switch to create methods

Signed-off-by: Panos Vagenas <pva@zurich.ibm.com>
This commit is contained in:
Panos Vagenas 2025-02-25 14:43:15 +01:00
parent 3844f2a5cb
commit 1553a125dc
3 changed files with 19 additions and 31 deletions

View File

@ -366,8 +366,10 @@ def convert(
export_txt = OutputFormat.TEXT in to_formats
export_doctags = OutputFormat.DOCTAGS in to_formats
ocr_options_class = ocr_factory.get_options_class(kind=str(ocr_engine.value)) # type: ignore
ocr_options = ocr_options_class(force_full_page_ocr=force_ocr)
ocr_options = ocr_factory.create_options(
kind=str(ocr_engine.value), # type:ignore
force_full_page_ocr=force_ocr,
)
ocr_lang_list = _split_list(ocr_lang)
if ocr_lang_list is not None:

View File

@ -39,25 +39,20 @@ class BaseFactory(Generic[A], metaclass=ABCMeta):
def classes(self):
return self._classes
def get_class(self, options: BaseOptions, *args, **kwargs) -> Type[A]:
def create_instance(self, options: BaseOptions, *args, **kwargs) -> A:
try:
return self._classes[type(options)]
_cls = self._classes[type(options)]
return _cls(*args, **kwargs)
except KeyError:
return self.on_class_not_found(options.kind, *args, **kwargs)
raise RuntimeError(self._err_msg_on_class_not_found(options.kind))
def get_class_by_kind(self, kind: str, *args, **kwargs) -> Type[A]:
for opt, cls in self._classes.items():
if opt.kind == kind:
return cls
return self.on_class_not_found(kind, *args, **kwargs)
def create_options(self, kind: str, *args, **kwargs) -> BaseOptions:
for opt_cls, _ in self._classes.items():
if opt_cls.kind == kind:
return opt_cls(*args, **kwargs)
raise RuntimeError(self._err_msg_on_class_not_found(kind))
def get_options_class(self, kind: str, *args, **kwargs) -> Type[BaseOptions]:
for opt, cls in self._classes.items():
if opt.kind == kind:
return opt
return self.on_class_not_found(kind, *args, **kwargs)
def on_class_not_found(self, kind: str, *args, **kwargs):
def _err_msg_on_class_not_found(self, kind: str):
msg = []
for opt, cls in self._classes.items():
@ -65,9 +60,7 @@ class BaseFactory(Generic[A], metaclass=ABCMeta):
msg_str = "\n".join(msg)
raise RuntimeError(
f"No class found with the name {kind!r}, known classes are:\n{msg_str}"
)
return f"No class found with the name {kind!r}, known classes are:\n{msg_str}"
def register(self, cls: Type[A]):
opt_type = cls.get_options_type()

View File

@ -147,12 +147,10 @@ class StandardPdfPipeline(PaginatedPipeline):
def get_ocr_model(self, artifacts_path: Optional[Path] = None) -> BaseOcrModel:
factory = get_ocr_factory()
ocr_engine_cls = factory.get_class(options=self.pipeline_options.ocr_options)
return ocr_engine_cls(
return factory.create_instance(
options=self.pipeline_options.ocr_options,
enabled=self.pipeline_options.do_ocr,
artifacts_path=artifacts_path,
options=self.pipeline_options.ocr_options,
accelerator_options=self.pipeline_options.accelerator_options,
)
@ -160,16 +158,11 @@ class StandardPdfPipeline(PaginatedPipeline):
self, artifacts_path: Optional[Path] = None
) -> Optional[PictureDescriptionBaseModel]:
factory = get_picture_description_factory()
options_cls = factory.get_class(
options=self.pipeline_options.picture_description_options
)
return options_cls(
return factory.create_instance(
options=self.pipeline_options.picture_description_options,
enabled=self.pipeline_options.do_picture_description,
enable_remote_services=self.pipeline_options.enable_remote_services,
artifacts_path=artifacts_path,
options=self.pipeline_options.picture_description_options,
accelerator_options=self.pipeline_options.accelerator_options,
)