propagate artifacts path usage for ocr models

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>
This commit is contained in:
Michele Dolfi 2025-02-05 17:40:52 +01:00
parent 0ba08adb26
commit a11fd1f157
2 changed files with 14 additions and 4 deletions

View File

@ -31,6 +31,7 @@ class EasyOcrModel(BaseOcrModel):
def __init__( def __init__(
self, self,
enabled: bool, enabled: bool,
artifacts_path: Optional[Path],
options: EasyOcrOptions, options: EasyOcrOptions,
accelerator_options: AcceleratorOptions, accelerator_options: AcceleratorOptions,
): ):
@ -68,12 +69,18 @@ class EasyOcrModel(BaseOcrModel):
) )
use_gpu = self.options.use_gpu use_gpu = self.options.use_gpu
download_enabled = self.options.download_enabled
model_storage_directory = self.options.model_storage_directory
if artifacts_path is not None and model_storage_directory is None:
download_enabled = False
model_storage_directory = str(artifacts_path / self._model_repo_folder)
self.reader = easyocr.Reader( self.reader = easyocr.Reader(
lang_list=self.options.lang, lang_list=self.options.lang,
gpu=use_gpu, gpu=use_gpu,
model_storage_directory=self.options.model_storage_directory, model_storage_directory=model_storage_directory,
recog_network=self.options.recog_network, recog_network=self.options.recog_network,
download_enabled=self.options.download_enabled, download_enabled=download_enabled,
verbose=False, verbose=False,
) )

View File

@ -64,7 +64,7 @@ class StandardPdfPipeline(PaginatedPipeline):
self.glm_model = GlmModel(options=GlmOptions()) self.glm_model = GlmModel(options=GlmOptions())
if (ocr_model := self.get_ocr_model()) is None: if (ocr_model := self.get_ocr_model(artifacts_path=artifacts_path)) is None:
raise RuntimeError( raise RuntimeError(
f"The specified OCR kind is not supported: {pipeline_options.ocr_options.kind}." f"The specified OCR kind is not supported: {pipeline_options.ocr_options.kind}."
) )
@ -157,10 +157,13 @@ class StandardPdfPipeline(PaginatedPipeline):
return local_dir return local_dir
def get_ocr_model(self) -> Optional[BaseOcrModel]: def get_ocr_model(
self, artifacts_path: Optional[Path] = None
) -> Optional[BaseOcrModel]:
if isinstance(self.pipeline_options.ocr_options, EasyOcrOptions): if isinstance(self.pipeline_options.ocr_options, EasyOcrOptions):
return EasyOcrModel( return EasyOcrModel(
enabled=self.pipeline_options.do_ocr, enabled=self.pipeline_options.do_ocr,
artifacts_path=artifacts_path,
options=self.pipeline_options.ocr_options, options=self.pipeline_options.ocr_options,
accelerator_options=self.pipeline_options.accelerator_options, accelerator_options=self.pipeline_options.accelerator_options,
) )