Use device_map for transformer models

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>
This commit is contained in:
Christoph Auer 2025-07-09 16:49:21 +02:00
parent 931eb55b88
commit 05123c9342

View File

@ -65,6 +65,7 @@ class PictureDescriptionVlmModel(
self.processor = AutoProcessor.from_pretrained(artifacts_path)
self.model = AutoModelForVision2Seq.from_pretrained(
artifacts_path,
device_map=self.device,
torch_dtype=torch.bfloat16,
_attn_implementation=(
"flash_attention_2"
@ -72,7 +73,7 @@ class PictureDescriptionVlmModel(
and accelerator_options.cuda_use_flash_attention2
else "eager"
),
).to(self.device)
)
self.provenance = f"{self.options.repo_id}"