mirror of
https://github.com/DS4SD/docling.git
synced 2025-07-31 14:34:40 +00:00
add picture description factory
Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>
This commit is contained in:
parent
8235a246c8
commit
3b9b675a4f
@ -2,6 +2,9 @@ import logging
|
||||
from functools import lru_cache
|
||||
|
||||
from docling.models.factories.ocr_factory import OcrFactory
|
||||
from docling.models.factories.picture_description_factory import (
|
||||
PictureDescriptionFactory,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -10,5 +13,13 @@ logger = logging.getLogger(__name__)
|
||||
def get_ocr_factory():
|
||||
factory = OcrFactory()
|
||||
factory.load_from_plugins()
|
||||
# logger.info("Registered ocr engines: %r", factory.registered_kind)
|
||||
logger.info("Registered ocr engines: %r", factory.registered_kind)
|
||||
return factory
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_picture_description_factory():
|
||||
factory = PictureDescriptionFactory()
|
||||
factory.load_from_plugins()
|
||||
logger.info("Registered picture descriptions: %r", factory.registered_kind)
|
||||
return factory
|
||||
|
@ -1,6 +1,5 @@
|
||||
import logging
|
||||
|
||||
from docling.datamodel.pipeline_options import OcrOptions
|
||||
from docling.models.base_ocr_model import BaseOcrModel
|
||||
from docling.models.factories.base_factory import BaseFactory
|
||||
|
||||
|
11
docling/models/factories/picture_description_factory.py
Normal file
11
docling/models/factories/picture_description_factory.py
Normal file
@ -0,0 +1,11 @@
|
||||
import logging
|
||||
|
||||
from docling.models.factories.base_factory import BaseFactory
|
||||
from docling.models.picture_description_base_model import PictureDescriptionBaseModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PictureDescriptionFactory(BaseFactory[PictureDescriptionBaseModel]):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__("picture_description", *args, **kwargs)
|
@ -1,13 +1,18 @@
|
||||
import base64
|
||||
import io
|
||||
import logging
|
||||
from typing import Iterable, List, Optional
|
||||
from pathlib import Path
|
||||
from typing import Iterable, List, Optional, Type, Union
|
||||
|
||||
import requests
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from docling.datamodel.pipeline_options import PictureDescriptionApiOptions
|
||||
from docling.datamodel.pipeline_options import (
|
||||
AcceleratorOptions,
|
||||
PictureDescriptionApiOptions,
|
||||
PictureDescriptionBaseOptions,
|
||||
)
|
||||
from docling.exceptions import OperationNotAllowed
|
||||
from docling.models.picture_description_base_model import PictureDescriptionBaseModel
|
||||
|
||||
@ -46,13 +51,25 @@ class ApiResponse(BaseModel):
|
||||
class PictureDescriptionApiModel(PictureDescriptionBaseModel):
|
||||
# elements_batch_size = 4
|
||||
|
||||
@classmethod
|
||||
def get_options_type(cls) -> Type[PictureDescriptionBaseOptions]:
|
||||
return PictureDescriptionApiOptions
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
enabled: bool,
|
||||
enable_remote_services: bool,
|
||||
artifacts_path: Optional[Union[Path, str]],
|
||||
options: PictureDescriptionApiOptions,
|
||||
accelerator_options: AcceleratorOptions,
|
||||
):
|
||||
super().__init__(enabled=enabled, options=options)
|
||||
super().__init__(
|
||||
enabled=enabled,
|
||||
enable_remote_services=enable_remote_services,
|
||||
artifacts_path=artifacts_path,
|
||||
options=options,
|
||||
accelerator_options=accelerator_options,
|
||||
)
|
||||
self.options: PictureDescriptionApiOptions
|
||||
|
||||
if self.enabled:
|
||||
|
@ -1,6 +1,7 @@
|
||||
import logging
|
||||
from abc import abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any, Iterable, List, Optional, Union
|
||||
from typing import Any, Iterable, List, Optional, Type, Union
|
||||
|
||||
from docling_core.types.doc import (
|
||||
DoclingDocument,
|
||||
@ -13,20 +14,29 @@ from docling_core.types.doc.document import ( # TODO: move import to docling_co
|
||||
)
|
||||
from PIL import Image
|
||||
|
||||
from docling.datamodel.pipeline_options import PictureDescriptionBaseOptions
|
||||
from docling.datamodel.pipeline_options import (
|
||||
AcceleratorOptions,
|
||||
PictureDescriptionBaseOptions,
|
||||
)
|
||||
from docling.models.base_model import (
|
||||
BaseItemAndImageEnrichmentModel,
|
||||
BaseModelWithOptions,
|
||||
ItemAndImageEnrichmentElement,
|
||||
)
|
||||
|
||||
|
||||
class PictureDescriptionBaseModel(BaseItemAndImageEnrichmentModel):
|
||||
class PictureDescriptionBaseModel(
|
||||
BaseItemAndImageEnrichmentModel, BaseModelWithOptions
|
||||
):
|
||||
images_scale: float = 2.0
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
enabled: bool,
|
||||
enable_remote_services: bool,
|
||||
artifacts_path: Optional[Union[Path, str]],
|
||||
options: PictureDescriptionBaseOptions,
|
||||
accelerator_options: AcceleratorOptions,
|
||||
):
|
||||
self.enabled = enabled
|
||||
self.options = options
|
||||
@ -62,3 +72,8 @@ class PictureDescriptionBaseModel(BaseItemAndImageEnrichmentModel):
|
||||
PictureDescriptionData(text=output, provenance=self.provenance)
|
||||
)
|
||||
yield item
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_options_type(cls) -> Type[PictureDescriptionBaseOptions]:
|
||||
pass
|
||||
|
@ -1,10 +1,11 @@
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Optional, Union
|
||||
from typing import Iterable, Optional, Type, Union
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from docling.datamodel.pipeline_options import (
|
||||
AcceleratorOptions,
|
||||
PictureDescriptionBaseOptions,
|
||||
PictureDescriptionVlmOptions,
|
||||
)
|
||||
from docling.models.picture_description_base_model import PictureDescriptionBaseModel
|
||||
@ -13,14 +14,25 @@ from docling.utils.accelerator_utils import decide_device
|
||||
|
||||
class PictureDescriptionVlmModel(PictureDescriptionBaseModel):
|
||||
|
||||
@classmethod
|
||||
def get_options_type(cls) -> Type[PictureDescriptionBaseOptions]:
|
||||
return PictureDescriptionVlmOptions
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
enabled: bool,
|
||||
enable_remote_services: bool,
|
||||
artifacts_path: Optional[Union[Path, str]],
|
||||
options: PictureDescriptionVlmOptions,
|
||||
accelerator_options: AcceleratorOptions,
|
||||
):
|
||||
super().__init__(enabled=enabled, options=options)
|
||||
super().__init__(
|
||||
enabled=enabled,
|
||||
enable_remote_services=enable_remote_services,
|
||||
artifacts_path=artifacts_path,
|
||||
options=options,
|
||||
accelerator_options=accelerator_options,
|
||||
)
|
||||
self.options: PictureDescriptionVlmOptions
|
||||
|
||||
if self.enabled:
|
||||
|
@ -1,7 +1,7 @@
|
||||
import sys
|
||||
|
||||
from docling.models.easyocr_model import EasyOcrModel
|
||||
from docling.models.ocr_mac_model import OcrMacModel
|
||||
from docling.models.picture_description_api_model import PictureDescriptionApiModel
|
||||
from docling.models.picture_description_vlm_model import PictureDescriptionVlmModel
|
||||
from docling.models.rapid_ocr_model import RapidOcrModel
|
||||
from docling.models.tesseract_ocr_cli_model import TesseractOcrCliModel
|
||||
from docling.models.tesseract_ocr_model import TesseractOcrModel
|
||||
@ -17,3 +17,12 @@ def ocr_engines():
|
||||
TesseractOcrCliModel,
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
def picture_description():
|
||||
return {
|
||||
"picture_description": [
|
||||
PictureDescriptionVlmModel,
|
||||
PictureDescriptionApiModel,
|
||||
]
|
||||
}
|
@ -10,16 +10,7 @@ from docling.backend.abstract_backend import AbstractDocumentBackend
|
||||
from docling.backend.pdf_backend import PdfDocumentBackend
|
||||
from docling.datamodel.base_models import AssembledUnit, Page
|
||||
from docling.datamodel.document import ConversionResult
|
||||
from docling.datamodel.pipeline_options import (
|
||||
EasyOcrOptions,
|
||||
OcrMacOptions,
|
||||
PdfPipelineOptions,
|
||||
PictureDescriptionApiOptions,
|
||||
PictureDescriptionVlmOptions,
|
||||
RapidOcrOptions,
|
||||
TesseractCliOcrOptions,
|
||||
TesseractOcrOptions,
|
||||
)
|
||||
from docling.datamodel.pipeline_options import PdfPipelineOptions
|
||||
from docling.datamodel.settings import settings
|
||||
from docling.models.base_ocr_model import BaseOcrModel
|
||||
from docling.models.code_formula_model import CodeFormulaModel, CodeFormulaModelOptions
|
||||
@ -27,23 +18,16 @@ from docling.models.document_picture_classifier import (
|
||||
DocumentPictureClassifier,
|
||||
DocumentPictureClassifierOptions,
|
||||
)
|
||||
from docling.models.easyocr_model import EasyOcrModel
|
||||
from docling.models.factories import get_ocr_factory
|
||||
from docling.models.factories import get_ocr_factory, get_picture_description_factory
|
||||
from docling.models.layout_model import LayoutModel
|
||||
from docling.models.ocr_mac_model import OcrMacModel
|
||||
from docling.models.page_assemble_model import PageAssembleModel, PageAssembleOptions
|
||||
from docling.models.page_preprocessing_model import (
|
||||
PagePreprocessingModel,
|
||||
PagePreprocessingOptions,
|
||||
)
|
||||
from docling.models.picture_description_api_model import PictureDescriptionApiModel
|
||||
from docling.models.picture_description_base_model import PictureDescriptionBaseModel
|
||||
from docling.models.picture_description_vlm_model import PictureDescriptionVlmModel
|
||||
from docling.models.rapid_ocr_model import RapidOcrModel
|
||||
from docling.models.readingorder_model import ReadingOrderModel, ReadingOrderOptions
|
||||
from docling.models.table_structure_model import TableStructureModel
|
||||
from docling.models.tesseract_ocr_cli_model import TesseractOcrCliModel
|
||||
from docling.models.tesseract_ocr_model import TesseractOcrModel
|
||||
from docling.pipeline.base_pipeline import PaginatedPipeline
|
||||
from docling.utils.model_downloader import download_models
|
||||
from docling.utils.profiling import ProfilingScope, TimeRecorder
|
||||
@ -162,10 +146,8 @@ class StandardPdfPipeline(PaginatedPipeline):
|
||||
return output_dir
|
||||
|
||||
def get_ocr_model(self, artifacts_path: Optional[Path] = None) -> BaseOcrModel:
|
||||
ocr_factory = get_ocr_factory()
|
||||
ocr_engine_cls = ocr_factory.get_class(
|
||||
options=self.pipeline_options.ocr_options
|
||||
)
|
||||
factory = get_ocr_factory()
|
||||
ocr_engine_cls = factory.get_class(options=self.pipeline_options.ocr_options)
|
||||
|
||||
return ocr_engine_cls(
|
||||
enabled=self.pipeline_options.do_ocr,
|
||||
@ -177,26 +159,19 @@ class StandardPdfPipeline(PaginatedPipeline):
|
||||
def get_picture_description_model(
|
||||
self, artifacts_path: Optional[Path] = None
|
||||
) -> Optional[PictureDescriptionBaseModel]:
|
||||
if isinstance(
|
||||
self.pipeline_options.picture_description_options,
|
||||
PictureDescriptionApiOptions,
|
||||
):
|
||||
return PictureDescriptionApiModel(
|
||||
enabled=self.pipeline_options.do_picture_description,
|
||||
enable_remote_services=self.pipeline_options.enable_remote_services,
|
||||
options=self.pipeline_options.picture_description_options,
|
||||
)
|
||||
elif isinstance(
|
||||
self.pipeline_options.picture_description_options,
|
||||
PictureDescriptionVlmOptions,
|
||||
):
|
||||
return PictureDescriptionVlmModel(
|
||||
enabled=self.pipeline_options.do_picture_description,
|
||||
artifacts_path=artifacts_path,
|
||||
options=self.pipeline_options.picture_description_options,
|
||||
accelerator_options=self.pipeline_options.accelerator_options,
|
||||
)
|
||||
return None
|
||||
factory = get_picture_description_factory()
|
||||
|
||||
options_cls = factory.get_class(
|
||||
options=self.pipeline_options.picture_description_options
|
||||
)
|
||||
|
||||
return options_cls(
|
||||
enabled=self.pipeline_options.do_ocr,
|
||||
enable_remote_services=self.pipeline_options.enable_remote_services,
|
||||
artifacts_path=artifacts_path,
|
||||
options=self.pipeline_options.picture_description_options,
|
||||
accelerator_options=self.pipeline_options.accelerator_options,
|
||||
)
|
||||
|
||||
def initialize_page(self, conv_res: ConversionResult, page: Page) -> Page:
|
||||
with TimeRecorder(conv_res, "page_init"):
|
||||
|
@ -133,7 +133,7 @@ docling = "docling.cli.main:app"
|
||||
docling-tools = "docling.cli.tools:app"
|
||||
|
||||
[tool.poetry.plugins."docling"]
|
||||
"docling_ocr_engines" = "docling.models.plugins.ocr_engines"
|
||||
"docling_defaults" = "docling.models.plugins.defaults"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
|
Loading…
Reference in New Issue
Block a user