diff --git a/docling/datamodel/pipeline_options.py b/docling/datamodel/pipeline_options.py index 26b2e494..e537894f 100644 --- a/docling/datamodel/pipeline_options.py +++ b/docling/datamodel/pipeline_options.py @@ -41,6 +41,7 @@ class AcceleratorOptions(BaseSettings): num_threads: int = 4 device: Union[str, AcceleratorDevice] = "auto" + cuda_use_flash_attention2: bool = False @field_validator("device") def validate_device(cls, value): diff --git a/docling/models/hf_vlm_model.py b/docling/models/hf_vlm_model.py index 619658d3..6d99de66 100644 --- a/docling/models/hf_vlm_model.py +++ b/docling/models/hf_vlm_model.py @@ -64,9 +64,12 @@ class HuggingFaceVlmModel(BasePageModel): self.vlm_model = AutoModelForVision2Seq.from_pretrained( artifacts_path, torch_dtype=torch.bfloat16, - # _attn_implementation=( - # "flash_attention_2" if self.device.startswith("cuda") else "eager" - # ), + _attn_implementation=( + "flash_attention_2" + if self.device.startswith("cuda") + and accelerator_options.cuda_use_flash_attention2 + else "eager" + ), ).to(self.device) else: @@ -74,9 +77,12 @@ class HuggingFaceVlmModel(BasePageModel): artifacts_path, torch_dtype="auto", quantization_config=self.param_quantization_config, - # _attn_implementation=( - # "flash_attention_2" if self.device.startswith("cuda") else "eager" - # ), + _attn_implementation=( + "flash_attention_2" + if self.device.startswith("cuda") + and accelerator_options.cuda_use_flash_attention2 + else "eager" + ), ).to(self.device) @staticmethod diff --git a/docs/examples/minimal_vlm_pipeline.py b/docs/examples/minimal_vlm_pipeline.py index 42a9ef53..7c9913e9 100644 --- a/docs/examples/minimal_vlm_pipeline.py +++ b/docs/examples/minimal_vlm_pipeline.py @@ -6,6 +6,7 @@ import yaml from docling.datamodel.base_models import InputFormat from docling.datamodel.pipeline_options import ( + AcceleratorDevice, VlmPipelineOptions, granite_vision_vlm_conversion_options, smoldocling_vlm_conversion_options, @@ -24,9 +25,13 @@ pipeline_options.generate_page_images = True # If force_backend_text = True, text from backend will be used instead of generated text pipeline_options.force_backend_text = False +## Enable flash_attention_2 with CUDA: +# pipeline_options.accelerator_options.device = AcceleratorDevice.CUDA +# pipeline_options.accelerator_options.cuda_use_flash_attention2 = True + pipeline_options.vlm_options = smoldocling_vlm_conversion_options -# Choose alternative VLM models: +## Choose alternative VLM models: # pipeline_options.vlm_options = granite_vision_vlm_conversion_options from docling_core.types.doc import DocItemLabel, ImageRefMode