mirror of
https://github.com/DS4SD/docling.git
synced 2025-07-26 20:14:47 +00:00
Use device_map for transformer models
Signed-off-by: Christoph Auer <cau@zurich.ibm.com>
This commit is contained in:
parent
931eb55b88
commit
05123c9342
@ -65,6 +65,7 @@ class PictureDescriptionVlmModel(
|
|||||||
self.processor = AutoProcessor.from_pretrained(artifacts_path)
|
self.processor = AutoProcessor.from_pretrained(artifacts_path)
|
||||||
self.model = AutoModelForVision2Seq.from_pretrained(
|
self.model = AutoModelForVision2Seq.from_pretrained(
|
||||||
artifacts_path,
|
artifacts_path,
|
||||||
|
device_map=self.device,
|
||||||
torch_dtype=torch.bfloat16,
|
torch_dtype=torch.bfloat16,
|
||||||
_attn_implementation=(
|
_attn_implementation=(
|
||||||
"flash_attention_2"
|
"flash_attention_2"
|
||||||
@ -72,7 +73,7 @@ class PictureDescriptionVlmModel(
|
|||||||
and accelerator_options.cuda_use_flash_attention2
|
and accelerator_options.cuda_use_flash_attention2
|
||||||
else "eager"
|
else "eager"
|
||||||
),
|
),
|
||||||
).to(self.device)
|
)
|
||||||
|
|
||||||
self.provenance = f"{self.options.repo_id}"
|
self.provenance = f"{self.options.repo_id}"
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user