fix: Safe pipeline init, use device_map in transformers models (#1917)

* Use device_map for transformer models

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Add accelerate

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Relax accelerate min version

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Make pipeline cache+init thread-safe

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

---------

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>
This commit is contained in:
Christoph Auer 2025-07-18 15:14:36 +02:00 committed by GitHub
parent e1e3053695
commit cca05c45ea
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 19 additions and 12 deletions

View File

@ -1,6 +1,7 @@
import hashlib import hashlib
import logging import logging
import sys import sys
import threading
import time import time
from collections.abc import Iterable, Iterator from collections.abc import Iterable, Iterator
from functools import partial from functools import partial
@ -49,6 +50,7 @@ from docling.pipeline.standard_pdf_pipeline import StandardPdfPipeline
from docling.utils.utils import chunkify from docling.utils.utils import chunkify
_log = logging.getLogger(__name__) _log = logging.getLogger(__name__)
_PIPELINE_CACHE_LOCK = threading.Lock()
class FormatOption(BaseModel): class FormatOption(BaseModel):
@ -315,17 +317,18 @@ class DocumentConverter:
# Use a composite key to cache pipelines # Use a composite key to cache pipelines
cache_key = (pipeline_class, options_hash) cache_key = (pipeline_class, options_hash)
if cache_key not in self.initialized_pipelines: with _PIPELINE_CACHE_LOCK:
_log.info( if cache_key not in self.initialized_pipelines:
f"Initializing pipeline for {pipeline_class.__name__} with options hash {options_hash}" _log.info(
) f"Initializing pipeline for {pipeline_class.__name__} with options hash {options_hash}"
self.initialized_pipelines[cache_key] = pipeline_class( )
pipeline_options=pipeline_options self.initialized_pipelines[cache_key] = pipeline_class(
) pipeline_options=pipeline_options
else: )
_log.debug( else:
f"Reusing cached pipeline for {pipeline_class.__name__} with options hash {options_hash}" _log.debug(
) f"Reusing cached pipeline for {pipeline_class.__name__} with options hash {options_hash}"
)
return self.initialized_pipelines[cache_key] return self.initialized_pipelines[cache_key]

View File

@ -65,6 +65,7 @@ class PictureDescriptionVlmModel(
self.processor = AutoProcessor.from_pretrained(artifacts_path) self.processor = AutoProcessor.from_pretrained(artifacts_path)
self.model = AutoModelForVision2Seq.from_pretrained( self.model = AutoModelForVision2Seq.from_pretrained(
artifacts_path, artifacts_path,
device_map=self.device,
torch_dtype=torch.bfloat16, torch_dtype=torch.bfloat16,
_attn_implementation=( _attn_implementation=(
"flash_attention_2" "flash_attention_2"
@ -72,7 +73,7 @@ class PictureDescriptionVlmModel(
and accelerator_options.cuda_use_flash_attention2 and accelerator_options.cuda_use_flash_attention2
else "eager" else "eager"
), ),
).to(self.device) )
self.provenance = f"{self.options.repo_id}" self.provenance = f"{self.options.repo_id}"

View File

@ -70,6 +70,7 @@ dependencies = [
'scipy (>=1.6.0,<2.0.0)', 'scipy (>=1.6.0,<2.0.0)',
# 'scipy (>=1.6.0,<2.0.0) ; python_version >= "3.10"', # 'scipy (>=1.6.0,<2.0.0) ; python_version >= "3.10"',
# 'scipy (>=1.6.0,<1.14.0) ; python_version < "3.10"', # 'scipy (>=1.6.0,<1.14.0) ; python_version < "3.10"',
"accelerate>=1.0.0,<2",
] ]
[project.urls] [project.urls]

2
uv.lock generated
View File

@ -809,6 +809,7 @@ name = "docling"
version = "2.41.0" version = "2.41.0"
source = { editable = "." } source = { editable = "." }
dependencies = [ dependencies = [
{ name = "accelerate" },
{ name = "beautifulsoup4" }, { name = "beautifulsoup4" },
{ name = "certifi" }, { name = "certifi" },
{ name = "docling-core", extra = ["chunking"] }, { name = "docling-core", extra = ["chunking"] },
@ -902,6 +903,7 @@ examples = [
[package.metadata] [package.metadata]
requires-dist = [ requires-dist = [
{ name = "accelerate", specifier = ">=1.0.0,<2" },
{ name = "accelerate", marker = "extra == 'vlm'", specifier = ">=1.2.1,<2.0.0" }, { name = "accelerate", marker = "extra == 'vlm'", specifier = ">=1.2.1,<2.0.0" },
{ name = "beautifulsoup4", specifier = ">=4.12.3,<5.0.0" }, { name = "beautifulsoup4", specifier = ">=4.12.3,<5.0.0" },
{ name = "certifi", specifier = ">=2024.7.4" }, { name = "certifi", specifier = ">=2024.7.4" },