From 8ecb810bb52081344f051eb6e9560d2b04e74dbb Mon Sep 17 00:00:00 2001 From: Matteo Omenetti Date: Fri, 24 Jan 2025 11:35:44 -0500 Subject: [PATCH] figure classifier Signed-off-by: Matteo Omenetti --- docling/datamodel/pipeline_options.py | 1 + docling/models/document_picture_classifier.py | 187 ++++++++++++++++++ docling/pipeline/standard_pdf_pipeline.py | 11 ++ 3 files changed, 199 insertions(+) create mode 100644 docling/models/document_picture_classifier.py diff --git a/docling/datamodel/pipeline_options.py b/docling/datamodel/pipeline_options.py index efdf3b1c..00ab7b41 100644 --- a/docling/datamodel/pipeline_options.py +++ b/docling/datamodel/pipeline_options.py @@ -221,6 +221,7 @@ class PdfPipelineOptions(PipelineOptions): do_ocr: bool = True # True: perform OCR, replace programmatic PDF text do_code_enrichment: bool = False # True: perform code OCR do_formula_enrichment: bool = False # True: perform formula OCR, return Latex code + do_picture_classification: bool = False # True: classify pictures in documents table_structure_options: TableStructureOptions = TableStructureOptions() ocr_options: Union[ diff --git a/docling/models/document_picture_classifier.py b/docling/models/document_picture_classifier.py new file mode 100644 index 00000000..6e2d90b4 --- /dev/null +++ b/docling/models/document_picture_classifier.py @@ -0,0 +1,187 @@ +from pathlib import Path +from typing import Iterable, List, Literal, Optional, Tuple, Union + +from docling_core.types.doc import ( + DoclingDocument, + NodeItem, + PictureClassificationClass, + PictureClassificationData, + PictureItem, +) +from PIL import Image +from pydantic import BaseModel + +from docling.datamodel.pipeline_options import AcceleratorOptions +from docling.models.base_model import BaseEnrichmentModel +from docling.utils.accelerator_utils import decide_device + + +class DocumentPictureClassifierOptions(BaseModel): + """ + Options for configuring the DocumentPictureClassifier. + + Attributes + ---------- + kind : Literal["document_picture_classifier"] + Identifier for the type of classifier. + """ + + kind: Literal["document_picture_classifier"] = "document_picture_classifier" + + +class DocumentPictureClassifier(BaseEnrichmentModel): + """ + A model for classifying pictures in documents. + + This class enriches document pictures with predicted classifications + based on a predefined set of classes. + + Attributes + ---------- + enabled : bool + Whether the classifier is enabled for use. + options : DocumentPictureClassifierOptions + Configuration options for the classifier. + document_picture_classifier : DocumentPictureClassifierPredictor + The underlying prediction model, loaded if the classifier is enabled. + + Methods + ------- + __init__(enabled, artifacts_path, options, accelerator_options) + Initializes the classifier with specified configurations. + is_processable(doc, element) + Checks if the given element can be processed by the classifier. + __call__(doc, element_batch) + Processes a batch of elements and adds classification annotations. + """ + + images_scale = 2 + + def __init__( + self, + enabled: bool, + artifacts_path: Optional[Union[Path, str]], + options: DocumentPictureClassifierOptions, + accelerator_options: AcceleratorOptions, + ): + """ + Initializes the DocumentPictureClassifier. + + Parameters + ---------- + enabled : bool + Indicates whether the classifier is enabled. + artifacts_path : Optional[Union[Path, str]], + Path to the directory containing model artifacts. + options : DocumentPictureClassifierOptions + Configuration options for the classifier. + accelerator_options : AcceleratorOptions + Options for configuring the device and parallelism. + """ + self.enabled = enabled + self.options = options + + if self.enabled: + device = decide_device(accelerator_options.device) + from docling_ibm_models.document_figure_classifier_model.document_figure_classifier_predictor import ( + DocumentFigureClassifierPredictor, + ) + + if artifacts_path is None: + artifacts_path = self.download_models_hf() + else: + artifacts_path = Path(artifacts_path) + + self.document_picture_classifier = DocumentFigureClassifierPredictor( + artifacts_path=artifacts_path, + device=device, + num_threads=accelerator_options.num_threads, + ) + + @staticmethod + def download_models_hf( + local_dir: Optional[Path] = None, force: bool = False + ) -> Path: + from huggingface_hub import snapshot_download + from huggingface_hub.utils import disable_progress_bars + + disable_progress_bars() + download_path = snapshot_download( + repo_id="ds4sd/DocumentFigureClassifier", + force_download=force, + local_dir=local_dir, + revision="v1.0.0", + ) + + return Path(download_path) + + def is_processable(self, doc: DoclingDocument, element: NodeItem) -> bool: + """ + Determines if the given element can be processed by the classifier. + + Parameters + ---------- + doc : DoclingDocument + The document containing the element. + element : NodeItem + The element to be checked. + + Returns + ------- + bool + True if the element is a PictureItem and processing is enabled; False otherwise. + """ + return self.enabled and isinstance(element, PictureItem) + + def __call__( + self, + doc: DoclingDocument, + element_batch: Iterable[NodeItem], + ) -> Iterable[NodeItem]: + """ + Processes a batch of elements and enriches them with classification predictions. + + Parameters + ---------- + doc : DoclingDocument + The document containing the elements to be processed. + element_batch : Iterable[NodeItem] + A batch of pictures to classify. + + Returns + ------- + Iterable[NodeItem] + An iterable of NodeItem objects after processing. The field + 'data.classification' is added containing the classification for each picture. + """ + if not self.enabled: + for element in element_batch: + yield element + return + + images: List[Image.Image] = [] + elements: List[PictureItem] = [] + for el in element_batch: + assert isinstance(el, PictureItem) + elements.append(el) + img = el.get_image(doc) + assert img is not None + images.append(img) + + outputs = self.document_picture_classifier.predict(images) + + for element, output in zip(elements, outputs): + element.annotations.append( + PictureClassificationData( + provenance="DocumentPictureClassifier", + predicted_classes=[ + PictureClassificationClass( + class_name=pred[0], + confidence=pred[1], + ) + for pred in output + ], + ) + ) + + yield element diff --git a/docling/pipeline/standard_pdf_pipeline.py b/docling/pipeline/standard_pdf_pipeline.py index 97bcc6b6..fe2201d6 100644 --- a/docling/pipeline/standard_pdf_pipeline.py +++ b/docling/pipeline/standard_pdf_pipeline.py @@ -19,6 +19,10 @@ from docling.datamodel.pipeline_options import ( ) from docling.models.base_ocr_model import BaseOcrModel from docling.models.code_formula_model import CodeFormulaModel, CodeFormulaModelOptions +from docling.models.document_picture_classifier import ( + DocumentPictureClassifier, + DocumentPictureClassifierOptions, +) from docling.models.ds_glm_model import GlmModel, GlmOptions from docling.models.easyocr_model import EasyOcrModel from docling.models.layout_model import LayoutModel @@ -104,6 +108,13 @@ class StandardPdfPipeline(PaginatedPipeline): ), accelerator_options=pipeline_options.accelerator_options, ), + # Document Picture Classifier + DocumentPictureClassifier( + enabled=pipeline_options.do_picture_classification, + artifacts_path=pipeline_options.artifacts_path, + options=DocumentPictureClassifierOptions(), + accelerator_options=pipeline_options.accelerator_options, + ), ] if (