mirror of
https://github.com/DS4SD/docling.git
synced 2025-07-29 21:44:32 +00:00
add allow_external_plugins option
Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>
This commit is contained in:
parent
66fe0049fb
commit
9b4c2e3fdf
@ -9,6 +9,7 @@ import warnings
|
||||
from pathlib import Path
|
||||
from typing import Annotated, Dict, Iterable, List, Optional, Type
|
||||
|
||||
import rich.table
|
||||
import typer
|
||||
from docling_core.types.doc import ImageRefMode
|
||||
from docling_core.utils.file import resolve_source_to_path
|
||||
@ -29,6 +30,7 @@ from docling.datamodel.pipeline_options import (
|
||||
AcceleratorDevice,
|
||||
AcceleratorOptions,
|
||||
EasyOcrOptions,
|
||||
OcrOptions,
|
||||
PdfBackend,
|
||||
PdfPipelineOptions,
|
||||
TableFormerMode,
|
||||
@ -43,10 +45,11 @@ warnings.filterwarnings(action="ignore", category=FutureWarning, module="easyocr
|
||||
_log = logging.getLogger(__name__)
|
||||
from rich.console import Console
|
||||
|
||||
console = Console()
|
||||
err_console = Console(stderr=True)
|
||||
|
||||
ocr_factory = get_ocr_factory()
|
||||
ocr_engines_enum = ocr_factory.get_enum()
|
||||
ocr_factory_internal = get_ocr_factory(allow_external_plugins=False)
|
||||
ocr_engines_enum_internal = ocr_factory_internal.get_enum()
|
||||
|
||||
app = typer.Typer(
|
||||
name="Docling",
|
||||
@ -74,6 +77,24 @@ def version_callback(value: bool):
|
||||
raise typer.Exit()
|
||||
|
||||
|
||||
def show_external_plugins_callback(value: bool):
|
||||
if value:
|
||||
ocr_factory_all = get_ocr_factory(allow_external_plugins=True)
|
||||
table = rich.table.Table(title="Available OCR engines")
|
||||
table.add_column("Name", justify="right")
|
||||
table.add_column("Plugin")
|
||||
table.add_column("Package")
|
||||
for meta in ocr_factory_all.registered_meta.values():
|
||||
if not meta.module.startswith("docling."):
|
||||
table.add_row(
|
||||
f"[bold]{meta.kind}[/bold]",
|
||||
meta.plugin_name,
|
||||
meta.module.split(".")[0],
|
||||
)
|
||||
rich.print(table)
|
||||
raise typer.Exit()
|
||||
|
||||
|
||||
def export_documents(
|
||||
conv_results: Iterable[ConversionResult],
|
||||
output_dir: Path,
|
||||
@ -191,10 +212,16 @@ def convert(
|
||||
help="Replace any existing text with OCR generated text over the full content.",
|
||||
),
|
||||
] = False,
|
||||
ocr_engine: Annotated[ # type: ignore
|
||||
ocr_engines_enum,
|
||||
# ocr_factory.get_registered_enum(),
|
||||
typer.Option(..., help="The OCR engine to use."),
|
||||
ocr_engine: Annotated[
|
||||
str,
|
||||
typer.Option(
|
||||
...,
|
||||
help=(
|
||||
f"The OCR engine to use. When --allow-external-plugins is *not* set, the available values are: "
|
||||
f"{', '.join((o.value for o in ocr_engines_enum_internal))}. "
|
||||
f"Use the option --show-external-plugins to see the options allowed with external plugins."
|
||||
),
|
||||
),
|
||||
] = EasyOcrOptions.kind,
|
||||
ocr_lang: Annotated[
|
||||
Optional[str],
|
||||
@ -239,6 +266,21 @@ def convert(
|
||||
..., help="Must be enabled when using models connecting to remote services."
|
||||
),
|
||||
] = False,
|
||||
allow_external_plugins: Annotated[
|
||||
bool,
|
||||
typer.Option(
|
||||
..., help="Must be enabled for loading modules from third-party plugins."
|
||||
),
|
||||
] = False,
|
||||
show_external_plugins: Annotated[
|
||||
bool,
|
||||
typer.Option(
|
||||
...,
|
||||
help="List the third-party plugins which are available when the option --allow-external-plugins is set.",
|
||||
callback=show_external_plugins_callback,
|
||||
is_eager=True,
|
||||
),
|
||||
] = False,
|
||||
abort_on_error: Annotated[
|
||||
bool,
|
||||
typer.Option(
|
||||
@ -366,8 +408,9 @@ def convert(
|
||||
export_txt = OutputFormat.TEXT in to_formats
|
||||
export_doctags = OutputFormat.DOCTAGS in to_formats
|
||||
|
||||
ocr_options = ocr_factory.create_options(
|
||||
kind=str(ocr_engine.value), # type:ignore
|
||||
ocr_factory = get_ocr_factory(allow_external_plugins=allow_external_plugins)
|
||||
ocr_options: OcrOptions = ocr_factory.create_options(
|
||||
kind=ocr_engine,
|
||||
force_full_page_ocr=force_ocr,
|
||||
)
|
||||
|
||||
@ -377,6 +420,7 @@ def convert(
|
||||
|
||||
accelerator_options = AcceleratorOptions(num_threads=num_threads, device=device)
|
||||
pipeline_options = PdfPipelineOptions(
|
||||
allow_external_plugins=allow_external_plugins,
|
||||
enable_remote_services=enable_remote_services,
|
||||
accelerator_options=accelerator_options,
|
||||
do_ocr=ocr,
|
||||
|
@ -322,6 +322,7 @@ class PipelineOptions(BaseModel):
|
||||
document_timeout: Optional[float] = None
|
||||
accelerator_options: AcceleratorOptions = AcceleratorOptions()
|
||||
enable_remote_services: bool = False
|
||||
allow_external_plugins: bool = False
|
||||
|
||||
|
||||
class PaginatedPipelineOptions(PipelineOptions):
|
||||
|
@ -9,17 +9,17 @@ from docling.models.factories.picture_description_factory import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_ocr_factory():
|
||||
@lru_cache()
|
||||
def get_ocr_factory(allow_external_plugins: bool = False):
|
||||
factory = OcrFactory()
|
||||
factory.load_from_plugins()
|
||||
factory.load_from_plugins(allow_external_plugins=allow_external_plugins)
|
||||
logger.info("Registered ocr engines: %r", factory.registered_kind)
|
||||
return factory
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_picture_description_factory():
|
||||
@lru_cache()
|
||||
def get_picture_description_factory(allow_external_plugins: bool = False):
|
||||
factory = PictureDescriptionFactory()
|
||||
factory.load_from_plugins()
|
||||
factory.load_from_plugins(allow_external_plugins=allow_external_plugins)
|
||||
logger.info("Registered picture descriptions: %r", factory.registered_kind)
|
||||
return factory
|
||||
|
@ -4,6 +4,7 @@ from abc import ABCMeta
|
||||
from typing import Generic, Optional, Type, TypeVar
|
||||
|
||||
from pluggy import PluginManager
|
||||
from pydantic import BaseModel
|
||||
|
||||
from docling.datamodel.pipeline_options import BaseOptions
|
||||
from docling.models.base_model import BaseModelWithOptions
|
||||
@ -14,6 +15,12 @@ A = TypeVar("A", bound=BaseModelWithOptions)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FactoryMeta(BaseModel):
|
||||
kind: str
|
||||
plugin_name: str
|
||||
module: str
|
||||
|
||||
|
||||
class BaseFactory(Generic[A], metaclass=ABCMeta):
|
||||
default_plugin_name = "docling"
|
||||
|
||||
@ -22,6 +29,7 @@ class BaseFactory(Generic[A], metaclass=ABCMeta):
|
||||
self.plugin_attr_name = plugin_attr_name
|
||||
|
||||
self._classes: dict[Type[BaseOptions], Type[A]] = {}
|
||||
self._meta: dict[Type[BaseOptions], FactoryMeta] = {}
|
||||
|
||||
@property
|
||||
def registered_kind(self) -> list[str]:
|
||||
@ -39,6 +47,10 @@ class BaseFactory(Generic[A], metaclass=ABCMeta):
|
||||
def classes(self):
|
||||
return self._classes
|
||||
|
||||
@property
|
||||
def registered_meta(self):
|
||||
return self._meta
|
||||
|
||||
def create_instance(self, options: BaseOptions, **kwargs) -> A:
|
||||
try:
|
||||
_cls = self._classes[type(options)]
|
||||
@ -62,7 +74,7 @@ class BaseFactory(Generic[A], metaclass=ABCMeta):
|
||||
|
||||
return f"No class found with the name {kind!r}, known classes are:\n{msg_str}"
|
||||
|
||||
def register(self, cls: Type[A]):
|
||||
def register(self, cls: Type[A], plugin_name: str, plugin_module_name: str):
|
||||
opt_type = cls.get_options_type()
|
||||
|
||||
if opt_type in self._classes:
|
||||
@ -71,14 +83,28 @@ class BaseFactory(Generic[A], metaclass=ABCMeta):
|
||||
)
|
||||
|
||||
self._classes[opt_type] = cls
|
||||
self._meta[opt_type] = FactoryMeta(
|
||||
kind=opt_type.kind, plugin_name=plugin_name, module=plugin_module_name
|
||||
)
|
||||
|
||||
def load_from_plugins(self, plugin_name: Optional[str] = None):
|
||||
def load_from_plugins(
|
||||
self, plugin_name: Optional[str] = None, allow_external_plugins: bool = False
|
||||
):
|
||||
plugin_name = plugin_name or self.plugin_name
|
||||
|
||||
plugin_manager = PluginManager(plugin_name)
|
||||
plugin_manager.load_setuptools_entrypoints(plugin_name)
|
||||
|
||||
for plugin_name, plugin_module in plugin_manager.list_name_plugin():
|
||||
plugin_module_name = str(plugin_module.__name__) # type: ignore
|
||||
|
||||
if not allow_external_plugins and not plugin_module_name.startswith(
|
||||
"docling."
|
||||
):
|
||||
logger.warning(
|
||||
f"The plugin {plugin_name} will not be loaded because Docling is being executed with allow_external_plugins=false."
|
||||
)
|
||||
continue
|
||||
|
||||
attr = getattr(plugin_module, self.plugin_attr_name, None)
|
||||
|
||||
@ -86,11 +112,11 @@ class BaseFactory(Generic[A], metaclass=ABCMeta):
|
||||
logger.info("Loading plugin %r", plugin_name)
|
||||
|
||||
config = attr()
|
||||
self.process_plugin(config)
|
||||
self.process_plugin(config, plugin_name, plugin_module_name)
|
||||
|
||||
def process_plugin(self, config):
|
||||
def process_plugin(self, config, plugin_name: str, plugin_module_name: str):
|
||||
for item in config[self.plugin_attr_name]:
|
||||
try:
|
||||
self.register(item)
|
||||
self.register(item, plugin_name, plugin_module_name)
|
||||
except ValueError:
|
||||
logger.warning("%r already registered", item)
|
||||
|
@ -9,19 +9,3 @@ logger = logging.getLogger(__name__)
|
||||
class OcrFactory(BaseFactory[BaseOcrModel]):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__("ocr_engines", *args, **kwargs)
|
||||
|
||||
|
||||
# def on_class_not_found(self, kind: str, *args, **kwargs):
|
||||
|
||||
# raise NoSuchOcrEngine(kind, self.registered_kind)
|
||||
|
||||
|
||||
# class NoSuchOcrEngine(Exception):
|
||||
# def __init__(self, ocr_engine_kind, known_engines=None):
|
||||
# if known_engines is None:
|
||||
# known_engines = []
|
||||
# super(NoSuchOcrEngine, self).__init__(
|
||||
# "No OCR engine found with the name '%s', known engines are: %r",
|
||||
# ocr_engine_kind,
|
||||
# [cls.__name__ for cls in known_engines],
|
||||
# )
|
||||
|
@ -146,7 +146,9 @@ class StandardPdfPipeline(PaginatedPipeline):
|
||||
return output_dir
|
||||
|
||||
def get_ocr_model(self, artifacts_path: Optional[Path] = None) -> BaseOcrModel:
|
||||
factory = get_ocr_factory()
|
||||
factory = get_ocr_factory(
|
||||
allow_external_plugins=self.pipeline_options.allow_external_plugins
|
||||
)
|
||||
return factory.create_instance(
|
||||
options=self.pipeline_options.ocr_options,
|
||||
enabled=self.pipeline_options.do_ocr,
|
||||
@ -157,7 +159,9 @@ class StandardPdfPipeline(PaginatedPipeline):
|
||||
def get_picture_description_model(
|
||||
self, artifacts_path: Optional[Path] = None
|
||||
) -> Optional[PictureDescriptionBaseModel]:
|
||||
factory = get_picture_description_factory()
|
||||
factory = get_picture_description_factory(
|
||||
allow_external_plugins=self.pipeline_options.allow_external_plugins
|
||||
)
|
||||
return factory.create_instance(
|
||||
options=self.pipeline_options.picture_description_options,
|
||||
enabled=self.pipeline_options.do_picture_description,
|
||||
|
Loading…
Reference in New Issue
Block a user