mirror of
https://github.com/DS4SD/docling.git
synced 2025-08-02 15:32:30 +00:00
figure classifier
Signed-off-by: Matteo Omenetti <omenetti.matteo@gmail.com>
This commit is contained in:
parent
3213b247ad
commit
8ecb810bb5
@ -221,6 +221,7 @@ class PdfPipelineOptions(PipelineOptions):
|
|||||||
do_ocr: bool = True # True: perform OCR, replace programmatic PDF text
|
do_ocr: bool = True # True: perform OCR, replace programmatic PDF text
|
||||||
do_code_enrichment: bool = False # True: perform code OCR
|
do_code_enrichment: bool = False # True: perform code OCR
|
||||||
do_formula_enrichment: bool = False # True: perform formula OCR, return Latex code
|
do_formula_enrichment: bool = False # True: perform formula OCR, return Latex code
|
||||||
|
do_picture_classification: bool = False # True: classify pictures in documents
|
||||||
|
|
||||||
table_structure_options: TableStructureOptions = TableStructureOptions()
|
table_structure_options: TableStructureOptions = TableStructureOptions()
|
||||||
ocr_options: Union[
|
ocr_options: Union[
|
||||||
|
187
docling/models/document_picture_classifier.py
Normal file
187
docling/models/document_picture_classifier.py
Normal file
@ -0,0 +1,187 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
from typing import Iterable, List, Literal, Optional, Tuple, Union
|
||||||
|
|
||||||
|
from docling_core.types.doc import (
|
||||||
|
DoclingDocument,
|
||||||
|
NodeItem,
|
||||||
|
PictureClassificationClass,
|
||||||
|
PictureClassificationData,
|
||||||
|
PictureItem,
|
||||||
|
)
|
||||||
|
from PIL import Image
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from docling.datamodel.pipeline_options import AcceleratorOptions
|
||||||
|
from docling.models.base_model import BaseEnrichmentModel
|
||||||
|
from docling.utils.accelerator_utils import decide_device
|
||||||
|
|
||||||
|
|
||||||
|
class DocumentPictureClassifierOptions(BaseModel):
|
||||||
|
"""
|
||||||
|
Options for configuring the DocumentPictureClassifier.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
kind : Literal["document_picture_classifier"]
|
||||||
|
Identifier for the type of classifier.
|
||||||
|
"""
|
||||||
|
|
||||||
|
kind: Literal["document_picture_classifier"] = "document_picture_classifier"
|
||||||
|
|
||||||
|
|
||||||
|
class DocumentPictureClassifier(BaseEnrichmentModel):
|
||||||
|
"""
|
||||||
|
A model for classifying pictures in documents.
|
||||||
|
|
||||||
|
This class enriches document pictures with predicted classifications
|
||||||
|
based on a predefined set of classes.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
enabled : bool
|
||||||
|
Whether the classifier is enabled for use.
|
||||||
|
options : DocumentPictureClassifierOptions
|
||||||
|
Configuration options for the classifier.
|
||||||
|
document_picture_classifier : DocumentPictureClassifierPredictor
|
||||||
|
The underlying prediction model, loaded if the classifier is enabled.
|
||||||
|
|
||||||
|
Methods
|
||||||
|
-------
|
||||||
|
__init__(enabled, artifacts_path, options, accelerator_options)
|
||||||
|
Initializes the classifier with specified configurations.
|
||||||
|
is_processable(doc, element)
|
||||||
|
Checks if the given element can be processed by the classifier.
|
||||||
|
__call__(doc, element_batch)
|
||||||
|
Processes a batch of elements and adds classification annotations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
images_scale = 2
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
enabled: bool,
|
||||||
|
artifacts_path: Optional[Union[Path, str]],
|
||||||
|
options: DocumentPictureClassifierOptions,
|
||||||
|
accelerator_options: AcceleratorOptions,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initializes the DocumentPictureClassifier.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
enabled : bool
|
||||||
|
Indicates whether the classifier is enabled.
|
||||||
|
artifacts_path : Optional[Union[Path, str]],
|
||||||
|
Path to the directory containing model artifacts.
|
||||||
|
options : DocumentPictureClassifierOptions
|
||||||
|
Configuration options for the classifier.
|
||||||
|
accelerator_options : AcceleratorOptions
|
||||||
|
Options for configuring the device and parallelism.
|
||||||
|
"""
|
||||||
|
self.enabled = enabled
|
||||||
|
self.options = options
|
||||||
|
|
||||||
|
if self.enabled:
|
||||||
|
device = decide_device(accelerator_options.device)
|
||||||
|
from docling_ibm_models.document_figure_classifier_model.document_figure_classifier_predictor import (
|
||||||
|
DocumentFigureClassifierPredictor,
|
||||||
|
)
|
||||||
|
|
||||||
|
if artifacts_path is None:
|
||||||
|
artifacts_path = self.download_models_hf()
|
||||||
|
else:
|
||||||
|
artifacts_path = Path(artifacts_path)
|
||||||
|
|
||||||
|
self.document_picture_classifier = DocumentFigureClassifierPredictor(
|
||||||
|
artifacts_path=artifacts_path,
|
||||||
|
device=device,
|
||||||
|
num_threads=accelerator_options.num_threads,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def download_models_hf(
|
||||||
|
local_dir: Optional[Path] = None, force: bool = False
|
||||||
|
) -> Path:
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
from huggingface_hub.utils import disable_progress_bars
|
||||||
|
|
||||||
|
disable_progress_bars()
|
||||||
|
download_path = snapshot_download(
|
||||||
|
repo_id="ds4sd/DocumentFigureClassifier",
|
||||||
|
force_download=force,
|
||||||
|
local_dir=local_dir,
|
||||||
|
revision="v1.0.0",
|
||||||
|
)
|
||||||
|
|
||||||
|
return Path(download_path)
|
||||||
|
|
||||||
|
def is_processable(self, doc: DoclingDocument, element: NodeItem) -> bool:
|
||||||
|
"""
|
||||||
|
Determines if the given element can be processed by the classifier.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
doc : DoclingDocument
|
||||||
|
The document containing the element.
|
||||||
|
element : NodeItem
|
||||||
|
The element to be checked.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
bool
|
||||||
|
True if the element is a PictureItem and processing is enabled; False otherwise.
|
||||||
|
"""
|
||||||
|
return self.enabled and isinstance(element, PictureItem)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
doc: DoclingDocument,
|
||||||
|
element_batch: Iterable[NodeItem],
|
||||||
|
) -> Iterable[NodeItem]:
|
||||||
|
"""
|
||||||
|
Processes a batch of elements and enriches them with classification predictions.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
doc : DoclingDocument
|
||||||
|
The document containing the elements to be processed.
|
||||||
|
element_batch : Iterable[NodeItem]
|
||||||
|
A batch of pictures to classify.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Iterable[NodeItem]
|
||||||
|
An iterable of NodeItem objects after processing. The field
|
||||||
|
'data.classification' is added containing the classification for each picture.
|
||||||
|
"""
|
||||||
|
if not self.enabled:
|
||||||
|
for element in element_batch:
|
||||||
|
yield element
|
||||||
|
return
|
||||||
|
|
||||||
|
images: List[Image.Image] = []
|
||||||
|
elements: List[PictureItem] = []
|
||||||
|
for el in element_batch:
|
||||||
|
assert isinstance(el, PictureItem)
|
||||||
|
elements.append(el)
|
||||||
|
img = el.get_image(doc)
|
||||||
|
assert img is not None
|
||||||
|
images.append(img)
|
||||||
|
|
||||||
|
outputs = self.document_picture_classifier.predict(images)
|
||||||
|
|
||||||
|
for element, output in zip(elements, outputs):
|
||||||
|
element.annotations.append(
|
||||||
|
PictureClassificationData(
|
||||||
|
provenance="DocumentPictureClassifier",
|
||||||
|
predicted_classes=[
|
||||||
|
PictureClassificationClass(
|
||||||
|
class_name=pred[0],
|
||||||
|
confidence=pred[1],
|
||||||
|
)
|
||||||
|
for pred in output
|
||||||
|
],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
yield element
|
@ -19,6 +19,10 @@ from docling.datamodel.pipeline_options import (
|
|||||||
)
|
)
|
||||||
from docling.models.base_ocr_model import BaseOcrModel
|
from docling.models.base_ocr_model import BaseOcrModel
|
||||||
from docling.models.code_formula_model import CodeFormulaModel, CodeFormulaModelOptions
|
from docling.models.code_formula_model import CodeFormulaModel, CodeFormulaModelOptions
|
||||||
|
from docling.models.document_picture_classifier import (
|
||||||
|
DocumentPictureClassifier,
|
||||||
|
DocumentPictureClassifierOptions,
|
||||||
|
)
|
||||||
from docling.models.ds_glm_model import GlmModel, GlmOptions
|
from docling.models.ds_glm_model import GlmModel, GlmOptions
|
||||||
from docling.models.easyocr_model import EasyOcrModel
|
from docling.models.easyocr_model import EasyOcrModel
|
||||||
from docling.models.layout_model import LayoutModel
|
from docling.models.layout_model import LayoutModel
|
||||||
@ -104,6 +108,13 @@ class StandardPdfPipeline(PaginatedPipeline):
|
|||||||
),
|
),
|
||||||
accelerator_options=pipeline_options.accelerator_options,
|
accelerator_options=pipeline_options.accelerator_options,
|
||||||
),
|
),
|
||||||
|
# Document Picture Classifier
|
||||||
|
DocumentPictureClassifier(
|
||||||
|
enabled=pipeline_options.do_picture_classification,
|
||||||
|
artifacts_path=pipeline_options.artifacts_path,
|
||||||
|
options=DocumentPictureClassifierOptions(),
|
||||||
|
accelerator_options=pipeline_options.accelerator_options,
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
if (
|
if (
|
||||||
|
Loading…
Reference in New Issue
Block a user