mirror of
https://github.com/DS4SD/docling.git
synced 2025-08-02 07:22:14 +00:00
propagate artifacts path usage for ocr models
Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>
This commit is contained in:
parent
0ba08adb26
commit
a11fd1f157
@ -31,6 +31,7 @@ class EasyOcrModel(BaseOcrModel):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
enabled: bool,
|
enabled: bool,
|
||||||
|
artifacts_path: Optional[Path],
|
||||||
options: EasyOcrOptions,
|
options: EasyOcrOptions,
|
||||||
accelerator_options: AcceleratorOptions,
|
accelerator_options: AcceleratorOptions,
|
||||||
):
|
):
|
||||||
@ -68,12 +69,18 @@ class EasyOcrModel(BaseOcrModel):
|
|||||||
)
|
)
|
||||||
use_gpu = self.options.use_gpu
|
use_gpu = self.options.use_gpu
|
||||||
|
|
||||||
|
download_enabled = self.options.download_enabled
|
||||||
|
model_storage_directory = self.options.model_storage_directory
|
||||||
|
if artifacts_path is not None and model_storage_directory is None:
|
||||||
|
download_enabled = False
|
||||||
|
model_storage_directory = str(artifacts_path / self._model_repo_folder)
|
||||||
|
|
||||||
self.reader = easyocr.Reader(
|
self.reader = easyocr.Reader(
|
||||||
lang_list=self.options.lang,
|
lang_list=self.options.lang,
|
||||||
gpu=use_gpu,
|
gpu=use_gpu,
|
||||||
model_storage_directory=self.options.model_storage_directory,
|
model_storage_directory=model_storage_directory,
|
||||||
recog_network=self.options.recog_network,
|
recog_network=self.options.recog_network,
|
||||||
download_enabled=self.options.download_enabled,
|
download_enabled=download_enabled,
|
||||||
verbose=False,
|
verbose=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -64,7 +64,7 @@ class StandardPdfPipeline(PaginatedPipeline):
|
|||||||
|
|
||||||
self.glm_model = GlmModel(options=GlmOptions())
|
self.glm_model = GlmModel(options=GlmOptions())
|
||||||
|
|
||||||
if (ocr_model := self.get_ocr_model()) is None:
|
if (ocr_model := self.get_ocr_model(artifacts_path=artifacts_path)) is None:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"The specified OCR kind is not supported: {pipeline_options.ocr_options.kind}."
|
f"The specified OCR kind is not supported: {pipeline_options.ocr_options.kind}."
|
||||||
)
|
)
|
||||||
@ -157,10 +157,13 @@ class StandardPdfPipeline(PaginatedPipeline):
|
|||||||
|
|
||||||
return local_dir
|
return local_dir
|
||||||
|
|
||||||
def get_ocr_model(self) -> Optional[BaseOcrModel]:
|
def get_ocr_model(
|
||||||
|
self, artifacts_path: Optional[Path] = None
|
||||||
|
) -> Optional[BaseOcrModel]:
|
||||||
if isinstance(self.pipeline_options.ocr_options, EasyOcrOptions):
|
if isinstance(self.pipeline_options.ocr_options, EasyOcrOptions):
|
||||||
return EasyOcrModel(
|
return EasyOcrModel(
|
||||||
enabled=self.pipeline_options.do_ocr,
|
enabled=self.pipeline_options.do_ocr,
|
||||||
|
artifacts_path=artifacts_path,
|
||||||
options=self.pipeline_options.ocr_options,
|
options=self.pipeline_options.ocr_options,
|
||||||
accelerator_options=self.pipeline_options.accelerator_options,
|
accelerator_options=self.pipeline_options.accelerator_options,
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user