fix: address deprecation warnings of dependencies (#2237)

* switch to dtype instead of torch_dtype

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>

* set __check_model__ to avoid deprecation warnings

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>

* remove dataloaders warnings in easyocr

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>

* suppress with option

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>

---------

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>
This commit is contained in:
Michele Dolfi
2025-09-10 14:38:34 +02:00
committed by GitHub
parent f8cc545bab
commit c6965495a2
6 changed files with 27 additions and 12 deletions

View File

@@ -135,6 +135,8 @@ class EasyOcrOptions(OcrOptions):
recog_network: Optional[str] = "standard" recog_network: Optional[str] = "standard"
download_enabled: bool = True download_enabled: bool = True
suppress_mps_warnings: bool = True
model_config = ConfigDict( model_config = ConfigDict(
extra="forbid", extra="forbid",
protected_namespaces=(), protected_namespaces=(),

View File

@@ -78,6 +78,9 @@ class EasyOcrModel(BaseOcrModel):
download_enabled = False download_enabled = False
model_storage_directory = str(artifacts_path / self._model_repo_folder) model_storage_directory = str(artifacts_path / self._model_repo_folder)
with warnings.catch_warnings():
if self.options.suppress_mps_warnings:
warnings.filterwarnings("ignore", message=".*pin_memory.*MPS.*")
self.reader = easyocr.Reader( self.reader = easyocr.Reader(
lang_list=self.options.lang, lang_list=self.options.lang,
gpu=use_gpu, gpu=use_gpu,
@@ -147,6 +150,13 @@ class EasyOcrModel(BaseOcrModel):
scale=self.scale, cropbox=ocr_rect scale=self.scale, cropbox=ocr_rect
) )
im = numpy.array(high_res_image) im = numpy.array(high_res_image)
with warnings.catch_warnings():
if self.options.suppress_mps_warnings:
warnings.filterwarnings(
"ignore", message=".*pin_memory.*MPS.*"
)
result = self.reader.readtext(im) result = self.reader.readtext(im)
del high_res_image del high_res_image

View File

@@ -67,7 +67,7 @@ class PictureDescriptionVlmModel(
self.model = AutoModelForImageTextToText.from_pretrained( self.model = AutoModelForImageTextToText.from_pretrained(
artifacts_path, artifacts_path,
device_map=self.device, device_map=self.device,
torch_dtype=torch.bfloat16, dtype=torch.bfloat16,
_attn_implementation=( _attn_implementation=(
"flash_attention_2" "flash_attention_2"
if self.device.startswith("cuda") if self.device.startswith("cuda")

View File

@@ -112,7 +112,7 @@ class HuggingFaceTransformersVlmModel(BaseVlmPageModel, HuggingFaceModelDownload
self.vlm_model = model_cls.from_pretrained( self.vlm_model = model_cls.from_pretrained(
artifacts_path, artifacts_path,
device_map=self.device, device_map=self.device,
torch_dtype=self.vlm_options.torch_dtype, dtype=self.vlm_options.torch_dtype,
_attn_implementation=( _attn_implementation=(
"flash_attention_2" "flash_attention_2"
if self.device.startswith("cuda") if self.device.startswith("cuda")

View File

@@ -144,7 +144,7 @@ class NuExtractTransformersModel(BaseVlmModel, HuggingFaceModelDownloadMixin):
self.vlm_model = AutoModelForImageTextToText.from_pretrained( self.vlm_model = AutoModelForImageTextToText.from_pretrained(
artifacts_path, artifacts_path,
device_map=self.device, device_map=self.device,
torch_dtype=self.vlm_options.torch_dtype, dtype=self.vlm_options.torch_dtype,
_attn_implementation=( _attn_implementation=(
"flash_attention_2" "flash_attention_2"
if self.device.startswith("cuda") if self.device.startswith("cuda")

View File

@@ -194,6 +194,9 @@ class ExtractionVlmPipeline(BaseExtractionPipeline):
class ExtractionTemplateFactory(ModelFactory[template]): # type: ignore class ExtractionTemplateFactory(ModelFactory[template]): # type: ignore
__use_examples__ = True # prefer Field(examples=...) when present __use_examples__ = True # prefer Field(examples=...) when present
__use_defaults__ = True # use field defaults instead of random values __use_defaults__ = True # use field defaults instead of random values
__check_model__ = (
True # setting the value to avoid deprecation warnings
)
return ExtractionTemplateFactory.build().model_dump_json(indent=2) # type: ignore return ExtractionTemplateFactory.build().model_dump_json(indent=2) # type: ignore
else: else: