diff --git a/docling/models/factories/__init__.py b/docling/models/factories/__init__.py index a1b50970..5d527ed5 100644 --- a/docling/models/factories/__init__.py +++ b/docling/models/factories/__init__.py @@ -2,6 +2,9 @@ import logging from functools import lru_cache from docling.models.factories.ocr_factory import OcrFactory +from docling.models.factories.picture_description_factory import ( + PictureDescriptionFactory, +) logger = logging.getLogger(__name__) @@ -10,5 +13,13 @@ logger = logging.getLogger(__name__) def get_ocr_factory(): factory = OcrFactory() factory.load_from_plugins() - # logger.info("Registered ocr engines: %r", factory.registered_kind) + logger.info("Registered ocr engines: %r", factory.registered_kind) + return factory + + +@lru_cache(maxsize=1) +def get_picture_description_factory(): + factory = PictureDescriptionFactory() + factory.load_from_plugins() + logger.info("Registered picture descriptions: %r", factory.registered_kind) return factory diff --git a/docling/models/factories/ocr_factory.py b/docling/models/factories/ocr_factory.py index 414d004a..1153baaa 100644 --- a/docling/models/factories/ocr_factory.py +++ b/docling/models/factories/ocr_factory.py @@ -1,6 +1,5 @@ import logging -from docling.datamodel.pipeline_options import OcrOptions from docling.models.base_ocr_model import BaseOcrModel from docling.models.factories.base_factory import BaseFactory diff --git a/docling/models/factories/picture_description_factory.py b/docling/models/factories/picture_description_factory.py new file mode 100644 index 00000000..f66d132f --- /dev/null +++ b/docling/models/factories/picture_description_factory.py @@ -0,0 +1,11 @@ +import logging + +from docling.models.factories.base_factory import BaseFactory +from docling.models.picture_description_base_model import PictureDescriptionBaseModel + +logger = logging.getLogger(__name__) + + +class PictureDescriptionFactory(BaseFactory[PictureDescriptionBaseModel]): + def __init__(self, *args, **kwargs): + super().__init__("picture_description", *args, **kwargs) diff --git a/docling/models/picture_description_api_model.py b/docling/models/picture_description_api_model.py index c64f1bfe..6ef8a7fc 100644 --- a/docling/models/picture_description_api_model.py +++ b/docling/models/picture_description_api_model.py @@ -1,13 +1,18 @@ import base64 import io import logging -from typing import Iterable, List, Optional +from pathlib import Path +from typing import Iterable, List, Optional, Type, Union import requests from PIL import Image from pydantic import BaseModel, ConfigDict -from docling.datamodel.pipeline_options import PictureDescriptionApiOptions +from docling.datamodel.pipeline_options import ( + AcceleratorOptions, + PictureDescriptionApiOptions, + PictureDescriptionBaseOptions, +) from docling.exceptions import OperationNotAllowed from docling.models.picture_description_base_model import PictureDescriptionBaseModel @@ -46,13 +51,25 @@ class ApiResponse(BaseModel): class PictureDescriptionApiModel(PictureDescriptionBaseModel): # elements_batch_size = 4 + @classmethod + def get_options_type(cls) -> Type[PictureDescriptionBaseOptions]: + return PictureDescriptionApiOptions + def __init__( self, enabled: bool, enable_remote_services: bool, + artifacts_path: Optional[Union[Path, str]], options: PictureDescriptionApiOptions, + accelerator_options: AcceleratorOptions, ): - super().__init__(enabled=enabled, options=options) + super().__init__( + enabled=enabled, + enable_remote_services=enable_remote_services, + artifacts_path=artifacts_path, + options=options, + accelerator_options=accelerator_options, + ) self.options: PictureDescriptionApiOptions if self.enabled: diff --git a/docling/models/picture_description_base_model.py b/docling/models/picture_description_base_model.py index b653e0e3..734c18ba 100644 --- a/docling/models/picture_description_base_model.py +++ b/docling/models/picture_description_base_model.py @@ -1,6 +1,7 @@ import logging +from abc import abstractmethod from pathlib import Path -from typing import Any, Iterable, List, Optional, Union +from typing import Any, Iterable, List, Optional, Type, Union from docling_core.types.doc import ( DoclingDocument, @@ -13,20 +14,29 @@ from docling_core.types.doc.document import ( # TODO: move import to docling_co ) from PIL import Image -from docling.datamodel.pipeline_options import PictureDescriptionBaseOptions +from docling.datamodel.pipeline_options import ( + AcceleratorOptions, + PictureDescriptionBaseOptions, +) from docling.models.base_model import ( BaseItemAndImageEnrichmentModel, + BaseModelWithOptions, ItemAndImageEnrichmentElement, ) -class PictureDescriptionBaseModel(BaseItemAndImageEnrichmentModel): +class PictureDescriptionBaseModel( + BaseItemAndImageEnrichmentModel, BaseModelWithOptions +): images_scale: float = 2.0 def __init__( self, enabled: bool, + enable_remote_services: bool, + artifacts_path: Optional[Union[Path, str]], options: PictureDescriptionBaseOptions, + accelerator_options: AcceleratorOptions, ): self.enabled = enabled self.options = options @@ -62,3 +72,8 @@ class PictureDescriptionBaseModel(BaseItemAndImageEnrichmentModel): PictureDescriptionData(text=output, provenance=self.provenance) ) yield item + + @classmethod + @abstractmethod + def get_options_type(cls) -> Type[PictureDescriptionBaseOptions]: + pass diff --git a/docling/models/picture_description_vlm_model.py b/docling/models/picture_description_vlm_model.py index 9fa4826d..75005142 100644 --- a/docling/models/picture_description_vlm_model.py +++ b/docling/models/picture_description_vlm_model.py @@ -1,10 +1,11 @@ from pathlib import Path -from typing import Iterable, Optional, Union +from typing import Iterable, Optional, Type, Union from PIL import Image from docling.datamodel.pipeline_options import ( AcceleratorOptions, + PictureDescriptionBaseOptions, PictureDescriptionVlmOptions, ) from docling.models.picture_description_base_model import PictureDescriptionBaseModel @@ -13,14 +14,25 @@ from docling.utils.accelerator_utils import decide_device class PictureDescriptionVlmModel(PictureDescriptionBaseModel): + @classmethod + def get_options_type(cls) -> Type[PictureDescriptionBaseOptions]: + return PictureDescriptionVlmOptions + def __init__( self, enabled: bool, + enable_remote_services: bool, artifacts_path: Optional[Union[Path, str]], options: PictureDescriptionVlmOptions, accelerator_options: AcceleratorOptions, ): - super().__init__(enabled=enabled, options=options) + super().__init__( + enabled=enabled, + enable_remote_services=enable_remote_services, + artifacts_path=artifacts_path, + options=options, + accelerator_options=accelerator_options, + ) self.options: PictureDescriptionVlmOptions if self.enabled: diff --git a/docling/models/plugins/ocr_engines.py b/docling/models/plugins/defaults.py similarity index 60% rename from docling/models/plugins/ocr_engines.py rename to docling/models/plugins/defaults.py index 646d8782..00873579 100644 --- a/docling/models/plugins/ocr_engines.py +++ b/docling/models/plugins/defaults.py @@ -1,7 +1,7 @@ -import sys - from docling.models.easyocr_model import EasyOcrModel from docling.models.ocr_mac_model import OcrMacModel +from docling.models.picture_description_api_model import PictureDescriptionApiModel +from docling.models.picture_description_vlm_model import PictureDescriptionVlmModel from docling.models.rapid_ocr_model import RapidOcrModel from docling.models.tesseract_ocr_cli_model import TesseractOcrCliModel from docling.models.tesseract_ocr_model import TesseractOcrModel @@ -17,3 +17,12 @@ def ocr_engines(): TesseractOcrCliModel, ] } + + +def picture_description(): + return { + "picture_description": [ + PictureDescriptionVlmModel, + PictureDescriptionApiModel, + ] + } diff --git a/docling/pipeline/standard_pdf_pipeline.py b/docling/pipeline/standard_pdf_pipeline.py index 97d3be3f..4fe2b899 100644 --- a/docling/pipeline/standard_pdf_pipeline.py +++ b/docling/pipeline/standard_pdf_pipeline.py @@ -10,16 +10,7 @@ from docling.backend.abstract_backend import AbstractDocumentBackend from docling.backend.pdf_backend import PdfDocumentBackend from docling.datamodel.base_models import AssembledUnit, Page from docling.datamodel.document import ConversionResult -from docling.datamodel.pipeline_options import ( - EasyOcrOptions, - OcrMacOptions, - PdfPipelineOptions, - PictureDescriptionApiOptions, - PictureDescriptionVlmOptions, - RapidOcrOptions, - TesseractCliOcrOptions, - TesseractOcrOptions, -) +from docling.datamodel.pipeline_options import PdfPipelineOptions from docling.datamodel.settings import settings from docling.models.base_ocr_model import BaseOcrModel from docling.models.code_formula_model import CodeFormulaModel, CodeFormulaModelOptions @@ -27,23 +18,16 @@ from docling.models.document_picture_classifier import ( DocumentPictureClassifier, DocumentPictureClassifierOptions, ) -from docling.models.easyocr_model import EasyOcrModel -from docling.models.factories import get_ocr_factory +from docling.models.factories import get_ocr_factory, get_picture_description_factory from docling.models.layout_model import LayoutModel -from docling.models.ocr_mac_model import OcrMacModel from docling.models.page_assemble_model import PageAssembleModel, PageAssembleOptions from docling.models.page_preprocessing_model import ( PagePreprocessingModel, PagePreprocessingOptions, ) -from docling.models.picture_description_api_model import PictureDescriptionApiModel from docling.models.picture_description_base_model import PictureDescriptionBaseModel -from docling.models.picture_description_vlm_model import PictureDescriptionVlmModel -from docling.models.rapid_ocr_model import RapidOcrModel from docling.models.readingorder_model import ReadingOrderModel, ReadingOrderOptions from docling.models.table_structure_model import TableStructureModel -from docling.models.tesseract_ocr_cli_model import TesseractOcrCliModel -from docling.models.tesseract_ocr_model import TesseractOcrModel from docling.pipeline.base_pipeline import PaginatedPipeline from docling.utils.model_downloader import download_models from docling.utils.profiling import ProfilingScope, TimeRecorder @@ -162,10 +146,8 @@ class StandardPdfPipeline(PaginatedPipeline): return output_dir def get_ocr_model(self, artifacts_path: Optional[Path] = None) -> BaseOcrModel: - ocr_factory = get_ocr_factory() - ocr_engine_cls = ocr_factory.get_class( - options=self.pipeline_options.ocr_options - ) + factory = get_ocr_factory() + ocr_engine_cls = factory.get_class(options=self.pipeline_options.ocr_options) return ocr_engine_cls( enabled=self.pipeline_options.do_ocr, @@ -177,26 +159,19 @@ class StandardPdfPipeline(PaginatedPipeline): def get_picture_description_model( self, artifacts_path: Optional[Path] = None ) -> Optional[PictureDescriptionBaseModel]: - if isinstance( - self.pipeline_options.picture_description_options, - PictureDescriptionApiOptions, - ): - return PictureDescriptionApiModel( - enabled=self.pipeline_options.do_picture_description, - enable_remote_services=self.pipeline_options.enable_remote_services, - options=self.pipeline_options.picture_description_options, - ) - elif isinstance( - self.pipeline_options.picture_description_options, - PictureDescriptionVlmOptions, - ): - return PictureDescriptionVlmModel( - enabled=self.pipeline_options.do_picture_description, - artifacts_path=artifacts_path, - options=self.pipeline_options.picture_description_options, - accelerator_options=self.pipeline_options.accelerator_options, - ) - return None + factory = get_picture_description_factory() + + options_cls = factory.get_class( + options=self.pipeline_options.picture_description_options + ) + + return options_cls( + enabled=self.pipeline_options.do_ocr, + enable_remote_services=self.pipeline_options.enable_remote_services, + artifacts_path=artifacts_path, + options=self.pipeline_options.picture_description_options, + accelerator_options=self.pipeline_options.accelerator_options, + ) def initialize_page(self, conv_res: ConversionResult, page: Page) -> Page: with TimeRecorder(conv_res, "page_init"): diff --git a/pyproject.toml b/pyproject.toml index ce0260f8..eb3e1f1c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -133,7 +133,7 @@ docling = "docling.cli.main:app" docling-tools = "docling.cli.tools:app" [tool.poetry.plugins."docling"] -"docling_ocr_engines" = "docling.models.plugins.ocr_engines" +"docling_defaults" = "docling.models.plugins.defaults" [build-system] requires = ["poetry-core"]