fix: enable use_cuda_flash_attention2 for PictureDescriptionVlmModel

Signed-off-by: Zach Cox <zach.s.cox@gmail.com>
This commit is contained in:
Zach Cox 2025-04-29 19:33:58 -04:00
parent 976e92e289
commit 0961cda5fb

View File

@ -57,7 +57,10 @@ class PictureDescriptionVlmModel(PictureDescriptionBaseModel):
artifacts_path,
torch_dtype=torch.bfloat16,
_attn_implementation=(
"flash_attention_2" if self.device.startswith("cuda") else "eager"
"flash_attention_2"
if self.device.startswith("cuda")
and accelerator_options.cuda_use_flash_attention2
else "eager"
),
).to(self.device)