mirror of
https://github.com/DS4SD/docling.git
synced 2025-07-24 19:14:23 +00:00
fix: Safe pipeline init, use device_map in transformers models (#1917)
* Use device_map for transformer models Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Add accelerate Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Relax accelerate min version Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Make pipeline cache+init thread-safe Signed-off-by: Christoph Auer <cau@zurich.ibm.com> --------- Signed-off-by: Christoph Auer <cau@zurich.ibm.com>
This commit is contained in:
parent
e1e3053695
commit
cca05c45ea
@ -1,6 +1,7 @@
|
|||||||
import hashlib
|
import hashlib
|
||||||
import logging
|
import logging
|
||||||
import sys
|
import sys
|
||||||
|
import threading
|
||||||
import time
|
import time
|
||||||
from collections.abc import Iterable, Iterator
|
from collections.abc import Iterable, Iterator
|
||||||
from functools import partial
|
from functools import partial
|
||||||
@ -49,6 +50,7 @@ from docling.pipeline.standard_pdf_pipeline import StandardPdfPipeline
|
|||||||
from docling.utils.utils import chunkify
|
from docling.utils.utils import chunkify
|
||||||
|
|
||||||
_log = logging.getLogger(__name__)
|
_log = logging.getLogger(__name__)
|
||||||
|
_PIPELINE_CACHE_LOCK = threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
class FormatOption(BaseModel):
|
class FormatOption(BaseModel):
|
||||||
@ -315,17 +317,18 @@ class DocumentConverter:
|
|||||||
# Use a composite key to cache pipelines
|
# Use a composite key to cache pipelines
|
||||||
cache_key = (pipeline_class, options_hash)
|
cache_key = (pipeline_class, options_hash)
|
||||||
|
|
||||||
if cache_key not in self.initialized_pipelines:
|
with _PIPELINE_CACHE_LOCK:
|
||||||
_log.info(
|
if cache_key not in self.initialized_pipelines:
|
||||||
f"Initializing pipeline for {pipeline_class.__name__} with options hash {options_hash}"
|
_log.info(
|
||||||
)
|
f"Initializing pipeline for {pipeline_class.__name__} with options hash {options_hash}"
|
||||||
self.initialized_pipelines[cache_key] = pipeline_class(
|
)
|
||||||
pipeline_options=pipeline_options
|
self.initialized_pipelines[cache_key] = pipeline_class(
|
||||||
)
|
pipeline_options=pipeline_options
|
||||||
else:
|
)
|
||||||
_log.debug(
|
else:
|
||||||
f"Reusing cached pipeline for {pipeline_class.__name__} with options hash {options_hash}"
|
_log.debug(
|
||||||
)
|
f"Reusing cached pipeline for {pipeline_class.__name__} with options hash {options_hash}"
|
||||||
|
)
|
||||||
|
|
||||||
return self.initialized_pipelines[cache_key]
|
return self.initialized_pipelines[cache_key]
|
||||||
|
|
||||||
|
@ -65,6 +65,7 @@ class PictureDescriptionVlmModel(
|
|||||||
self.processor = AutoProcessor.from_pretrained(artifacts_path)
|
self.processor = AutoProcessor.from_pretrained(artifacts_path)
|
||||||
self.model = AutoModelForVision2Seq.from_pretrained(
|
self.model = AutoModelForVision2Seq.from_pretrained(
|
||||||
artifacts_path,
|
artifacts_path,
|
||||||
|
device_map=self.device,
|
||||||
torch_dtype=torch.bfloat16,
|
torch_dtype=torch.bfloat16,
|
||||||
_attn_implementation=(
|
_attn_implementation=(
|
||||||
"flash_attention_2"
|
"flash_attention_2"
|
||||||
@ -72,7 +73,7 @@ class PictureDescriptionVlmModel(
|
|||||||
and accelerator_options.cuda_use_flash_attention2
|
and accelerator_options.cuda_use_flash_attention2
|
||||||
else "eager"
|
else "eager"
|
||||||
),
|
),
|
||||||
).to(self.device)
|
)
|
||||||
|
|
||||||
self.provenance = f"{self.options.repo_id}"
|
self.provenance = f"{self.options.repo_id}"
|
||||||
|
|
||||||
|
@ -70,6 +70,7 @@ dependencies = [
|
|||||||
'scipy (>=1.6.0,<2.0.0)',
|
'scipy (>=1.6.0,<2.0.0)',
|
||||||
# 'scipy (>=1.6.0,<2.0.0) ; python_version >= "3.10"',
|
# 'scipy (>=1.6.0,<2.0.0) ; python_version >= "3.10"',
|
||||||
# 'scipy (>=1.6.0,<1.14.0) ; python_version < "3.10"',
|
# 'scipy (>=1.6.0,<1.14.0) ; python_version < "3.10"',
|
||||||
|
"accelerate>=1.0.0,<2",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.urls]
|
[project.urls]
|
||||||
|
2
uv.lock
generated
2
uv.lock
generated
@ -809,6 +809,7 @@ name = "docling"
|
|||||||
version = "2.41.0"
|
version = "2.41.0"
|
||||||
source = { editable = "." }
|
source = { editable = "." }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
{ name = "accelerate" },
|
||||||
{ name = "beautifulsoup4" },
|
{ name = "beautifulsoup4" },
|
||||||
{ name = "certifi" },
|
{ name = "certifi" },
|
||||||
{ name = "docling-core", extra = ["chunking"] },
|
{ name = "docling-core", extra = ["chunking"] },
|
||||||
@ -902,6 +903,7 @@ examples = [
|
|||||||
|
|
||||||
[package.metadata]
|
[package.metadata]
|
||||||
requires-dist = [
|
requires-dist = [
|
||||||
|
{ name = "accelerate", specifier = ">=1.0.0,<2" },
|
||||||
{ name = "accelerate", marker = "extra == 'vlm'", specifier = ">=1.2.1,<2.0.0" },
|
{ name = "accelerate", marker = "extra == 'vlm'", specifier = ">=1.2.1,<2.0.0" },
|
||||||
{ name = "beautifulsoup4", specifier = ">=4.12.3,<5.0.0" },
|
{ name = "beautifulsoup4", specifier = ">=4.12.3,<5.0.0" },
|
||||||
{ name = "certifi", specifier = ">=2024.7.4" },
|
{ name = "certifi", specifier = ">=2024.7.4" },
|
||||||
|
Loading…
Reference in New Issue
Block a user