add allow_external_plugins option

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>
This commit is contained in:
Michele Dolfi 2025-03-18 09:07:28 +01:00
parent 66fe0049fb
commit 9b4c2e3fdf
6 changed files with 96 additions and 37 deletions

View File

@ -9,6 +9,7 @@ import warnings
from pathlib import Path
from typing import Annotated, Dict, Iterable, List, Optional, Type
import rich.table
import typer
from docling_core.types.doc import ImageRefMode
from docling_core.utils.file import resolve_source_to_path
@ -29,6 +30,7 @@ from docling.datamodel.pipeline_options import (
AcceleratorDevice,
AcceleratorOptions,
EasyOcrOptions,
OcrOptions,
PdfBackend,
PdfPipelineOptions,
TableFormerMode,
@ -43,10 +45,11 @@ warnings.filterwarnings(action="ignore", category=FutureWarning, module="easyocr
_log = logging.getLogger(__name__)
from rich.console import Console
console = Console()
err_console = Console(stderr=True)
ocr_factory = get_ocr_factory()
ocr_engines_enum = ocr_factory.get_enum()
ocr_factory_internal = get_ocr_factory(allow_external_plugins=False)
ocr_engines_enum_internal = ocr_factory_internal.get_enum()
app = typer.Typer(
name="Docling",
@ -74,6 +77,24 @@ def version_callback(value: bool):
raise typer.Exit()
def show_external_plugins_callback(value: bool):
if value:
ocr_factory_all = get_ocr_factory(allow_external_plugins=True)
table = rich.table.Table(title="Available OCR engines")
table.add_column("Name", justify="right")
table.add_column("Plugin")
table.add_column("Package")
for meta in ocr_factory_all.registered_meta.values():
if not meta.module.startswith("docling."):
table.add_row(
f"[bold]{meta.kind}[/bold]",
meta.plugin_name,
meta.module.split(".")[0],
)
rich.print(table)
raise typer.Exit()
def export_documents(
conv_results: Iterable[ConversionResult],
output_dir: Path,
@ -191,10 +212,16 @@ def convert(
help="Replace any existing text with OCR generated text over the full content.",
),
] = False,
ocr_engine: Annotated[ # type: ignore
ocr_engines_enum,
# ocr_factory.get_registered_enum(),
typer.Option(..., help="The OCR engine to use."),
ocr_engine: Annotated[
str,
typer.Option(
...,
help=(
f"The OCR engine to use. When --allow-external-plugins is *not* set, the available values are: "
f"{', '.join((o.value for o in ocr_engines_enum_internal))}. "
f"Use the option --show-external-plugins to see the options allowed with external plugins."
),
),
] = EasyOcrOptions.kind,
ocr_lang: Annotated[
Optional[str],
@ -239,6 +266,21 @@ def convert(
..., help="Must be enabled when using models connecting to remote services."
),
] = False,
allow_external_plugins: Annotated[
bool,
typer.Option(
..., help="Must be enabled for loading modules from third-party plugins."
),
] = False,
show_external_plugins: Annotated[
bool,
typer.Option(
...,
help="List the third-party plugins which are available when the option --allow-external-plugins is set.",
callback=show_external_plugins_callback,
is_eager=True,
),
] = False,
abort_on_error: Annotated[
bool,
typer.Option(
@ -366,8 +408,9 @@ def convert(
export_txt = OutputFormat.TEXT in to_formats
export_doctags = OutputFormat.DOCTAGS in to_formats
ocr_options = ocr_factory.create_options(
kind=str(ocr_engine.value), # type:ignore
ocr_factory = get_ocr_factory(allow_external_plugins=allow_external_plugins)
ocr_options: OcrOptions = ocr_factory.create_options(
kind=ocr_engine,
force_full_page_ocr=force_ocr,
)
@ -377,6 +420,7 @@ def convert(
accelerator_options = AcceleratorOptions(num_threads=num_threads, device=device)
pipeline_options = PdfPipelineOptions(
allow_external_plugins=allow_external_plugins,
enable_remote_services=enable_remote_services,
accelerator_options=accelerator_options,
do_ocr=ocr,

View File

@ -322,6 +322,7 @@ class PipelineOptions(BaseModel):
document_timeout: Optional[float] = None
accelerator_options: AcceleratorOptions = AcceleratorOptions()
enable_remote_services: bool = False
allow_external_plugins: bool = False
class PaginatedPipelineOptions(PipelineOptions):

View File

@ -9,17 +9,17 @@ from docling.models.factories.picture_description_factory import (
logger = logging.getLogger(__name__)
@lru_cache(maxsize=1)
def get_ocr_factory():
@lru_cache()
def get_ocr_factory(allow_external_plugins: bool = False):
factory = OcrFactory()
factory.load_from_plugins()
factory.load_from_plugins(allow_external_plugins=allow_external_plugins)
logger.info("Registered ocr engines: %r", factory.registered_kind)
return factory
@lru_cache(maxsize=1)
def get_picture_description_factory():
@lru_cache()
def get_picture_description_factory(allow_external_plugins: bool = False):
factory = PictureDescriptionFactory()
factory.load_from_plugins()
factory.load_from_plugins(allow_external_plugins=allow_external_plugins)
logger.info("Registered picture descriptions: %r", factory.registered_kind)
return factory

View File

@ -4,6 +4,7 @@ from abc import ABCMeta
from typing import Generic, Optional, Type, TypeVar
from pluggy import PluginManager
from pydantic import BaseModel
from docling.datamodel.pipeline_options import BaseOptions
from docling.models.base_model import BaseModelWithOptions
@ -14,6 +15,12 @@ A = TypeVar("A", bound=BaseModelWithOptions)
logger = logging.getLogger(__name__)
class FactoryMeta(BaseModel):
kind: str
plugin_name: str
module: str
class BaseFactory(Generic[A], metaclass=ABCMeta):
default_plugin_name = "docling"
@ -22,6 +29,7 @@ class BaseFactory(Generic[A], metaclass=ABCMeta):
self.plugin_attr_name = plugin_attr_name
self._classes: dict[Type[BaseOptions], Type[A]] = {}
self._meta: dict[Type[BaseOptions], FactoryMeta] = {}
@property
def registered_kind(self) -> list[str]:
@ -39,6 +47,10 @@ class BaseFactory(Generic[A], metaclass=ABCMeta):
def classes(self):
return self._classes
@property
def registered_meta(self):
return self._meta
def create_instance(self, options: BaseOptions, **kwargs) -> A:
try:
_cls = self._classes[type(options)]
@ -62,7 +74,7 @@ class BaseFactory(Generic[A], metaclass=ABCMeta):
return f"No class found with the name {kind!r}, known classes are:\n{msg_str}"
def register(self, cls: Type[A]):
def register(self, cls: Type[A], plugin_name: str, plugin_module_name: str):
opt_type = cls.get_options_type()
if opt_type in self._classes:
@ -71,14 +83,28 @@ class BaseFactory(Generic[A], metaclass=ABCMeta):
)
self._classes[opt_type] = cls
self._meta[opt_type] = FactoryMeta(
kind=opt_type.kind, plugin_name=plugin_name, module=plugin_module_name
)
def load_from_plugins(self, plugin_name: Optional[str] = None):
def load_from_plugins(
self, plugin_name: Optional[str] = None, allow_external_plugins: bool = False
):
plugin_name = plugin_name or self.plugin_name
plugin_manager = PluginManager(plugin_name)
plugin_manager.load_setuptools_entrypoints(plugin_name)
for plugin_name, plugin_module in plugin_manager.list_name_plugin():
plugin_module_name = str(plugin_module.__name__) # type: ignore
if not allow_external_plugins and not plugin_module_name.startswith(
"docling."
):
logger.warning(
f"The plugin {plugin_name} will not be loaded because Docling is being executed with allow_external_plugins=false."
)
continue
attr = getattr(plugin_module, self.plugin_attr_name, None)
@ -86,11 +112,11 @@ class BaseFactory(Generic[A], metaclass=ABCMeta):
logger.info("Loading plugin %r", plugin_name)
config = attr()
self.process_plugin(config)
self.process_plugin(config, plugin_name, plugin_module_name)
def process_plugin(self, config):
def process_plugin(self, config, plugin_name: str, plugin_module_name: str):
for item in config[self.plugin_attr_name]:
try:
self.register(item)
self.register(item, plugin_name, plugin_module_name)
except ValueError:
logger.warning("%r already registered", item)

View File

@ -9,19 +9,3 @@ logger = logging.getLogger(__name__)
class OcrFactory(BaseFactory[BaseOcrModel]):
def __init__(self, *args, **kwargs):
super().__init__("ocr_engines", *args, **kwargs)
# def on_class_not_found(self, kind: str, *args, **kwargs):
# raise NoSuchOcrEngine(kind, self.registered_kind)
# class NoSuchOcrEngine(Exception):
# def __init__(self, ocr_engine_kind, known_engines=None):
# if known_engines is None:
# known_engines = []
# super(NoSuchOcrEngine, self).__init__(
# "No OCR engine found with the name '%s', known engines are: %r",
# ocr_engine_kind,
# [cls.__name__ for cls in known_engines],
# )

View File

@ -146,7 +146,9 @@ class StandardPdfPipeline(PaginatedPipeline):
return output_dir
def get_ocr_model(self, artifacts_path: Optional[Path] = None) -> BaseOcrModel:
factory = get_ocr_factory()
factory = get_ocr_factory(
allow_external_plugins=self.pipeline_options.allow_external_plugins
)
return factory.create_instance(
options=self.pipeline_options.ocr_options,
enabled=self.pipeline_options.do_ocr,
@ -157,7 +159,9 @@ class StandardPdfPipeline(PaginatedPipeline):
def get_picture_description_model(
self, artifacts_path: Optional[Path] = None
) -> Optional[PictureDescriptionBaseModel]:
factory = get_picture_description_factory()
factory = get_picture_description_factory(
allow_external_plugins=self.pipeline_options.allow_external_plugins
)
return factory.create_instance(
options=self.pipeline_options.picture_description_options,
enabled=self.pipeline_options.do_picture_description,