diff --git a/docling/cli/main.py b/docling/cli/main.py index 456c68a5..f752b8cd 100644 --- a/docling/cli/main.py +++ b/docling/cli/main.py @@ -9,6 +9,7 @@ import warnings from pathlib import Path from typing import Annotated, Dict, Iterable, List, Optional, Type +import rich.table import typer from docling_core.types.doc import ImageRefMode from docling_core.utils.file import resolve_source_to_path @@ -29,6 +30,7 @@ from docling.datamodel.pipeline_options import ( AcceleratorDevice, AcceleratorOptions, EasyOcrOptions, + OcrOptions, PdfBackend, PdfPipelineOptions, TableFormerMode, @@ -43,10 +45,11 @@ warnings.filterwarnings(action="ignore", category=FutureWarning, module="easyocr _log = logging.getLogger(__name__) from rich.console import Console +console = Console() err_console = Console(stderr=True) -ocr_factory = get_ocr_factory() -ocr_engines_enum = ocr_factory.get_enum() +ocr_factory_internal = get_ocr_factory(allow_external_plugins=False) +ocr_engines_enum_internal = ocr_factory_internal.get_enum() app = typer.Typer( name="Docling", @@ -74,6 +77,24 @@ def version_callback(value: bool): raise typer.Exit() +def show_external_plugins_callback(value: bool): + if value: + ocr_factory_all = get_ocr_factory(allow_external_plugins=True) + table = rich.table.Table(title="Available OCR engines") + table.add_column("Name", justify="right") + table.add_column("Plugin") + table.add_column("Package") + for meta in ocr_factory_all.registered_meta.values(): + if not meta.module.startswith("docling."): + table.add_row( + f"[bold]{meta.kind}[/bold]", + meta.plugin_name, + meta.module.split(".")[0], + ) + rich.print(table) + raise typer.Exit() + + def export_documents( conv_results: Iterable[ConversionResult], output_dir: Path, @@ -191,10 +212,16 @@ def convert( help="Replace any existing text with OCR generated text over the full content.", ), ] = False, - ocr_engine: Annotated[ # type: ignore - ocr_engines_enum, - # ocr_factory.get_registered_enum(), - typer.Option(..., help="The OCR engine to use."), + ocr_engine: Annotated[ + str, + typer.Option( + ..., + help=( + f"The OCR engine to use. When --allow-external-plugins is *not* set, the available values are: " + f"{', '.join((o.value for o in ocr_engines_enum_internal))}. " + f"Use the option --show-external-plugins to see the options allowed with external plugins." + ), + ), ] = EasyOcrOptions.kind, ocr_lang: Annotated[ Optional[str], @@ -239,6 +266,21 @@ def convert( ..., help="Must be enabled when using models connecting to remote services." ), ] = False, + allow_external_plugins: Annotated[ + bool, + typer.Option( + ..., help="Must be enabled for loading modules from third-party plugins." + ), + ] = False, + show_external_plugins: Annotated[ + bool, + typer.Option( + ..., + help="List the third-party plugins which are available when the option --allow-external-plugins is set.", + callback=show_external_plugins_callback, + is_eager=True, + ), + ] = False, abort_on_error: Annotated[ bool, typer.Option( @@ -366,8 +408,9 @@ def convert( export_txt = OutputFormat.TEXT in to_formats export_doctags = OutputFormat.DOCTAGS in to_formats - ocr_options = ocr_factory.create_options( - kind=str(ocr_engine.value), # type:ignore + ocr_factory = get_ocr_factory(allow_external_plugins=allow_external_plugins) + ocr_options: OcrOptions = ocr_factory.create_options( + kind=ocr_engine, force_full_page_ocr=force_ocr, ) @@ -377,6 +420,7 @@ def convert( accelerator_options = AcceleratorOptions(num_threads=num_threads, device=device) pipeline_options = PdfPipelineOptions( + allow_external_plugins=allow_external_plugins, enable_remote_services=enable_remote_services, accelerator_options=accelerator_options, do_ocr=ocr, diff --git a/docling/datamodel/pipeline_options.py b/docling/datamodel/pipeline_options.py index 0d47a09c..d3dcd136 100644 --- a/docling/datamodel/pipeline_options.py +++ b/docling/datamodel/pipeline_options.py @@ -322,6 +322,7 @@ class PipelineOptions(BaseModel): document_timeout: Optional[float] = None accelerator_options: AcceleratorOptions = AcceleratorOptions() enable_remote_services: bool = False + allow_external_plugins: bool = False class PaginatedPipelineOptions(PipelineOptions): diff --git a/docling/models/factories/__init__.py b/docling/models/factories/__init__.py index 5d527ed5..6c7bc1bc 100644 --- a/docling/models/factories/__init__.py +++ b/docling/models/factories/__init__.py @@ -9,17 +9,17 @@ from docling.models.factories.picture_description_factory import ( logger = logging.getLogger(__name__) -@lru_cache(maxsize=1) -def get_ocr_factory(): +@lru_cache() +def get_ocr_factory(allow_external_plugins: bool = False): factory = OcrFactory() - factory.load_from_plugins() + factory.load_from_plugins(allow_external_plugins=allow_external_plugins) logger.info("Registered ocr engines: %r", factory.registered_kind) return factory -@lru_cache(maxsize=1) -def get_picture_description_factory(): +@lru_cache() +def get_picture_description_factory(allow_external_plugins: bool = False): factory = PictureDescriptionFactory() - factory.load_from_plugins() + factory.load_from_plugins(allow_external_plugins=allow_external_plugins) logger.info("Registered picture descriptions: %r", factory.registered_kind) return factory diff --git a/docling/models/factories/base_factory.py b/docling/models/factories/base_factory.py index 26a659ef..542fc7e6 100644 --- a/docling/models/factories/base_factory.py +++ b/docling/models/factories/base_factory.py @@ -4,6 +4,7 @@ from abc import ABCMeta from typing import Generic, Optional, Type, TypeVar from pluggy import PluginManager +from pydantic import BaseModel from docling.datamodel.pipeline_options import BaseOptions from docling.models.base_model import BaseModelWithOptions @@ -14,6 +15,12 @@ A = TypeVar("A", bound=BaseModelWithOptions) logger = logging.getLogger(__name__) +class FactoryMeta(BaseModel): + kind: str + plugin_name: str + module: str + + class BaseFactory(Generic[A], metaclass=ABCMeta): default_plugin_name = "docling" @@ -22,6 +29,7 @@ class BaseFactory(Generic[A], metaclass=ABCMeta): self.plugin_attr_name = plugin_attr_name self._classes: dict[Type[BaseOptions], Type[A]] = {} + self._meta: dict[Type[BaseOptions], FactoryMeta] = {} @property def registered_kind(self) -> list[str]: @@ -39,6 +47,10 @@ class BaseFactory(Generic[A], metaclass=ABCMeta): def classes(self): return self._classes + @property + def registered_meta(self): + return self._meta + def create_instance(self, options: BaseOptions, **kwargs) -> A: try: _cls = self._classes[type(options)] @@ -62,7 +74,7 @@ class BaseFactory(Generic[A], metaclass=ABCMeta): return f"No class found with the name {kind!r}, known classes are:\n{msg_str}" - def register(self, cls: Type[A]): + def register(self, cls: Type[A], plugin_name: str, plugin_module_name: str): opt_type = cls.get_options_type() if opt_type in self._classes: @@ -71,14 +83,28 @@ class BaseFactory(Generic[A], metaclass=ABCMeta): ) self._classes[opt_type] = cls + self._meta[opt_type] = FactoryMeta( + kind=opt_type.kind, plugin_name=plugin_name, module=plugin_module_name + ) - def load_from_plugins(self, plugin_name: Optional[str] = None): + def load_from_plugins( + self, plugin_name: Optional[str] = None, allow_external_plugins: bool = False + ): plugin_name = plugin_name or self.plugin_name plugin_manager = PluginManager(plugin_name) plugin_manager.load_setuptools_entrypoints(plugin_name) for plugin_name, plugin_module in plugin_manager.list_name_plugin(): + plugin_module_name = str(plugin_module.__name__) # type: ignore + + if not allow_external_plugins and not plugin_module_name.startswith( + "docling." + ): + logger.warning( + f"The plugin {plugin_name} will not be loaded because Docling is being executed with allow_external_plugins=false." + ) + continue attr = getattr(plugin_module, self.plugin_attr_name, None) @@ -86,11 +112,11 @@ class BaseFactory(Generic[A], metaclass=ABCMeta): logger.info("Loading plugin %r", plugin_name) config = attr() - self.process_plugin(config) + self.process_plugin(config, plugin_name, plugin_module_name) - def process_plugin(self, config): + def process_plugin(self, config, plugin_name: str, plugin_module_name: str): for item in config[self.plugin_attr_name]: try: - self.register(item) + self.register(item, plugin_name, plugin_module_name) except ValueError: logger.warning("%r already registered", item) diff --git a/docling/models/factories/ocr_factory.py b/docling/models/factories/ocr_factory.py index 1153baaa..34fc7c43 100644 --- a/docling/models/factories/ocr_factory.py +++ b/docling/models/factories/ocr_factory.py @@ -9,19 +9,3 @@ logger = logging.getLogger(__name__) class OcrFactory(BaseFactory[BaseOcrModel]): def __init__(self, *args, **kwargs): super().__init__("ocr_engines", *args, **kwargs) - - -# def on_class_not_found(self, kind: str, *args, **kwargs): - -# raise NoSuchOcrEngine(kind, self.registered_kind) - - -# class NoSuchOcrEngine(Exception): -# def __init__(self, ocr_engine_kind, known_engines=None): -# if known_engines is None: -# known_engines = [] -# super(NoSuchOcrEngine, self).__init__( -# "No OCR engine found with the name '%s', known engines are: %r", -# ocr_engine_kind, -# [cls.__name__ for cls in known_engines], -# ) diff --git a/docling/pipeline/standard_pdf_pipeline.py b/docling/pipeline/standard_pdf_pipeline.py index f8b11f57..07940585 100644 --- a/docling/pipeline/standard_pdf_pipeline.py +++ b/docling/pipeline/standard_pdf_pipeline.py @@ -146,7 +146,9 @@ class StandardPdfPipeline(PaginatedPipeline): return output_dir def get_ocr_model(self, artifacts_path: Optional[Path] = None) -> BaseOcrModel: - factory = get_ocr_factory() + factory = get_ocr_factory( + allow_external_plugins=self.pipeline_options.allow_external_plugins + ) return factory.create_instance( options=self.pipeline_options.ocr_options, enabled=self.pipeline_options.do_ocr, @@ -157,7 +159,9 @@ class StandardPdfPipeline(PaginatedPipeline): def get_picture_description_model( self, artifacts_path: Optional[Path] = None ) -> Optional[PictureDescriptionBaseModel]: - factory = get_picture_description_factory() + factory = get_picture_description_factory( + allow_external_plugins=self.pipeline_options.allow_external_plugins + ) return factory.create_instance( options=self.pipeline_options.picture_description_options, enabled=self.pipeline_options.do_picture_description,