add picture description factory

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>
This commit is contained in:
Michele Dolfi 2025-02-24 11:13:14 +01:00
parent 8235a246c8
commit 3b9b675a4f
9 changed files with 104 additions and 55 deletions

View File

@ -2,6 +2,9 @@ import logging
from functools import lru_cache from functools import lru_cache
from docling.models.factories.ocr_factory import OcrFactory from docling.models.factories.ocr_factory import OcrFactory
from docling.models.factories.picture_description_factory import (
PictureDescriptionFactory,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -10,5 +13,13 @@ logger = logging.getLogger(__name__)
def get_ocr_factory(): def get_ocr_factory():
factory = OcrFactory() factory = OcrFactory()
factory.load_from_plugins() factory.load_from_plugins()
# logger.info("Registered ocr engines: %r", factory.registered_kind) logger.info("Registered ocr engines: %r", factory.registered_kind)
return factory
@lru_cache(maxsize=1)
def get_picture_description_factory():
factory = PictureDescriptionFactory()
factory.load_from_plugins()
logger.info("Registered picture descriptions: %r", factory.registered_kind)
return factory return factory

View File

@ -1,6 +1,5 @@
import logging import logging
from docling.datamodel.pipeline_options import OcrOptions
from docling.models.base_ocr_model import BaseOcrModel from docling.models.base_ocr_model import BaseOcrModel
from docling.models.factories.base_factory import BaseFactory from docling.models.factories.base_factory import BaseFactory

View File

@ -0,0 +1,11 @@
import logging
from docling.models.factories.base_factory import BaseFactory
from docling.models.picture_description_base_model import PictureDescriptionBaseModel
logger = logging.getLogger(__name__)
class PictureDescriptionFactory(BaseFactory[PictureDescriptionBaseModel]):
def __init__(self, *args, **kwargs):
super().__init__("picture_description", *args, **kwargs)

View File

@ -1,13 +1,18 @@
import base64 import base64
import io import io
import logging import logging
from typing import Iterable, List, Optional from pathlib import Path
from typing import Iterable, List, Optional, Type, Union
import requests import requests
from PIL import Image from PIL import Image
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from docling.datamodel.pipeline_options import PictureDescriptionApiOptions from docling.datamodel.pipeline_options import (
AcceleratorOptions,
PictureDescriptionApiOptions,
PictureDescriptionBaseOptions,
)
from docling.exceptions import OperationNotAllowed from docling.exceptions import OperationNotAllowed
from docling.models.picture_description_base_model import PictureDescriptionBaseModel from docling.models.picture_description_base_model import PictureDescriptionBaseModel
@ -46,13 +51,25 @@ class ApiResponse(BaseModel):
class PictureDescriptionApiModel(PictureDescriptionBaseModel): class PictureDescriptionApiModel(PictureDescriptionBaseModel):
# elements_batch_size = 4 # elements_batch_size = 4
@classmethod
def get_options_type(cls) -> Type[PictureDescriptionBaseOptions]:
return PictureDescriptionApiOptions
def __init__( def __init__(
self, self,
enabled: bool, enabled: bool,
enable_remote_services: bool, enable_remote_services: bool,
artifacts_path: Optional[Union[Path, str]],
options: PictureDescriptionApiOptions, options: PictureDescriptionApiOptions,
accelerator_options: AcceleratorOptions,
): ):
super().__init__(enabled=enabled, options=options) super().__init__(
enabled=enabled,
enable_remote_services=enable_remote_services,
artifacts_path=artifacts_path,
options=options,
accelerator_options=accelerator_options,
)
self.options: PictureDescriptionApiOptions self.options: PictureDescriptionApiOptions
if self.enabled: if self.enabled:

View File

@ -1,6 +1,7 @@
import logging import logging
from abc import abstractmethod
from pathlib import Path from pathlib import Path
from typing import Any, Iterable, List, Optional, Union from typing import Any, Iterable, List, Optional, Type, Union
from docling_core.types.doc import ( from docling_core.types.doc import (
DoclingDocument, DoclingDocument,
@ -13,20 +14,29 @@ from docling_core.types.doc.document import ( # TODO: move import to docling_co
) )
from PIL import Image from PIL import Image
from docling.datamodel.pipeline_options import PictureDescriptionBaseOptions from docling.datamodel.pipeline_options import (
AcceleratorOptions,
PictureDescriptionBaseOptions,
)
from docling.models.base_model import ( from docling.models.base_model import (
BaseItemAndImageEnrichmentModel, BaseItemAndImageEnrichmentModel,
BaseModelWithOptions,
ItemAndImageEnrichmentElement, ItemAndImageEnrichmentElement,
) )
class PictureDescriptionBaseModel(BaseItemAndImageEnrichmentModel): class PictureDescriptionBaseModel(
BaseItemAndImageEnrichmentModel, BaseModelWithOptions
):
images_scale: float = 2.0 images_scale: float = 2.0
def __init__( def __init__(
self, self,
enabled: bool, enabled: bool,
enable_remote_services: bool,
artifacts_path: Optional[Union[Path, str]],
options: PictureDescriptionBaseOptions, options: PictureDescriptionBaseOptions,
accelerator_options: AcceleratorOptions,
): ):
self.enabled = enabled self.enabled = enabled
self.options = options self.options = options
@ -62,3 +72,8 @@ class PictureDescriptionBaseModel(BaseItemAndImageEnrichmentModel):
PictureDescriptionData(text=output, provenance=self.provenance) PictureDescriptionData(text=output, provenance=self.provenance)
) )
yield item yield item
@classmethod
@abstractmethod
def get_options_type(cls) -> Type[PictureDescriptionBaseOptions]:
pass

