mirror of
https://github.com/DS4SD/docling.git
synced 2025-07-30 14:04:27 +00:00
add allow_external_plugins option
Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>
This commit is contained in:
parent
66fe0049fb
commit
9b4c2e3fdf
@ -9,6 +9,7 @@ import warnings
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Annotated, Dict, Iterable, List, Optional, Type
|
from typing import Annotated, Dict, Iterable, List, Optional, Type
|
||||||
|
|
||||||
|
import rich.table
|
||||||
import typer
|
import typer
|
||||||
from docling_core.types.doc import ImageRefMode
|
from docling_core.types.doc import ImageRefMode
|
||||||
from docling_core.utils.file import resolve_source_to_path
|
from docling_core.utils.file import resolve_source_to_path
|
||||||
@ -29,6 +30,7 @@ from docling.datamodel.pipeline_options import (
|
|||||||
AcceleratorDevice,
|
AcceleratorDevice,
|
||||||
AcceleratorOptions,
|
AcceleratorOptions,
|
||||||
EasyOcrOptions,
|
EasyOcrOptions,
|
||||||
|
OcrOptions,
|
||||||
PdfBackend,
|
PdfBackend,
|
||||||
PdfPipelineOptions,
|
PdfPipelineOptions,
|
||||||
TableFormerMode,
|
TableFormerMode,
|
||||||
@ -43,10 +45,11 @@ warnings.filterwarnings(action="ignore", category=FutureWarning, module="easyocr
|
|||||||
_log = logging.getLogger(__name__)
|
_log = logging.getLogger(__name__)
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
|
|
||||||
|
console = Console()
|
||||||
err_console = Console(stderr=True)
|
err_console = Console(stderr=True)
|
||||||
|
|
||||||
ocr_factory = get_ocr_factory()
|
ocr_factory_internal = get_ocr_factory(allow_external_plugins=False)
|
||||||
ocr_engines_enum = ocr_factory.get_enum()
|
ocr_engines_enum_internal = ocr_factory_internal.get_enum()
|
||||||
|
|
||||||
app = typer.Typer(
|
app = typer.Typer(
|
||||||
name="Docling",
|
name="Docling",
|
||||||
@ -74,6 +77,24 @@ def version_callback(value: bool):
|
|||||||
raise typer.Exit()
|
raise typer.Exit()
|
||||||
|
|
||||||
|
|
||||||
|
def show_external_plugins_callback(value: bool):
|
||||||
|
if value:
|
||||||
|
ocr_factory_all = get_ocr_factory(allow_external_plugins=True)
|
||||||
|
table = rich.table.Table(title="Available OCR engines")
|
||||||
|
table.add_column("Name", justify="right")
|
||||||
|
table.add_column("Plugin")
|
||||||
|
table.add_column("Package")
|
||||||
|
for meta in ocr_factory_all.registered_meta.values():
|
||||||
|
if not meta.module.startswith("docling."):
|
||||||
|
table.add_row(
|
||||||
|
f"[bold]{meta.kind}[/bold]",
|
||||||
|
meta.plugin_name,
|
||||||
|
meta.module.split(".")[0],
|
||||||
|
)
|
||||||
|
rich.print(table)
|
||||||
|
raise typer.Exit()
|
||||||
|
|
||||||
|
|
||||||
def export_documents(
|
def export_documents(
|
||||||
conv_results: Iterable[ConversionResult],
|
conv_results: Iterable[ConversionResult],
|
||||||
output_dir: Path,
|
output_dir: Path,
|
||||||
@ -191,10 +212,16 @@ def convert(
|
|||||||
help="Replace any existing text with OCR generated text over the full content.",
|
help="Replace any existing text with OCR generated text over the full content.",
|
||||||
),
|
),
|
||||||
] = False,
|
] = False,
|
||||||
ocr_engine: Annotated[ # type: ignore
|
ocr_engine: Annotated[
|
||||||
ocr_engines_enum,
|
str,
|
||||||
# ocr_factory.get_registered_enum(),
|
typer.Option(
|
||||||
typer.Option(..., help="The OCR engine to use."),
|
...,
|
||||||
|
help=(
|
||||||
|
f"The OCR engine to use. When --allow-external-plugins is *not* set, the available values are: "
|
||||||
|
f"{', '.join((o.value for o in ocr_engines_enum_internal))}. "
|
||||||
|
f"Use the option --show-external-plugins to see the options allowed with external plugins."
|
||||||
|
),
|
||||||
|
),
|
||||||
] = EasyOcrOptions.kind,
|
] = EasyOcrOptions.kind,
|
||||||
ocr_lang: Annotated[
|
ocr_lang: Annotated[
|
||||||
Optional[str],
|
Optional[str],
|
||||||
@ -239,6 +266,21 @@ def convert(
|
|||||||
..., help="Must be enabled when using models connecting to remote services."
|
..., help="Must be enabled when using models connecting to remote services."
|
||||||
),
|
),
|
||||||
] = False,
|
] = False,
|
||||||
|
allow_external_plugins: Annotated[
|
||||||
|
bool,
|
||||||
|
typer.Option(
|
||||||
|
..., help="Must be enabled for loading modules from third-party plugins."
|
||||||
|
),
|
||||||
|
] = False,
|
||||||
|
show_external_plugins: Annotated[
|
||||||
|
bool,
|
||||||
|
typer.Option(
|
||||||
|
...,
|
||||||
|
help="List the third-party plugins which are available when the option --allow-external-plugins is set.",
|
||||||
|
callback=show_external_plugins_callback,
|
||||||
|
is_eager=True,
|
||||||
|
),
|
||||||
|
] = False,
|
||||||
abort_on_error: Annotated[
|
abort_on_error: Annotated[
|
||||||
bool,
|
bool,
|
||||||
typer.Option(
|
typer.Option(
|
||||||
@ -366,8 +408,9 @@ def convert(
|
|||||||
export_txt = OutputFormat.TEXT in to_formats
|
export_txt = OutputFormat.TEXT in to_formats
|
||||||
export_doctags = OutputFormat.DOCTAGS in to_formats
|
export_doctags = OutputFormat.DOCTAGS in to_formats
|
||||||
|
|
||||||
ocr_options = ocr_factory.create_options(
|
ocr_factory = get_ocr_factory(allow_external_plugins=allow_external_plugins)
|
||||||
kind=str(ocr_engine.value), # type:ignore
|
ocr_options: OcrOptions = ocr_factory.create_options(
|
||||||
|
kind=ocr_engine,
|
||||||
force_full_page_ocr=force_ocr,
|
force_full_page_ocr=force_ocr,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -377,6 +420,7 @@ def convert(
|
|||||||
|
|
||||||
accelerator_options = AcceleratorOptions(num_threads=num_threads, device=device)
|
accelerator_options = AcceleratorOptions(num_threads=num_threads, device=device)
|
||||||
pipeline_options = PdfPipelineOptions(
|
pipeline_options = PdfPipelineOptions(
|
||||||
|
allow_external_plugins=allow_external_plugins,
|
||||||
enable_remote_services=enable_remote_services,
|
enable_remote_services=enable_remote_services,
|
||||||
accelerator_options=accelerator_options,
|
accelerator_options=accelerator_options,
|
||||||
do_ocr=ocr,
|
do_ocr=ocr,
|
||||||
|
@ -322,6 +322,7 @@ class PipelineOptions(BaseModel):
|
|||||||
document_timeout: Optional[float] = None
|
document_timeout: Optional[float] = None
|
||||||
accelerator_options: AcceleratorOptions = AcceleratorOptions()
|
accelerator_options: AcceleratorOptions = AcceleratorOptions()
|
||||||
enable_remote_services: bool = False
|
enable_remote_services: bool = False
|
||||||
|
allow_external_plugins: bool = False
|
||||||
|
|
||||||
|
|
||||||
class PaginatedPipelineOptions(PipelineOptions):
|
class PaginatedPipelineOptions(PipelineOptions):
|
||||||
|
@ -9,17 +9,17 @@ from docling.models.factories.picture_description_factory import (
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=1)
|
@lru_cache()
|
||||||
def get_ocr_factory():
|
def get_ocr_factory(allow_external_plugins: bool = False):
|
||||||
factory = OcrFactory()
|
factory = OcrFactory()
|
||||||
factory.load_from_plugins()
|
factory.load_from_plugins(allow_external_plugins=allow_external_plugins)
|
||||||
logger.info("Registered ocr engines: %r", factory.registered_kind)
|
logger.info("Registered ocr engines: %r", factory.registered_kind)
|
||||||
return factory
|
return factory
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=1)
|
@lru_cache()
|
||||||
def get_picture_description_factory():
|
def get_picture_description_factory(allow_external_plugins: bool = False):
|
||||||
factory = PictureDescriptionFactory()
|
factory = PictureDescriptionFactory()
|
||||||
factory.load_from_plugins()
|
factory.load_from_plugins(allow_external_plugins=allow_external_plugins)
|
||||||
logger.info("Registered picture descriptions: %r", factory.registered_kind)
|
logger.info("Registered picture descriptions: %r", factory.registered_kind)
|
||||||
return factory
|
return factory
|
||||||
|
@ -4,6 +4,7 @@ from abc import ABCMeta
|
|||||||
from typing import Generic, Optional, Type, TypeVar
|
from typing import Generic, Optional, Type, TypeVar
|
||||||
|
|
||||||
from pluggy import PluginManager
|
from pluggy import PluginManager
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from docling.datamodel.pipeline_options import BaseOptions
|
from docling.datamodel.pipeline_options import BaseOptions
|
||||||
from docling.models.base_model import BaseModelWithOptions
|
from docling.models.base_model import BaseModelWithOptions
|
||||||
@ -14,6 +15,12 @@ A = TypeVar("A", bound=BaseModelWithOptions)
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class FactoryMeta(BaseModel):
|
||||||
|
kind: str
|
||||||
|
plugin_name: str
|
||||||
|
module: str
|
||||||
|
|
||||||
|
|
||||||
class BaseFactory(Generic[A], metaclass=ABCMeta):
|
class BaseFactory(Generic[A], metaclass=ABCMeta):
|
||||||
default_plugin_name = "docling"
|
default_plugin_name = "docling"
|
||||||
|
|
||||||
@ -22,6 +29,7 @@ class BaseFactory(Generic[A], metaclass=ABCMeta):
|
|||||||
self.plugin_attr_name = plugin_attr_name
|
self.plugin_attr_name = plugin_attr_name
|
||||||
|
|
||||||
self._classes: dict[Type[BaseOptions], Type[A]] = {}
|
self._classes: dict[Type[BaseOptions], Type[A]] = {}
|
||||||
|
self._meta: dict[Type[BaseOptions], FactoryMeta] = {}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def registered_kind(self) -> list[str]:
|
def registered_kind(self) -> list[str]:
|
||||||
@ -39,6 +47,10 @@ class BaseFactory(Generic[A], metaclass=ABCMeta):
|
|||||||
def classes(self):
|
def classes(self):
|
||||||
return self._classes
|
return self._classes
|
||||||
|
|
||||||
|
@property
|
||||||
|
def registered_meta(self):
|
||||||
|
return self._meta
|
||||||
|
|
||||||
def create_instance(self, options: BaseOptions, **kwargs) -> A:
|
def create_instance(self, options: BaseOptions, **kwargs) -> A:
|
||||||
try:
|
try:
|
||||||
_cls = self._classes[type(options)]
|
_cls = self._classes[type(options)]
|
||||||
@ -62,7 +74,7 @@ class BaseFactory(Generic[A], metaclass=ABCMeta):
|
|||||||
|
|
||||||
return f"No class found with the name {kind!r}, known classes are:\n{msg_str}"
|
return f"No class found with the name {kind!r}, known classes are:\n{msg_str}"
|
||||||
|
|
||||||
def register(self, cls: Type[A]):
|
def register(self, cls: Type[A], plugin_name: str, plugin_module_name: str):
|
||||||
opt_type = cls.get_options_type()
|
opt_type = cls.get_options_type()
|
||||||
|
|
||||||
if opt_type in self._classes:
|
if opt_type in self._classes:
|
||||||
@ -71,14 +83,28 @@ class BaseFactory(Generic[A], metaclass=ABCMeta):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self._classes[opt_type] = cls
|
self._classes[opt_type] = cls
|
||||||
|
self._meta[opt_type] = FactoryMeta(
|
||||||
|
kind=opt_type.kind, plugin_name=plugin_name, module=plugin_module_name
|
||||||
|
)
|
||||||
|
|
||||||
def load_from_plugins(self, plugin_name: Optional[str] = None):
|
def load_from_plugins(
|
||||||
|
self, plugin_name: Optional[str] = None, allow_external_plugins: bool = False
|
||||||
|
):
|
||||||
plugin_name = plugin_name or self.plugin_name
|
plugin_name = plugin_name or self.plugin_name
|
||||||
|
|
||||||
plugin_manager = PluginManager(plugin_name)
|
plugin_manager = PluginManager(plugin_name)
|
||||||
plugin_manager.load_setuptools_entrypoints(plugin_name)
|
plugin_manager.load_setuptools_entrypoints(plugin_name)
|
||||||
|
|
||||||
for plugin_name, plugin_module in plugin_manager.list_name_plugin():
|
for plugin_name, plugin_module in plugin_manager.list_name_plugin():
|
||||||
|
plugin_module_name = str(plugin_module.__name__) # type: ignore
|
||||||
|
|
||||||
|
if not allow_external_plugins and not plugin_module_name.startswith(
|
||||||
|
"docling."
|
||||||
|
):
|
||||||
|
logger.warning(
|
||||||
|
f"The plugin {plugin_name} will not be loaded because Docling is being executed with allow_external_plugins=false."
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
attr = getattr(plugin_module, self.plugin_attr_name, None)
|
attr = getattr(plugin_module, self.plugin_attr_name, None)
|
||||||
|
|
||||||
@ -86,11 +112,11 @@ class BaseFactory(Generic[A], metaclass=ABCMeta):
|
|||||||
logger.info("Loading plugin %r", plugin_name)
|
logger.info("Loading plugin %r", plugin_name)
|
||||||
|
|
||||||
config = attr()
|
config = attr()
|
||||||
self.process_plugin(config)
|
self.process_plugin(config, plugin_name, plugin_module_name)
|
||||||
|
|
||||||
def process_plugin(self, config):
|
def process_plugin(self, config, plugin_name: str, plugin_module_name: str):
|
||||||
for item in config[self.plugin_attr_name]:
|
for item in config[self.plugin_attr_name]:
|
||||||
try:
|
try:
|
||||||
self.register(item)
|
self.register(item, plugin_name, plugin_module_name)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
logger.warning("%r already registered", item)
|
logger.warning("%r already registered", item)
|
||||||
|
@ -9,19 +9,3 @@ logger = logging.getLogger(__name__)
|
|||||||
class OcrFactory(BaseFactory[BaseOcrModel]):
|
class OcrFactory(BaseFactory[BaseOcrModel]):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__("ocr_engines", *args, **kwargs)
|
super().__init__("ocr_engines", *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
# def on_class_not_found(self, kind: str, *args, **kwargs):
|
|
||||||
|
|
||||||
# raise NoSuchOcrEngine(kind, self.registered_kind)
|
|
||||||
|
|
||||||
|
|
||||||
# class NoSuchOcrEngine(Exception):
|
|
||||||
# def __init__(self, ocr_engine_kind, known_engines=None):
|
|
||||||
# if known_engines is None:
|
|
||||||
# known_engines = []
|
|
||||||
# super(NoSuchOcrEngine, self).__init__(
|
|
||||||
# "No OCR engine found with the name '%s', known engines are: %r",
|
|
||||||
# ocr_engine_kind,
|
|
||||||
# [cls.__name__ for cls in known_engines],
|
|
||||||
# )
|
|
||||||
|
@ -146,7 +146,9 @@ class StandardPdfPipeline(PaginatedPipeline):
|
|||||||
return output_dir
|
return output_dir
|
||||||
|
|
||||||
def get_ocr_model(self, artifacts_path: Optional[Path] = None) -> BaseOcrModel:
|
def get_ocr_model(self, artifacts_path: Optional[Path] = None) -> BaseOcrModel:
|
||||||
factory = get_ocr_factory()
|
factory = get_ocr_factory(
|
||||||
|
allow_external_plugins=self.pipeline_options.allow_external_plugins
|
||||||
|
)
|
||||||
return factory.create_instance(
|
return factory.create_instance(
|
||||||
options=self.pipeline_options.ocr_options,
|
options=self.pipeline_options.ocr_options,
|
||||||
enabled=self.pipeline_options.do_ocr,
|
enabled=self.pipeline_options.do_ocr,
|
||||||
@ -157,7 +159,9 @@ class StandardPdfPipeline(PaginatedPipeline):
|
|||||||
def get_picture_description_model(
|
def get_picture_description_model(
|
||||||
self, artifacts_path: Optional[Path] = None
|
self, artifacts_path: Optional[Path] = None
|
||||||
) -> Optional[PictureDescriptionBaseModel]:
|
) -> Optional[PictureDescriptionBaseModel]:
|
||||||
factory = get_picture_description_factory()
|
factory = get_picture_description_factory(
|
||||||
|
allow_external_plugins=self.pipeline_options.allow_external_plugins
|
||||||
|
)
|
||||||
return factory.create_instance(
|
return factory.create_instance(
|
||||||
options=self.pipeline_options.picture_description_options,
|
options=self.pipeline_options.picture_description_options,
|
||||||
enabled=self.pipeline_options.do_picture_description,
|
enabled=self.pipeline_options.do_picture_description,
|
||||||
|
Loading…
Reference in New Issue
Block a user