switch to create methods

Signed-off-by: Panos Vagenas <pva@zurich.ibm.com>
This commit is contained in:
Panos Vagenas 2025-02-25 14:43:15 +01:00
parent 3844f2a5cb
commit 1553a125dc
3 changed files with 19 additions and 31 deletions

View File

@ -366,8 +366,10 @@ def convert(
export_txt = OutputFormat.TEXT in to_formats export_txt = OutputFormat.TEXT in to_formats
export_doctags = OutputFormat.DOCTAGS in to_formats export_doctags = OutputFormat.DOCTAGS in to_formats
ocr_options_class = ocr_factory.get_options_class(kind=str(ocr_engine.value)) # type: ignore ocr_options = ocr_factory.create_options(
ocr_options = ocr_options_class(force_full_page_ocr=force_ocr) kind=str(ocr_engine.value), # type:ignore
force_full_page_ocr=force_ocr,
)
ocr_lang_list = _split_list(ocr_lang) ocr_lang_list = _split_list(ocr_lang)
if ocr_lang_list is not None: if ocr_lang_list is not None:

View File

@ -39,25 +39,20 @@ class BaseFactory(Generic[A], metaclass=ABCMeta):
def classes(self): def classes(self):
return self._classes return self._classes
def get_class(self, options: BaseOptions, *args, **kwargs) -> Type[A]: def create_instance(self, options: BaseOptions, *args, **kwargs) -> A:
try: try:
return self._classes[type(options)] _cls = self._classes[type(options)]
return _cls(*args, **kwargs)
except KeyError: except KeyError:
return self.on_class_not_found(options.kind, *args, **kwargs) raise RuntimeError(self._err_msg_on_class_not_found(options.kind))
def get_class_by_kind(self, kind: str, *args, **kwargs) -> Type[A]: def create_options(self, kind: str, *args, **kwargs) -> BaseOptions:
for opt, cls in self._classes.items(): for opt_cls, _ in self._classes.items():
if opt.kind == kind: if opt_cls.kind == kind:
return cls return opt_cls(*args, **kwargs)
return self.on_class_not_found(kind, *args, **kwargs) raise RuntimeError(self._err_msg_on_class_not_found(kind))
def get_options_class(self, kind: str, *args, **kwargs) -> Type[BaseOptions]: def _err_msg_on_class_not_found(self, kind: str):
for opt, cls in self._classes.items():
if opt.kind == kind:
return opt
return self.on_class_not_found(kind, *args, **kwargs)
def on_class_not_found(self, kind: str, *args, **kwargs):
msg = [] msg = []
for opt, cls in self._classes.items(): for opt, cls in self._classes.items():
@ -65,9 +60,7 @@ class BaseFactory(Generic[A], metaclass=ABCMeta):
msg_str = "\n".join(msg) msg_str = "\n".join(msg)
raise RuntimeError( return f"No class found with the name {kind!r}, known classes are:\n{msg_str}"
f"No class found with the name {kind!r}, known classes are:\n{msg_str}"
)
def register(self, cls: Type[A]): def register(self, cls: Type[A]):
opt_type = cls.get_options_type() opt_type = cls.get_options_type()

View File

@ -147,12 +147,10 @@ class StandardPdfPipeline(PaginatedPipeline):
def get_ocr_model(self, artifacts_path: Optional[Path] = None) -> BaseOcrModel: def get_ocr_model(self, artifacts_path: Optional[Path] = None) -> BaseOcrModel:
factory = get_ocr_factory() factory = get_ocr_factory()
ocr_engine_cls = factory.get_class(options=self.pipeline_options.ocr_options) return factory.create_instance(
options=self.pipeline_options.ocr_options,
return ocr_engine_cls(
enabled=self.pipeline_options.do_ocr, enabled=self.pipeline_options.do_ocr,
artifacts_path=artifacts_path, artifacts_path=artifacts_path,
options=self.pipeline_options.ocr_options,
accelerator_options=self.pipeline_options.accelerator_options, accelerator_options=self.pipeline_options.accelerator_options,
) )
@ -160,16 +158,11 @@ class StandardPdfPipeline(PaginatedPipeline):
self, artifacts_path: Optional[Path] = None self, artifacts_path: Optional[Path] = None
) -> Optional[PictureDescriptionBaseModel]: ) -> Optional[PictureDescriptionBaseModel]:
factory = get_picture_description_factory() factory = get_picture_description_factory()
return factory.create_instance(
options_cls = factory.get_class( options=self.pipeline_options.picture_description_options,
options=self.pipeline_options.picture_description_options
)
return options_cls(
enabled=self.pipeline_options.do_picture_description, enabled=self.pipeline_options.do_picture_description,
enable_remote_services=self.pipeline_options.enable_remote_services, enable_remote_services=self.pipeline_options.enable_remote_services,
artifacts_path=artifacts_path, artifacts_path=artifacts_path,
options=self.pipeline_options.picture_description_options,
accelerator_options=self.pipeline_options.accelerator_options, accelerator_options=self.pipeline_options.accelerator_options,
) )