diff --git a/docling/models/easyocr_model.py b/docling/models/easyocr_model.py index 45de8f1b..9b1b2a02 100644 --- a/docling/models/easyocr_model.py +++ b/docling/models/easyocr_model.py @@ -31,6 +31,7 @@ class EasyOcrModel(BaseOcrModel): def __init__( self, enabled: bool, + artifacts_path: Optional[Path], options: EasyOcrOptions, accelerator_options: AcceleratorOptions, ): @@ -68,12 +69,18 @@ class EasyOcrModel(BaseOcrModel): ) use_gpu = self.options.use_gpu + download_enabled = self.options.download_enabled + model_storage_directory = self.options.model_storage_directory + if artifacts_path is not None and model_storage_directory is None: + download_enabled = False + model_storage_directory = str(artifacts_path / self._model_repo_folder) + self.reader = easyocr.Reader( lang_list=self.options.lang, gpu=use_gpu, - model_storage_directory=self.options.model_storage_directory, + model_storage_directory=model_storage_directory, recog_network=self.options.recog_network, - download_enabled=self.options.download_enabled, + download_enabled=download_enabled, verbose=False, ) diff --git a/docling/pipeline/standard_pdf_pipeline.py b/docling/pipeline/standard_pdf_pipeline.py index 8eb2866f..fbdc22c3 100644 --- a/docling/pipeline/standard_pdf_pipeline.py +++ b/docling/pipeline/standard_pdf_pipeline.py @@ -64,7 +64,7 @@ class StandardPdfPipeline(PaginatedPipeline): self.glm_model = GlmModel(options=GlmOptions()) - if (ocr_model := self.get_ocr_model()) is None: + if (ocr_model := self.get_ocr_model(artifacts_path=artifacts_path)) is None: raise RuntimeError( f"The specified OCR kind is not supported: {pipeline_options.ocr_options.kind}." ) @@ -157,10 +157,13 @@ class StandardPdfPipeline(PaginatedPipeline): return local_dir - def get_ocr_model(self) -> Optional[BaseOcrModel]: + def get_ocr_model( + self, artifacts_path: Optional[Path] = None + ) -> Optional[BaseOcrModel]: if isinstance(self.pipeline_options.ocr_options, EasyOcrOptions): return EasyOcrModel( enabled=self.pipeline_options.do_ocr, + artifacts_path=artifacts_path, options=self.pipeline_options.ocr_options, accelerator_options=self.pipeline_options.accelerator_options, )