Expose control over using flash_attention_2

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>
This commit is contained in:
Christoph Auer 2025-02-25 15:31:32 +01:00
parent 84c77c9fcb
commit 10f64a948c
3 changed files with 19 additions and 7 deletions

View File

@ -41,6 +41,7 @@ class AcceleratorOptions(BaseSettings):
num_threads: int = 4
device: Union[str, AcceleratorDevice] = "auto"
cuda_use_flash_attention2: bool = False
@field_validator("device")
def validate_device(cls, value):

View File

@ -64,9 +64,12 @@ class HuggingFaceVlmModel(BasePageModel):
self.vlm_model = AutoModelForVision2Seq.from_pretrained(
artifacts_path,
torch_dtype=torch.bfloat16,
# _attn_implementation=(
# "flash_attention_2" if self.device.startswith("cuda") else "eager"
# ),
_attn_implementation=(
"flash_attention_2"
if self.device.startswith("cuda")
and accelerator_options.cuda_use_flash_attention2
else "eager"
),
).to(self.device)
else:
@ -74,9 +77,12 @@ class HuggingFaceVlmModel(BasePageModel):
artifacts_path,
torch_dtype="auto",
quantization_config=self.param_quantization_config,
# _attn_implementation=(
# "flash_attention_2" if self.device.startswith("cuda") else "eager"
# ),
_attn_implementation=(
"flash_attention_2"
if self.device.startswith("cuda")
and accelerator_options.cuda_use_flash_attention2
else "eager"
),
).to(self.device)
@staticmethod

View File

@ -6,6 +6,7 @@ import yaml
from docling.datamodel.base_models import InputFormat
from docling.datamodel.pipeline_options import (
AcceleratorDevice,
VlmPipelineOptions,
granite_vision_vlm_conversion_options,
smoldocling_vlm_conversion_options,
@ -24,9 +25,13 @@ pipeline_options.generate_page_images = True
# If force_backend_text = True, text from backend will be used instead of generated text
pipeline_options.force_backend_text = False
## Enable flash_attention_2 with CUDA:
# pipeline_options.accelerator_options.device = AcceleratorDevice.CUDA
# pipeline_options.accelerator_options.cuda_use_flash_attention2 = True
pipeline_options.vlm_options = smoldocling_vlm_conversion_options
# Choose alternative VLM models:
## Choose alternative VLM models:
# pipeline_options.vlm_options = granite_vision_vlm_conversion_options
from docling_core.types.doc import DocItemLabel, ImageRefMode