mirror of
https://github.com/DS4SD/docling.git
synced 2025-12-08 12:48:28 +00:00
fix: address deprecation warnings of dependencies (#2237)
* switch to dtype instead of torch_dtype Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * set __check_model__ to avoid deprecation warnings Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * remove dataloaders warnings in easyocr Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> * suppress with option Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> --------- Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>
This commit is contained in:
@@ -135,6 +135,8 @@ class EasyOcrOptions(OcrOptions):
|
|||||||
recog_network: Optional[str] = "standard"
|
recog_network: Optional[str] = "standard"
|
||||||
download_enabled: bool = True
|
download_enabled: bool = True
|
||||||
|
|
||||||
|
suppress_mps_warnings: bool = True
|
||||||
|
|
||||||
model_config = ConfigDict(
|
model_config = ConfigDict(
|
||||||
extra="forbid",
|
extra="forbid",
|
||||||
protected_namespaces=(),
|
protected_namespaces=(),
|
||||||
|
|||||||
@@ -78,14 +78,17 @@ class EasyOcrModel(BaseOcrModel):
|
|||||||
download_enabled = False
|
download_enabled = False
|
||||||
model_storage_directory = str(artifacts_path / self._model_repo_folder)
|
model_storage_directory = str(artifacts_path / self._model_repo_folder)
|
||||||
|
|
||||||
self.reader = easyocr.Reader(
|
with warnings.catch_warnings():
|
||||||
lang_list=self.options.lang,
|
if self.options.suppress_mps_warnings:
|
||||||
gpu=use_gpu,
|
warnings.filterwarnings("ignore", message=".*pin_memory.*MPS.*")
|
||||||
model_storage_directory=model_storage_directory,
|
self.reader = easyocr.Reader(
|
||||||
recog_network=self.options.recog_network,
|
lang_list=self.options.lang,
|
||||||
download_enabled=download_enabled,
|
gpu=use_gpu,
|
||||||
verbose=False,
|
model_storage_directory=model_storage_directory,
|
||||||
)
|
recog_network=self.options.recog_network,
|
||||||
|
download_enabled=download_enabled,
|
||||||
|
verbose=False,
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def download_models(
|
def download_models(
|
||||||
@@ -147,7 +150,14 @@ class EasyOcrModel(BaseOcrModel):
|
|||||||
scale=self.scale, cropbox=ocr_rect
|
scale=self.scale, cropbox=ocr_rect
|
||||||
)
|
)
|
||||||
im = numpy.array(high_res_image)
|
im = numpy.array(high_res_image)
|
||||||
result = self.reader.readtext(im)
|
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
if self.options.suppress_mps_warnings:
|
||||||
|
warnings.filterwarnings(
|
||||||
|
"ignore", message=".*pin_memory.*MPS.*"
|
||||||
|
)
|
||||||
|
|
||||||
|
result = self.reader.readtext(im)
|
||||||
|
|
||||||
del high_res_image
|
del high_res_image
|
||||||
del im
|
del im
|
||||||
|
|||||||
@@ -67,7 +67,7 @@ class PictureDescriptionVlmModel(
|
|||||||
self.model = AutoModelForImageTextToText.from_pretrained(
|
self.model = AutoModelForImageTextToText.from_pretrained(
|
||||||
artifacts_path,
|
artifacts_path,
|
||||||
device_map=self.device,
|
device_map=self.device,
|
||||||
torch_dtype=torch.bfloat16,
|
dtype=torch.bfloat16,
|
||||||
_attn_implementation=(
|
_attn_implementation=(
|
||||||
"flash_attention_2"
|
"flash_attention_2"
|
||||||
if self.device.startswith("cuda")
|
if self.device.startswith("cuda")
|
||||||
|
|||||||
@@ -112,7 +112,7 @@ class HuggingFaceTransformersVlmModel(BaseVlmPageModel, HuggingFaceModelDownload
|
|||||||
self.vlm_model = model_cls.from_pretrained(
|
self.vlm_model = model_cls.from_pretrained(
|
||||||
artifacts_path,
|
artifacts_path,
|
||||||
device_map=self.device,
|
device_map=self.device,
|
||||||
torch_dtype=self.vlm_options.torch_dtype,
|
dtype=self.vlm_options.torch_dtype,
|
||||||
_attn_implementation=(
|
_attn_implementation=(
|
||||||
"flash_attention_2"
|
"flash_attention_2"
|
||||||
if self.device.startswith("cuda")
|
if self.device.startswith("cuda")
|
||||||
|
|||||||
@@ -144,7 +144,7 @@ class NuExtractTransformersModel(BaseVlmModel, HuggingFaceModelDownloadMixin):
|
|||||||
self.vlm_model = AutoModelForImageTextToText.from_pretrained(
|
self.vlm_model = AutoModelForImageTextToText.from_pretrained(
|
||||||
artifacts_path,
|
artifacts_path,
|
||||||
device_map=self.device,
|
device_map=self.device,
|
||||||
torch_dtype=self.vlm_options.torch_dtype,
|
dtype=self.vlm_options.torch_dtype,
|
||||||
_attn_implementation=(
|
_attn_implementation=(
|
||||||
"flash_attention_2"
|
"flash_attention_2"
|
||||||
if self.device.startswith("cuda")
|
if self.device.startswith("cuda")
|
||||||
|
|||||||
@@ -194,6 +194,9 @@ class ExtractionVlmPipeline(BaseExtractionPipeline):
|
|||||||
class ExtractionTemplateFactory(ModelFactory[template]): # type: ignore
|
class ExtractionTemplateFactory(ModelFactory[template]): # type: ignore
|
||||||
__use_examples__ = True # prefer Field(examples=...) when present
|
__use_examples__ = True # prefer Field(examples=...) when present
|
||||||
__use_defaults__ = True # use field defaults instead of random values
|
__use_defaults__ = True # use field defaults instead of random values
|
||||||
|
__check_model__ = (
|
||||||
|
True # setting the value to avoid deprecation warnings
|
||||||
|
)
|
||||||
|
|
||||||
return ExtractionTemplateFactory.build().model_dump_json(indent=2) # type: ignore
|
return ExtractionTemplateFactory.build().model_dump_json(indent=2) # type: ignore
|
||||||
else:
|
else:
|
||||||
|
|||||||
Reference in New Issue
Block a user