mirror of
https://github.com/DS4SD/docling.git
synced 2025-12-08 12:48:28 +00:00
fix: address deprecation warnings of dependencies (#2237)
* switch to dtype instead of torch_dtype Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * set __check_model__ to avoid deprecation warnings Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * remove dataloaders warnings in easyocr Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * suppress with option Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> --------- Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>
This commit is contained in:
@@ -135,6 +135,8 @@ class EasyOcrOptions(OcrOptions):
|
||||
recog_network: Optional[str] = "standard"
|
||||
download_enabled: bool = True
|
||||
|
||||
suppress_mps_warnings: bool = True
|
||||
|
||||
model_config = ConfigDict(
|
||||
extra="forbid",
|
||||
protected_namespaces=(),
|
||||
|
||||
@@ -78,14 +78,17 @@ class EasyOcrModel(BaseOcrModel):
|
||||
download_enabled = False
|
||||
model_storage_directory = str(artifacts_path / self._model_repo_folder)
|
||||
|
||||
self.reader = easyocr.Reader(
|
||||
lang_list=self.options.lang,
|
||||
gpu=use_gpu,
|
||||
model_storage_directory=model_storage_directory,
|
||||
recog_network=self.options.recog_network,
|
||||
download_enabled=download_enabled,
|
||||
verbose=False,
|
||||
)
|
||||
with warnings.catch_warnings():
|
||||
if self.options.suppress_mps_warnings:
|
||||
warnings.filterwarnings("ignore", message=".*pin_memory.*MPS.*")
|
||||
self.reader = easyocr.Reader(
|
||||
lang_list=self.options.lang,
|
||||
gpu=use_gpu,
|
||||
model_storage_directory=model_storage_directory,
|
||||
recog_network=self.options.recog_network,
|
||||
download_enabled=download_enabled,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def download_models(
|
||||
@@ -147,7 +150,14 @@ class EasyOcrModel(BaseOcrModel):
|
||||
scale=self.scale, cropbox=ocr_rect
|
||||
)
|
||||
im = numpy.array(high_res_image)
|
||||
result = self.reader.readtext(im)
|
||||
|
||||
with warnings.catch_warnings():
|
||||
if self.options.suppress_mps_warnings:
|
||||
warnings.filterwarnings(
|
||||
"ignore", message=".*pin_memory.*MPS.*"
|
||||
)
|
||||
|
||||
result = self.reader.readtext(im)
|
||||
|
||||
del high_res_image
|
||||
del im
|
||||
|
||||
@@ -67,7 +67,7 @@ class PictureDescriptionVlmModel(
|
||||
self.model = AutoModelForImageTextToText.from_pretrained(
|
||||
artifacts_path,
|
||||
device_map=self.device,
|
||||
torch_dtype=torch.bfloat16,
|
||||
dtype=torch.bfloat16,
|
||||
_attn_implementation=(
|
||||
"flash_attention_2"
|
||||
if self.device.startswith("cuda")
|
||||
|
||||
@@ -112,7 +112,7 @@ class HuggingFaceTransformersVlmModel(BaseVlmPageModel, HuggingFaceModelDownload
|
||||
self.vlm_model = model_cls.from_pretrained(
|
||||
artifacts_path,
|
||||
device_map=self.device,
|
||||
torch_dtype=self.vlm_options.torch_dtype,
|
||||
dtype=self.vlm_options.torch_dtype,
|
||||
_attn_implementation=(
|
||||
"flash_attention_2"
|
||||
if self.device.startswith("cuda")
|
||||
|
||||
@@ -144,7 +144,7 @@ class NuExtractTransformersModel(BaseVlmModel, HuggingFaceModelDownloadMixin):
|
||||
self.vlm_model = AutoModelForImageTextToText.from_pretrained(
|
||||
artifacts_path,
|
||||
device_map=self.device,
|
||||
torch_dtype=self.vlm_options.torch_dtype,
|
||||
dtype=self.vlm_options.torch_dtype,
|
||||
_attn_implementation=(
|
||||
"flash_attention_2"
|
||||
if self.device.startswith("cuda")
|
||||
|
||||
@@ -194,6 +194,9 @@ class ExtractionVlmPipeline(BaseExtractionPipeline):
|
||||
class ExtractionTemplateFactory(ModelFactory[template]): # type: ignore
|
||||
__use_examples__ = True # prefer Field(examples=...) when present
|
||||
__use_defaults__ = True # use field defaults instead of random values
|
||||
__check_model__ = (
|
||||
True # setting the value to avoid deprecation warnings
|
||||
)
|
||||
|
||||
return ExtractionTemplateFactory.build().model_dump_json(indent=2) # type: ignore
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user