diff --git a/docling/datamodel/pipeline_options.py b/docling/datamodel/pipeline_options.py index b34b7934..58392a5a 100644 --- a/docling/datamodel/pipeline_options.py +++ b/docling/datamodel/pipeline_options.py @@ -135,6 +135,8 @@ class EasyOcrOptions(OcrOptions): recog_network: Optional[str] = "standard" download_enabled: bool = True + suppress_mps_warnings: bool = True + model_config = ConfigDict( extra="forbid", protected_namespaces=(), diff --git a/docling/models/easyocr_model.py b/docling/models/easyocr_model.py index e430465a..36eb5d97 100644 --- a/docling/models/easyocr_model.py +++ b/docling/models/easyocr_model.py @@ -78,14 +78,17 @@ class EasyOcrModel(BaseOcrModel): download_enabled = False model_storage_directory = str(artifacts_path / self._model_repo_folder) - self.reader = easyocr.Reader( - lang_list=self.options.lang, - gpu=use_gpu, - model_storage_directory=model_storage_directory, - recog_network=self.options.recog_network, - download_enabled=download_enabled, - verbose=False, - ) + with warnings.catch_warnings(): + if self.options.suppress_mps_warnings: + warnings.filterwarnings("ignore", message=".*pin_memory.*MPS.*") + self.reader = easyocr.Reader( + lang_list=self.options.lang, + gpu=use_gpu, + model_storage_directory=model_storage_directory, + recog_network=self.options.recog_network, + download_enabled=download_enabled, + verbose=False, + ) @staticmethod def download_models( @@ -147,7 +150,14 @@ class EasyOcrModel(BaseOcrModel): scale=self.scale, cropbox=ocr_rect ) im = numpy.array(high_res_image) - result = self.reader.readtext(im) + + with warnings.catch_warnings(): + if self.options.suppress_mps_warnings: + warnings.filterwarnings( + "ignore", message=".*pin_memory.*MPS.*" + ) + + result = self.reader.readtext(im) del high_res_image del im diff --git a/docling/models/picture_description_vlm_model.py b/docling/models/picture_description_vlm_model.py index 87850773..4b5007fa 100644 --- a/docling/models/picture_description_vlm_model.py +++ b/docling/models/picture_description_vlm_model.py @@ -67,7 +67,7 @@ class PictureDescriptionVlmModel( self.model = AutoModelForImageTextToText.from_pretrained( artifacts_path, device_map=self.device, - torch_dtype=torch.bfloat16, + dtype=torch.bfloat16, _attn_implementation=( "flash_attention_2" if self.device.startswith("cuda") diff --git a/docling/models/vlm_models_inline/hf_transformers_model.py b/docling/models/vlm_models_inline/hf_transformers_model.py index 77cbac05..e7157948 100644 --- a/docling/models/vlm_models_inline/hf_transformers_model.py +++ b/docling/models/vlm_models_inline/hf_transformers_model.py @@ -112,7 +112,7 @@ class HuggingFaceTransformersVlmModel(BaseVlmPageModel, HuggingFaceModelDownload self.vlm_model = model_cls.from_pretrained( artifacts_path, device_map=self.device, - torch_dtype=self.vlm_options.torch_dtype, + dtype=self.vlm_options.torch_dtype, _attn_implementation=( "flash_attention_2" if self.device.startswith("cuda") diff --git a/docling/models/vlm_models_inline/nuextract_transformers_model.py b/docling/models/vlm_models_inline/nuextract_transformers_model.py index f0a0c872..3eb64d49 100644 --- a/docling/models/vlm_models_inline/nuextract_transformers_model.py +++ b/docling/models/vlm_models_inline/nuextract_transformers_model.py @@ -144,7 +144,7 @@ class NuExtractTransformersModel(BaseVlmModel, HuggingFaceModelDownloadMixin): self.vlm_model = AutoModelForImageTextToText.from_pretrained( artifacts_path, device_map=self.device, - torch_dtype=self.vlm_options.torch_dtype, + dtype=self.vlm_options.torch_dtype, _attn_implementation=( "flash_attention_2" if self.device.startswith("cuda") diff --git a/docling/pipeline/extraction_vlm_pipeline.py b/docling/pipeline/extraction_vlm_pipeline.py index f2b7ff9f..68222cd5 100644 --- a/docling/pipeline/extraction_vlm_pipeline.py +++ b/docling/pipeline/extraction_vlm_pipeline.py @@ -194,6 +194,9 @@ class ExtractionVlmPipeline(BaseExtractionPipeline): class ExtractionTemplateFactory(ModelFactory[template]): # type: ignore __use_examples__ = True # prefer Field(examples=...) when present __use_defaults__ = True # use field defaults instead of random values + __check_model__ = ( + True # setting the value to avoid deprecation warnings + ) return ExtractionTemplateFactory.build().model_dump_json(indent=2) # type: ignore else: