From 1553a125dc059eda44c947cddd744c1842d926b3 Mon Sep 17 00:00:00 2001 From: Panos Vagenas Date: Tue, 25 Feb 2025 14:43:15 +0100 Subject: [PATCH] switch to create methods Signed-off-by: Panos Vagenas --- docling/cli/main.py | 6 +++-- docling/models/factories/base_factory.py | 29 +++++++++-------------- docling/pipeline/standard_pdf_pipeline.py | 15 ++++-------- 3 files changed, 19 insertions(+), 31 deletions(-) diff --git a/docling/cli/main.py b/docling/cli/main.py index d9bb2e16..9e516b4e 100644 --- a/docling/cli/main.py +++ b/docling/cli/main.py @@ -366,8 +366,10 @@ def convert( export_txt = OutputFormat.TEXT in to_formats export_doctags = OutputFormat.DOCTAGS in to_formats - ocr_options_class = ocr_factory.get_options_class(kind=str(ocr_engine.value)) # type: ignore - ocr_options = ocr_options_class(force_full_page_ocr=force_ocr) + ocr_options = ocr_factory.create_options( + kind=str(ocr_engine.value), # type:ignore + force_full_page_ocr=force_ocr, + ) ocr_lang_list = _split_list(ocr_lang) if ocr_lang_list is not None: diff --git a/docling/models/factories/base_factory.py b/docling/models/factories/base_factory.py index adb3fbb0..c1d34b5b 100644 --- a/docling/models/factories/base_factory.py +++ b/docling/models/factories/base_factory.py @@ -39,25 +39,20 @@ class BaseFactory(Generic[A], metaclass=ABCMeta): def classes(self): return self._classes - def get_class(self, options: BaseOptions, *args, **kwargs) -> Type[A]: + def create_instance(self, options: BaseOptions, *args, **kwargs) -> A: try: - return self._classes[type(options)] + _cls = self._classes[type(options)] + return _cls(*args, **kwargs) except KeyError: - return self.on_class_not_found(options.kind, *args, **kwargs) + raise RuntimeError(self._err_msg_on_class_not_found(options.kind)) - def get_class_by_kind(self, kind: str, *args, **kwargs) -> Type[A]: - for opt, cls in self._classes.items(): - if opt.kind == kind: - return cls - return self.on_class_not_found(kind, *args, **kwargs) + def create_options(self, kind: str, *args, **kwargs) -> BaseOptions: + for opt_cls, _ in self._classes.items(): + if opt_cls.kind == kind: + return opt_cls(*args, **kwargs) + raise RuntimeError(self._err_msg_on_class_not_found(kind)) - def get_options_class(self, kind: str, *args, **kwargs) -> Type[BaseOptions]: - for opt, cls in self._classes.items(): - if opt.kind == kind: - return opt - return self.on_class_not_found(kind, *args, **kwargs) - - def on_class_not_found(self, kind: str, *args, **kwargs): + def _err_msg_on_class_not_found(self, kind: str): msg = [] for opt, cls in self._classes.items(): @@ -65,9 +60,7 @@ class BaseFactory(Generic[A], metaclass=ABCMeta): msg_str = "\n".join(msg) - raise RuntimeError( - f"No class found with the name {kind!r}, known classes are:\n{msg_str}" - ) + return f"No class found with the name {kind!r}, known classes are:\n{msg_str}" def register(self, cls: Type[A]): opt_type = cls.get_options_type() diff --git a/docling/pipeline/standard_pdf_pipeline.py b/docling/pipeline/standard_pdf_pipeline.py index 05e58d07..f8b11f57 100644 --- a/docling/pipeline/standard_pdf_pipeline.py +++ b/docling/pipeline/standard_pdf_pipeline.py @@ -147,12 +147,10 @@ class StandardPdfPipeline(PaginatedPipeline): def get_ocr_model(self, artifacts_path: Optional[Path] = None) -> BaseOcrModel: factory = get_ocr_factory() - ocr_engine_cls = factory.get_class(options=self.pipeline_options.ocr_options) - - return ocr_engine_cls( + return factory.create_instance( + options=self.pipeline_options.ocr_options, enabled=self.pipeline_options.do_ocr, artifacts_path=artifacts_path, - options=self.pipeline_options.ocr_options, accelerator_options=self.pipeline_options.accelerator_options, ) @@ -160,16 +158,11 @@ class StandardPdfPipeline(PaginatedPipeline): self, artifacts_path: Optional[Path] = None ) -> Optional[PictureDescriptionBaseModel]: factory = get_picture_description_factory() - - options_cls = factory.get_class( - options=self.pipeline_options.picture_description_options - ) - - return options_cls( + return factory.create_instance( + options=self.pipeline_options.picture_description_options, enabled=self.pipeline_options.do_picture_description, enable_remote_services=self.pipeline_options.enable_remote_services, artifacts_path=artifacts_path, - options=self.pipeline_options.picture_description_options, accelerator_options=self.pipeline_options.accelerator_options, )