mirror of
https://github.com/DS4SD/docling.git
synced 2025-08-02 07:22:14 +00:00
figure classifier
Signed-off-by: Matteo Omenetti <omenetti.matteo@gmail.com>
This commit is contained in:
parent
3213b247ad
commit
8ecb810bb5
@ -221,6 +221,7 @@ class PdfPipelineOptions(PipelineOptions):
|
||||
do_ocr: bool = True # True: perform OCR, replace programmatic PDF text
|
||||
do_code_enrichment: bool = False # True: perform code OCR
|
||||
do_formula_enrichment: bool = False # True: perform formula OCR, return Latex code
|
||||
do_picture_classification: bool = False # True: classify pictures in documents
|
||||
|
||||
table_structure_options: TableStructureOptions = TableStructureOptions()
|
||||
ocr_options: Union[
|
||||
|
187
docling/models/document_picture_classifier.py
Normal file
187
docling/models/document_picture_classifier.py
Normal file
@ -0,0 +1,187 @@
|
||||
from pathlib import Path
|
||||
from typing import Iterable, List, Literal, Optional, Tuple, Union
|
||||
|
||||
from docling_core.types.doc import (
|
||||
DoclingDocument,
|
||||
NodeItem,
|
||||
PictureClassificationClass,
|
||||
PictureClassificationData,
|
||||
PictureItem,
|
||||
)
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel
|
||||
|
||||
from docling.datamodel.pipeline_options import AcceleratorOptions
|
||||
from docling.models.base_model import BaseEnrichmentModel
|
||||
from docling.utils.accelerator_utils import decide_device
|
||||
|
||||
|
||||
class DocumentPictureClassifierOptions(BaseModel):
|
||||
"""
|
||||
Options for configuring the DocumentPictureClassifier.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
kind : Literal["document_picture_classifier"]
|
||||
Identifier for the type of classifier.
|
||||
"""
|
||||
|
||||
kind: Literal["document_picture_classifier"] = "document_picture_classifier"
|
||||
|
||||
|
||||
class DocumentPictureClassifier(BaseEnrichmentModel):
|
||||
"""
|
||||
A model for classifying pictures in documents.
|
||||
|
||||
This class enriches document pictures with predicted classifications
|
||||
based on a predefined set of classes.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
enabled : bool
|
||||
Whether the classifier is enabled for use.
|
||||
options : DocumentPictureClassifierOptions
|
||||
Configuration options for the classifier.
|
||||
document_picture_classifier : DocumentPictureClassifierPredictor
|
||||
The underlying prediction model, loaded if the classifier is enabled.
|
||||
|
||||
Methods
|
||||
-------
|
||||
__init__(enabled, artifacts_path, options, accelerator_options)
|
||||
Initializes the classifier with specified configurations.
|
||||
is_processable(doc, element)
|
||||
Checks if the given element can be processed by the classifier.
|
||||
__call__(doc, element_batch)
|
||||
Processes a batch of elements and adds classification annotations.
|
||||
"""
|
||||
|
||||
images_scale = 2
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
enabled: bool,
|
||||
artifacts_path: Optional[Union[Path, str]],
|
||||
options: DocumentPictureClassifierOptions,
|
||||
accelerator_options: AcceleratorOptions,
|
||||
):
|
||||
"""
|
||||
Initializes the DocumentPictureClassifier.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
enabled : bool
|
||||
Indicates whether the classifier is enabled.
|
||||
artifacts_path : Optional[Union[Path, str]],
|
||||
Path to the directory containing model artifacts.
|
||||
options : DocumentPictureClassifierOptions
|
||||
Configuration options for the classifier.
|
||||
accelerator_options : AcceleratorOptions
|
||||
Options for configuring the device and parallelism.
|
||||
"""
|
||||
self.enabled = enabled
|
||||
self.options = options
|
||||
|
||||
if self.enabled:
|
||||
device = decide_device(accelerator_options.device)
|
||||
from docling_ibm_models.document_figure_classifier_model.document_figure_classifier_predictor import (
|
||||
DocumentFigureClassifierPredictor,
|
||||
)
|
||||
|
||||
if artifacts_path is None:
|
||||
artifacts_path = self.download_models_hf()
|
||||
else:
|
||||
artifacts_path = Path(artifacts_path)
|
||||
|
||||
self.document_picture_classifier = DocumentFigureClassifierPredictor(
|
||||
artifacts_path=artifacts_path,
|
||||
device=device,
|
||||
num_threads=accelerator_options.num_threads,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def download_models_hf(
|
||||
local_dir: Optional[Path] = None, force: bool = False
|
||||
) -> Path:
|
||||
from huggingface_hub import snapshot_download
|
||||
from huggingface_hub.utils import disable_progress_bars
|
||||
|
||||
disable_progress_bars()
|
||||
download_path = snapshot_download(
|
||||
repo_id="ds4sd/DocumentFigureClassifier",
|
||||
force_download=force,
|
||||
local_dir=local_dir,
|
||||
revision="v1.0.0",
|
||||
)
|
||||
|
||||
return Path(download_path)
|
||||
|
||||
def is_processable(self, doc: DoclingDocument, element: NodeItem) -> bool:
|
||||
"""
|
||||
Determines if the given element can be processed by the classifier.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
doc : DoclingDocument
|
||||
The document containing the element.
|
||||
element : NodeItem
|
||||
The element to be checked.
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if the element is a PictureItem and processing is enabled; False otherwise.
|
||||
"""
|
||||
return self.enabled and isinstance(element, PictureItem)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
doc: DoclingDocument,
|
||||
element_batch: Iterable[NodeItem],
|
||||
) -> Iterable[NodeItem]:
|
||||
"""
|
||||
Processes a batch of elements and enriches them with classification predictions.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
doc : DoclingDocument
|
||||
The document containing the elements to be processed.
|
||||
element_batch : Iterable[NodeItem]
|
||||
A batch of pictures to classify.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Iterable[NodeItem]
|
||||
An iterable of NodeItem objects after processing. The field
|
||||
'data.classification' is added containing the classification for each picture.
|
||||
"""
|
||||
if not self.enabled:
|
||||
for element in element_batch:
|
||||
yield element
|
||||
return
|
||||
|
||||
images: List[Image.Image] = []
|
||||
elements: List[PictureItem] = []
|
||||
for el in element_batch:
|
||||
assert isinstance(el, PictureItem)
|
||||
elements.append(el)
|
||||
img = el.get_image(doc)
|
||||
assert img is not None
|
||||
images.append(img)
|
||||
|
||||
outputs = self.document_picture_classifier.predict(images)
|
||||
|
||||
for element, output in zip(elements, outputs):
|
||||
element.annotations.append(
|
||||
PictureClassificationData(
|
||||
provenance="DocumentPictureClassifier",
|
||||
predicted_classes=[
|
||||
PictureClassificationClass(
|
||||
class_name=pred[0],
|
||||
confidence=pred[1],
|
||||
)
|
||||
for pred in output
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
yield element
|
@ -19,6 +19,10 @@ from docling.datamodel.pipeline_options import (
|
||||
)
|
||||
from docling.models.base_ocr_model import BaseOcrModel
|
||||
from docling.models.code_formula_model import CodeFormulaModel, CodeFormulaModelOptions
|
||||
from docling.models.document_picture_classifier import (
|
||||
DocumentPictureClassifier,
|
||||
DocumentPictureClassifierOptions,
|
||||
)
|
||||
from docling.models.ds_glm_model import GlmModel, GlmOptions
|
||||
from docling.models.easyocr_model import EasyOcrModel
|
||||
from docling.models.layout_model import LayoutModel
|
||||
@ -104,6 +108,13 @@ class StandardPdfPipeline(PaginatedPipeline):
|
||||
),
|
||||
accelerator_options=pipeline_options.accelerator_options,
|
||||
),
|
||||
# Document Picture Classifier
|
||||
DocumentPictureClassifier(
|
||||
enabled=pipeline_options.do_picture_classification,
|
||||
artifacts_path=pipeline_options.artifacts_path,
|
||||
options=DocumentPictureClassifierOptions(),
|
||||
accelerator_options=pipeline_options.accelerator_options,
|
||||
),
|
||||
]
|
||||
|
||||
if (
|
||||
|
Loading…
Reference in New Issue
Block a user