View File

@ -1,10 +1,11 @@
from pathlib import Path from pathlib import Path
from typing import Iterable, Optional, Union from typing import Iterable, Optional, Type, Union
from PIL import Image from PIL import Image
from docling.datamodel.pipeline_options import ( from docling.datamodel.pipeline_options import (
AcceleratorOptions, AcceleratorOptions,
PictureDescriptionBaseOptions,
PictureDescriptionVlmOptions, PictureDescriptionVlmOptions,
) )
from docling.models.picture_description_base_model import PictureDescriptionBaseModel from docling.models.picture_description_base_model import PictureDescriptionBaseModel
@ -13,14 +14,25 @@ from docling.utils.accelerator_utils import decide_device
class PictureDescriptionVlmModel(PictureDescriptionBaseModel): class PictureDescriptionVlmModel(PictureDescriptionBaseModel):
@classmethod
def get_options_type(cls) -> Type[PictureDescriptionBaseOptions]:
return PictureDescriptionVlmOptions
def __init__( def __init__(
self, self,
enabled: bool, enabled: bool,
enable_remote_services: bool,
artifacts_path: Optional[Union[Path, str]], artifacts_path: Optional[Union[Path, str]],
options: PictureDescriptionVlmOptions, options: PictureDescriptionVlmOptions,
accelerator_options: AcceleratorOptions, accelerator_options: AcceleratorOptions,
): ):
super().__init__(enabled=enabled, options=options) super().__init__(
enabled=enabled,
enable_remote_services=enable_remote_services,
artifacts_path=artifacts_path,
options=options,
accelerator_options=accelerator_options,
)
self.options: PictureDescriptionVlmOptions self.options: PictureDescriptionVlmOptions
if self.enabled: if self.enabled:

View File

@ -1,7 +1,7 @@
import sys
from docling.models.easyocr_model import EasyOcrModel from docling.models.easyocr_model import EasyOcrModel
from docling.models.ocr_mac_model import OcrMacModel from docling.models.ocr_mac_model import OcrMacModel
from docling.models.picture_description_api_model import PictureDescriptionApiModel
from docling.models.picture_description_vlm_model import PictureDescriptionVlmModel
from docling.models.rapid_ocr_model import RapidOcrModel from docling.models.rapid_ocr_model import RapidOcrModel
from docling.models.tesseract_ocr_cli_model import TesseractOcrCliModel from docling.models.tesseract_ocr_cli_model import TesseractOcrCliModel
from docling.models.tesseract_ocr_model import TesseractOcrModel from docling.models.tesseract_ocr_model import TesseractOcrModel
@ -17,3 +17,12 @@ def ocr_engines():
TesseractOcrCliModel, TesseractOcrCliModel,
] ]
} }
def picture_description():
return {
"picture_description": [
PictureDescriptionVlmModel,
PictureDescriptionApiModel,
]
}

View File

