From 5db127b2e831ac40c4e679e82aef8c16b076ee8f Mon Sep 17 00:00:00 2001 From: Michele Dolfi Date: Sun, 8 Jun 2025 16:50:57 +0200 Subject: [PATCH] fix: allow custom torch_dtype in vlm models Signed-off-by: Michele Dolfi --- docling/datamodel/pipeline_options_vlm_model.py | 3 ++- docling/models/vlm_models_inline/hf_transformers_model.py | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/docling/datamodel/pipeline_options_vlm_model.py b/docling/datamodel/pipeline_options_vlm_model.py index c1ec28aa..2289c3c7 100644 --- a/docling/datamodel/pipeline_options_vlm_model.py +++ b/docling/datamodel/pipeline_options_vlm_model.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Any, Dict, List, Literal +from typing import Any, Dict, List, Literal, Optional, Union from pydantic import AnyUrl, BaseModel from typing_extensions import deprecated @@ -42,6 +42,7 @@ class InlineVlmOptions(BaseVlmOptions): transformers_model_type: TransformersModelType = TransformersModelType.AUTOMODEL response_format: ResponseFormat + torch_dtype: Optional[str] = None supported_devices: List[AcceleratorDevice] = [ AcceleratorDevice.CPU, AcceleratorDevice.CUDA, diff --git a/docling/models/vlm_models_inline/hf_transformers_model.py b/docling/models/vlm_models_inline/hf_transformers_model.py index de7f289d..00fdfa58 100644 --- a/docling/models/vlm_models_inline/hf_transformers_model.py +++ b/docling/models/vlm_models_inline/hf_transformers_model.py @@ -99,6 +99,7 @@ class HuggingFaceTransformersVlmModel(BasePageModel, HuggingFaceModelDownloadMix self.vlm_model = model_cls.from_pretrained( artifacts_path, device_map=self.device, + torch_dtype=self.vlm_options.torch_dtype, _attn_implementation=( "flash_attention_2" if self.device.startswith("cuda")