add allow_external_plugins option

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>
This commit is contained in:
Michele Dolfi 2025-03-18 09:07:28 +01:00
parent 66fe0049fb
commit 9b4c2e3fdf
6 changed files with 96 additions and 37 deletions

View File

@ -9,6 +9,7 @@ import warnings
from pathlib import Path from pathlib import Path
from typing import Annotated, Dict, Iterable, List, Optional, Type from typing import Annotated, Dict, Iterable, List, Optional, Type
import rich.table
import typer import typer
from docling_core.types.doc import ImageRefMode from docling_core.types.doc import ImageRefMode
from docling_core.utils.file import resolve_source_to_path from docling_core.utils.file import resolve_source_to_path
@ -29,6 +30,7 @@ from docling.datamodel.pipeline_options import (
AcceleratorDevice, AcceleratorDevice,
AcceleratorOptions, AcceleratorOptions,
EasyOcrOptions, EasyOcrOptions,
OcrOptions,
PdfBackend, PdfBackend,
PdfPipelineOptions, PdfPipelineOptions,
TableFormerMode, TableFormerMode,
@ -43,10 +45,11 @@ warnings.filterwarnings(action="ignore", category=FutureWarning, module="easyocr
_log = logging.getLogger(__name__) _log = logging.getLogger(__name__)
from rich.console import Console from rich.console import Console
console = Console()
err_console = Console(stderr=True) err_console = Console(stderr=True)
ocr_factory = get_ocr_factory() ocr_factory_internal = get_ocr_factory(allow_external_plugins=False)
ocr_engines_enum = ocr_factory.get_enum() ocr_engines_enum_internal = ocr_factory_internal.get_enum()
app = typer.Typer( app = typer.Typer(
name="Docling", name="Docling",
@ -74,6 +77,24 @@ def version_callback(value: bool):
raise typer.Exit() raise typer.Exit()
def show_external_plugins_callback(value: bool):
if value:
ocr_factory_all = get_ocr_factory(allow_external_plugins=True)
table = rich.table.Table(title="Available OCR engines")
table.add_column("Name", justify="right")
table.add_column("Plugin")
table.add_column("Package")
for meta in ocr_factory_all.registered_meta.values():
if not meta.module.startswith("docling."):
table.add_row(
f"[bold]{meta.kind}[/bold]",
meta.plugin_name,
meta.module.split(".")[0],
)
rich.print(table)
raise typer.Exit()
def export_documents( def export_documents(
conv_results: Iterable[ConversionResult], conv_results: Iterable[ConversionResult],
output_dir: Path, output_dir: Path,
@ -191,10 +212,16 @@ def convert(
help="Replace any existing text with OCR generated text over the full content.", help="Replace any existing text with OCR generated text over the full content.",
), ),
] = False, ] = False,
ocr_engine: Annotated[ # type: ignore ocr_engine: Annotated[
ocr_engines_enum, str,
# ocr_factory.get_registered_enum(), typer.Option(
typer.Option(..., help="The OCR engine to use."), ...,
help=(
f"The OCR engine to use. When --allow-external-plugins is *not* set, the available values are: "
f"{', '.join((o.value for o in ocr_engines_enum_internal))}. "
f"Use the option --show-external-plugins to see the options allowed with external plugins."
),
),
] = EasyOcrOptions.kind, ] = EasyOcrOptions.kind,
ocr_lang: Annotated[ ocr_lang: Annotated[
Optional[str], Optional[str],
@ -239,6 +266,21 @@ def convert(
..., help="Must be enabled when using models connecting to remote services." ..., help="Must be enabled when using models connecting to remote services."
), ),
] = False, ] = False,
allow_external_plugins: Annotated[
bool,
typer.Option(
..., help="Must be enabled for loading modules from third-party plugins."
),
] = False,
show_external_plugins: Annotated[
bool,
typer.Option(
...,
help="List the third-party plugins which are available when the option --allow-external-plugins is set.",
callback=show_external_plugins_callback,
is_eager=True,
),
] = False,
abort_on_error: Annotated[ abort_on_error: Annotated[
bool, bool,
typer.Option( typer.Option(
@ -366,8 +408,9 @@ def convert(
export_txt = OutputFormat.TEXT in to_formats export_txt = OutputFormat.TEXT in to_formats
export_doctags = OutputFormat.DOCTAGS in to_formats export_doctags = OutputFormat.DOCTAGS in to_formats
ocr_options = ocr_factory.create_options( ocr_factory = get_ocr_factory(allow_external_plugins=allow_external_plugins)
kind=str(ocr_engine.value), # type:ignore ocr_options: OcrOptions = ocr_factory.create_options(
kind=ocr_engine,
force_full_page_ocr=force_ocr, force_full_page_ocr=force_ocr,
) )
@ -377,6 +420,7 @@ def convert(
accelerator_options = AcceleratorOptions(num_threads=num_threads, device=device) accelerator_options = AcceleratorOptions(num_threads=num_threads, device=device)
pipeline_options = PdfPipelineOptions( pipeline_options = PdfPipelineOptions(
allow_external_plugins=allow_external_plugins,
enable_remote_services=enable_remote_services, enable_remote_services=enable_remote_services,
accelerator_options=accelerator_options, accelerator_options=accelerator_options,
do_ocr=ocr, do_ocr=ocr,

View File

@ -322,6 +322,7 @@ class PipelineOptions(BaseModel):
document_timeout: Optional[float] = None document_timeout: Optional[float] = None
accelerator_options: AcceleratorOptions = AcceleratorOptions() accelerator_options: AcceleratorOptions = AcceleratorOptions()
enable_remote_services: bool = False enable_remote_services: bool = False
allow_external_plugins: bool = False
class PaginatedPipelineOptions(PipelineOptions): class PaginatedPipelineOptions(PipelineOptions):

View File

@ -9,17 +9,17 @@ from docling.models.factories.picture_description_factory import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@lru_cache(maxsize=1) @lru_cache()
def get_ocr_factory(): def get_ocr_factory(allow_external_plugins: bool = False):
factory = OcrFactory() factory = OcrFactory()
factory.load_from_plugins() factory.load_from_plugins(allow_external_plugins=allow_external_plugins)
logger.info("Registered ocr engines: %r", factory.registered_kind) logger.info("Registered ocr engines: %r", factory.registered_kind)
return factory return factory
@lru_cache(maxsize=1) @lru_cache()
def get_picture_description_factory(): def get_picture_description_factory(allow_external_plugins: bool = False):
factory = PictureDescriptionFactory() factory = PictureDescriptionFactory()
factory.load_from_plugins() factory.load_from_plugins(allow_external_plugins=allow_external_plugins)
logger.info("Registered picture descriptions: %r", factory.registered_kind) logger.info("Registered picture descriptions: %r", factory.registered_kind)
return factory return factory

View File

@ -4,6 +4,7 @@ from abc import ABCMeta
from typing import Generic, Optional, Type, TypeVar from typing import Generic, Optional, Type, TypeVar
from pluggy import PluginManager from pluggy import PluginManager
from pydantic import BaseModel
from docling.datamodel.pipeline_options import BaseOptions from docling.datamodel.pipeline_options import BaseOptions
from docling.models.base_model import BaseModelWithOptions from docling.models.base_model import BaseModelWithOptions
@ -14,6 +15,12 @@ A = TypeVar("A", bound=BaseModelWithOptions)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class FactoryMeta(BaseModel):
kind: str
plugin_name: str
module: str
class BaseFactory(Generic[A], metaclass=ABCMeta): class BaseFactory(Generic[A], metaclass=ABCMeta):
default_plugin_name = "docling" default_plugin_name = "docling"
@ -22,6 +29,7 @@ class BaseFactory(Generic[A], metaclass=ABCMeta):
self.plugin_attr_name = plugin_attr_name self.plugin_attr_name = plugin_attr_name
self._classes: dict[Type[BaseOptions], Type[A]] = {} self._classes: dict[Type[BaseOptions], Type[A]] = {}
self._meta: dict[Type[BaseOptions], FactoryMeta] = {}
@property @property
def registered_kind(self) -> list[str]: def registered_kind(self) -> list[str]:
@ -39,6 +47,10 @@ class BaseFactory(Generic[A], metaclass=ABCMeta):
def classes(self): def classes(self):
return self._classes return self._classes
@property
def registered_meta(self):
return self._meta
def create_instance(self, options: BaseOptions, **kwargs) -> A: def create_instance(self, options: BaseOptions, **kwargs) -> A:
try: try:
_cls = self._classes[type(options)] _cls = self._classes[type(options)]
@ -62,7 +74,7 @@ class BaseFactory(Generic[A], metaclass=ABCMeta):
return f"No class found with the name {kind!r}, known classes are:\n{msg_str}" return f"No class found with the name {kind!r}, known classes are:\n{msg_str}"
def register(self, cls: Type[A]): def register(self, cls: Type[A], plugin_name: str, plugin_module_name: str):
opt_type = cls.get_options_type() opt_type = cls.get_options_type()
if opt_type in self._classes: if opt_type in self._classes:
@ -71,14 +83,28 @@ class BaseFactory(Generic[A], metaclass=ABCMeta):
) )
self._classes[opt_type] = cls self._classes[opt_type] = cls
self._meta[opt_type] = FactoryMeta(
kind=opt_type.kind, plugin_name=plugin_name, module=plugin_module_name
)
def load_from_plugins(self, plugin_name: Optional[str] = None): def load_from_plugins(
self, plugin_name: Optional[str] = None, allow_external_plugins: bool = False
):
plugin_name = plugin_name or self.plugin_name plugin_name = plugin_name or self.plugin_name
plugin_manager = PluginManager(plugin_name) plugin_manager = PluginManager(plugin_name)
plugin_manager.load_setuptools_entrypoints(plugin_name) plugin_manager.load_setuptools_entrypoints(plugin_name)
for plugin_name, plugin_module in plugin_manager.list_name_plugin(): for plugin_name, plugin_module in plugin_manager.list_name_plugin():
plugin_module_name = str(plugin_module.__name__) # type: ignore
if not allow_external_plugins and not plugin_module_name.startswith(
"docling."
):
logger.warning(
f"The plugin {plugin_name} will not be loaded because Docling is being executed with allow_external_plugins=false."
)
continue
attr = getattr(plugin_module, self.plugin_attr_name, None) attr = getattr(plugin_module, self.plugin_attr_name, None)
@ -86,11 +112,11 @@ class BaseFactory(Generic[A], metaclass=ABCMeta):
logger.info("Loading plugin %r", plugin_name) logger.info("Loading plugin %r", plugin_name)
config = attr() config = attr()
self.process_plugin(config) self.process_plugin(config, plugin_name, plugin_module_name)
def process_plugin(self, config): def process_plugin(self, config, plugin_name: str, plugin_module_name: str):
for item in config[self.plugin_attr_name]: for item in config[self.plugin_attr_name]:
try: try:
self.register(item) self.register(item, plugin_name, plugin_module_name)
except ValueError: except ValueError:
logger.warning("%r already registered", item) logger.warning("%r already registered", item)

View File

@ -9,19 +9,3 @@ logger = logging.getLogger(__name__)
class OcrFactory(BaseFactory[BaseOcrModel]): class OcrFactory(BaseFactory[BaseOcrModel]):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__("ocr_engines", *args, **kwargs) super().__init__("ocr_engines", *args, **kwargs)
# def on_class_not_found(self, kind: str, *args, **kwargs):
# raise NoSuchOcrEngine(kind, self.registered_kind)
# class NoSuchOcrEngine(Exception):
# def __init__(self, ocr_engine_kind, known_engines=None):
# if known_engines is None:
# known_engines = []
# super(NoSuchOcrEngine, self).__init__(
# "No OCR engine found with the name '%s', known engines are: %r",
# ocr_engine_kind,
# [cls.__name__ for cls in known_engines],
# )

View File

@ -146,7 +146,9 @@ class StandardPdfPipeline(PaginatedPipeline):
return output_dir return output_dir
def get_ocr_model(self, artifacts_path: Optional[Path] = None) -> BaseOcrModel: def get_ocr_model(self, artifacts_path: Optional[Path] = None) -> BaseOcrModel:
factory = get_ocr_factory() factory = get_ocr_factory(
allow_external_plugins=self.pipeline_options.allow_external_plugins
)
return factory.create_instance( return factory.create_instance(
options=self.pipeline_options.ocr_options, options=self.pipeline_options.ocr_options,
enabled=self.pipeline_options.do_ocr, enabled=self.pipeline_options.do_ocr,
@ -157,7 +159,9 @@ class StandardPdfPipeline(PaginatedPipeline):
def get_picture_description_model( def get_picture_description_model(
self, artifacts_path: Optional[Path] = None self, artifacts_path: Optional[Path] = None
) -> Optional[PictureDescriptionBaseModel]: ) -> Optional[PictureDescriptionBaseModel]:
factory = get_picture_description_factory() factory = get_picture_description_factory(
allow_external_plugins=self.pipeline_options.allow_external_plugins
)
return factory.create_instance( return factory.create_instance(
options=self.pipeline_options.picture_description_options, options=self.pipeline_options.picture_description_options,
enabled=self.pipeline_options.do_picture_description, enabled=self.pipeline_options.do_picture_description,