@ -10,16 +10,7 @@ from docling.backend.abstract_backend import AbstractDocumentBackend
from docling.backend.pdf_backend import PdfDocumentBackend from docling.backend.pdf_backend import PdfDocumentBackend
from docling.datamodel.base_models import AssembledUnit, Page from docling.datamodel.base_models import AssembledUnit, Page
from docling.datamodel.document import ConversionResult from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options import ( from docling.datamodel.pipeline_options import PdfPipelineOptions
EasyOcrOptions,
OcrMacOptions,
PdfPipelineOptions,
PictureDescriptionApiOptions,
PictureDescriptionVlmOptions,
RapidOcrOptions,
TesseractCliOcrOptions,
TesseractOcrOptions,
)
from docling.datamodel.settings import settings from docling.datamodel.settings import settings
from docling.models.base_ocr_model import BaseOcrModel from docling.models.base_ocr_model import BaseOcrModel
from docling.models.code_formula_model import CodeFormulaModel, CodeFormulaModelOptions from docling.models.code_formula_model import CodeFormulaModel, CodeFormulaModelOptions
@ -27,23 +18,16 @@ from docling.models.document_picture_classifier import (
DocumentPictureClassifier, DocumentPictureClassifier,
DocumentPictureClassifierOptions, DocumentPictureClassifierOptions,
) )
from docling.models.easyocr_model import EasyOcrModel from docling.models.factories import get_ocr_factory, get_picture_description_factory
from docling.models.factories import get_ocr_factory
from docling.models.layout_model import LayoutModel from docling.models.layout_model import LayoutModel
from docling.models.ocr_mac_model import OcrMacModel
from docling.models.page_assemble_model import PageAssembleModel, PageAssembleOptions from docling.models.page_assemble_model import PageAssembleModel, PageAssembleOptions
from docling.models.page_preprocessing_model import ( from docling.models.page_preprocessing_model import (
PagePreprocessingModel, PagePreprocessingModel,
PagePreprocessingOptions, PagePreprocessingOptions,
) )
from docling.models.picture_description_api_model import PictureDescriptionApiModel
from docling.models.picture_description_base_model import PictureDescriptionBaseModel from docling.models.picture_description_base_model import PictureDescriptionBaseModel
from docling.models.picture_description_vlm_model import PictureDescriptionVlmModel
from docling.models.rapid_ocr_model import RapidOcrModel
from docling.models.readingorder_model import ReadingOrderModel, ReadingOrderOptions from docling.models.readingorder_model import ReadingOrderModel, ReadingOrderOptions
from docling.models.table_structure_model import TableStructureModel from docling.models.table_structure_model import TableStructureModel
from docling.models.tesseract_ocr_cli_model import TesseractOcrCliModel
from docling.models.tesseract_ocr_model import TesseractOcrModel
from docling.pipeline.base_pipeline import PaginatedPipeline from docling.pipeline.base_pipeline import PaginatedPipeline
from docling.utils.model_downloader import download_models from docling.utils.model_downloader import download_models
from docling.utils.profiling import ProfilingScope, TimeRecorder from docling.utils.profiling import ProfilingScope, TimeRecorder
@ -162,10 +146,8 @@ class StandardPdfPipeline(PaginatedPipeline):
return output_dir return output_dir
def get_ocr_model(self, artifacts_path: Optional[Path] = None) -> BaseOcrModel: def get_ocr_model(self, artifacts_path: Optional[Path] = None) -> BaseOcrModel:
ocr_factory = get_ocr_factory() factory = get_ocr_factory()
ocr_engine_cls = ocr_factory.get_class( ocr_engine_cls = factory.get_class(options=self.pipeline_options.ocr_options)
options=self.pipeline_options.ocr_options
)
return ocr_engine_cls( return ocr_engine_cls(
enabled=self.pipeline_options.do_ocr, enabled=self.pipeline_options.do_ocr,
@ -177,26 +159,19 @@ class StandardPdfPipeline(PaginatedPipeline):
def get_picture_description_model( def get_picture_description_model(
self, artifacts_path: Optional[Path] = None self, artifacts_path: Optional[Path] = None
) -> Optional[PictureDescriptionBaseModel]: ) -> Optional[PictureDescriptionBaseModel]:
if isinstance( factory = get_picture_description_factory()
self.pipeline_options.picture_description_options,
PictureDescriptionApiOptions, options_cls = factory.get_class(
): options=self.pipeline_options.picture_description_options
return PictureDescriptionApiModel( )
enabled=self.pipeline_options.do_picture_description,
enable_remote_services=self.pipeline_options.enable_remote_services, return options_cls(
options=self.pipeline_options.picture_description_options, enabled=self.pipeline_options.do_ocr,
) enable_remote_services=self.pipeline_options.enable_remote_services,
elif isinstance( artifacts_path=artifacts_path,
self.pipeline_options.picture_description_options, options=self.pipeline_options.picture_description_options,
PictureDescriptionVlmOptions, accelerator_options=self.pipeline_options.accelerator_options,
): )
return PictureDescriptionVlmModel(
enabled=self.pipeline_options.do_picture_description,
artifacts_path=artifacts_path,
options=self.pipeline_options.picture_description_options,
accelerator_options=self.pipeline_options.accelerator_options,
)
return None
def initialize_page(self, conv_res: ConversionResult, page: Page) -> Page: def initialize_page(self, conv_res: ConversionResult, page: Page) -> Page:
with TimeRecorder(conv_res, "page_init"): with TimeRecorder(conv_res, "page_init"):

View File

@ -133,7 +133,7 @@ docling = "docling.cli.main:app"
docling-tools = "docling.cli.tools:app" docling-tools = "docling.cli.tools:app"
[tool.poetry.plugins."docling"] [tool.poetry.plugins."docling"]
"docling_ocr_engines" = "docling.models.plugins.ocr_engines" "docling_defaults" = "docling.models.plugins.defaults"
[build-system] [build-system]
requires = ["poetry-core"] requires = ["poetry-core"]