feat: batching support for VLMs in transformers backend, add initial VLLM backend (#2094)

* Prepare existing codes for use with new multi-stage VLM pipeline

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Add multithreaded VLM pipeline

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Add VLM task interpreters

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Add VLM task interpreters

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Remove prints

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Fix KeyboardInterrupt behaviour

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Add VLLM backend support, optimize process_images

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Tweak defaults

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Implement proper batch inference for HuggingFaceTransformersVlmModel

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Small fixes

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Cleanup hf_transformers_model batching impl

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Adjust example instatiation of multi-stage VLM pipeline

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Add GoT OCR 2.0

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Factor out changes without multi-stage pipeline

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Reset defaults for generation

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Cleanup

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Add torch.compile, fix temperature setting in gen_kwargs

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Expose page_batch_size on CLI

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Add torch_dtype bfloat16 to SMOLDOCLING and SMOLVLM model spec

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Clip off pad_token

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

---------

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>
This commit is contained in:
Christoph Auer
2025-08-22 13:17:33 +02:00
committed by GitHub
parent 3f03709885
commit 3c660c0511
17 changed files with 2837 additions and 319 deletions

View File

@@ -60,10 +60,12 @@ from docling.datamodel.pipeline_options import (
)
from docling.datamodel.settings import settings
from docling.datamodel.vlm_model_specs import (
GOT2_TRANSFORMERS,
GRANITE_VISION_OLLAMA,
GRANITE_VISION_TRANSFORMERS,
SMOLDOCLING_MLX,
SMOLDOCLING_TRANSFORMERS,
SMOLDOCLING_VLLM,
VlmModelType,
)
from docling.document_converter import (
@@ -477,6 +479,13 @@ def convert( # noqa: C901
"--logo", callback=logo_callback, is_eager=True, help="Docling logo"
),
] = None,
page_batch_size: Annotated[
int,
typer.Option(
"--page-batch-size",
help=f"Number of pages processed in one batch. Default: {settings.perf.page_batch_size}",
),
] = settings.perf.page_batch_size,
):
log_format = "%(asctime)s\t%(levelname)s\t%(name)s: %(message)s"
@@ -491,6 +500,7 @@ def convert( # noqa: C901
settings.debug.visualize_layout = debug_visualize_layout
settings.debug.visualize_tables = debug_visualize_tables
settings.debug.visualize_ocr = debug_visualize_ocr
settings.perf.page_batch_size = page_batch_size
if from_formats is None:
from_formats = list(InputFormat)
@@ -631,6 +641,8 @@ def convert( # noqa: C901
pipeline_options.vlm_options = GRANITE_VISION_TRANSFORMERS
elif vlm_model == VlmModelType.GRANITE_VISION_OLLAMA:
pipeline_options.vlm_options = GRANITE_VISION_OLLAMA
elif vlm_model == VlmModelType.GOT_OCR_2:
pipeline_options.vlm_options = GOT2_TRANSFORMERS
elif vlm_model == VlmModelType.SMOLDOCLING:
pipeline_options.vlm_options = SMOLDOCLING_TRANSFORMERS
if sys.platform == "darwin":
@@ -643,6 +655,8 @@ def convert( # noqa: C901
"To run SmolDocling faster, please install mlx-vlm:\n"
"pip install mlx-vlm"
)
elif vlm_model == VlmModelType.SMOLDOCLING_VLLM:
pipeline_options.vlm_options = SMOLDOCLING_VLLM
pdf_format_option = PdfFormatOption(
pipeline_cls=VlmPipeline, pipeline_options=pipeline_options

View File

@@ -1,7 +1,7 @@
import math
from collections import defaultdict
from enum import Enum
from typing import TYPE_CHECKING, Annotated, Dict, List, Literal, Optional, Union
from typing import TYPE_CHECKING, Dict, List, Optional, Union
import numpy as np
from docling_core.types.doc import (

View File

@@ -282,6 +282,9 @@ class LayoutOptions(BaseModel):
keep_empty_clusters: bool = (
False # Whether to keep clusters that contain no text cells
)
skip_cell_assignment: bool = (
False # Skip cell-to-cluster assignment for VLM-only processing
)
model_spec: LayoutModelConfig = DOCLING_LAYOUT_V2

View File

@@ -26,11 +26,14 @@ class ResponseFormat(str, Enum):
DOCTAGS = "doctags"
MARKDOWN = "markdown"
HTML = "html"
OTSL = "otsl"
PLAINTEXT = "plaintext"
class InferenceFramework(str, Enum):
MLX = "mlx"
TRANSFORMERS = "transformers"
VLLM = "vllm"
class TransformersModelType(str, Enum):
@@ -43,6 +46,7 @@ class TransformersModelType(str, Enum):
class TransformersPromptStyle(str, Enum):
CHAT = "chat"
RAW = "raw"
NONE = "none"
class InlineVlmOptions(BaseVlmOptions):
@@ -68,6 +72,7 @@ class InlineVlmOptions(BaseVlmOptions):
stop_strings: List[str] = []
extra_generation_config: Dict[str, Any] = {}
extra_processor_kwargs: Dict[str, Any] = {}
use_kv_cache: bool = True
max_new_tokens: int = 4096

View File

@@ -12,6 +12,7 @@ from docling.datamodel.pipeline_options_vlm_model import (
InlineVlmOptions,
ResponseFormat,
TransformersModelType,
TransformersPromptStyle,
)
_log = logging.getLogger(__name__)
@@ -26,6 +27,7 @@ SMOLDOCLING_MLX = InlineVlmOptions(
supported_devices=[AcceleratorDevice.MPS],
scale=2.0,
temperature=0.0,
stop_strings=["</doctag>", "<end_of_utterance>"],
)
SMOLDOCLING_TRANSFORMERS = InlineVlmOptions(
@@ -33,16 +35,74 @@ SMOLDOCLING_TRANSFORMERS = InlineVlmOptions(
prompt="Convert this page to docling.",
response_format=ResponseFormat.DOCTAGS,
inference_framework=InferenceFramework.TRANSFORMERS,
transformers_model_type=TransformersModelType.AUTOMODEL_VISION2SEQ,
transformers_model_type=TransformersModelType.AUTOMODEL_IMAGETEXTTOTEXT,
supported_devices=[
AcceleratorDevice.CPU,
AcceleratorDevice.CUDA,
],
torch_dtype="bfloat16",
scale=2.0,
temperature=0.0,
stop_strings=["</doctag>", "<end_of_utterance>"],
)
SMOLDOCLING_VLLM = InlineVlmOptions(
repo_id="ds4sd/SmolDocling-256M-preview",
prompt="Convert this page to docling.",
response_format=ResponseFormat.DOCTAGS,
inference_framework=InferenceFramework.VLLM,
supported_devices=[
AcceleratorDevice.CUDA,
],
scale=2.0,
temperature=0.0,
stop_strings=["</doctag>", "<end_of_utterance>"],
)
# SmolVLM-256M-Instruct
SMOLVLM256_TRANSFORMERS = InlineVlmOptions(
repo_id="HuggingFaceTB/SmolVLM-256M-Instruct",
prompt="Transcribe this image to plain text.",
response_format=ResponseFormat.PLAINTEXT,
inference_framework=InferenceFramework.TRANSFORMERS,
transformers_model_type=TransformersModelType.AUTOMODEL_IMAGETEXTTOTEXT,
supported_devices=[
AcceleratorDevice.CPU,
AcceleratorDevice.CUDA,
# AcceleratorDevice.MPS,
],
torch_dtype="bfloat16",
scale=2.0,
temperature=0.0,
)
# SmolVLM2-2.2b-Instruct
SMOLVLM256_MLX = InlineVlmOptions(
repo_id="moot20/SmolVLM-256M-Instruct-MLX",
prompt="Extract the text.",
response_format=ResponseFormat.DOCTAGS,
inference_framework=InferenceFramework.MLX,
transformers_model_type=TransformersModelType.AUTOMODEL_IMAGETEXTTOTEXT,
supported_devices=[
AcceleratorDevice.MPS,
],
scale=2.0,
temperature=0.0,
)
SMOLVLM256_VLLM = InlineVlmOptions(
repo_id="HuggingFaceTB/SmolVLM-256M-Instruct",
prompt="Transcribe this image to plain text.",
response_format=ResponseFormat.PLAINTEXT,
inference_framework=InferenceFramework.VLLM,
supported_devices=[
AcceleratorDevice.CUDA,
],
scale=2.0,
temperature=0.0,
)
# GraniteVision
GRANITE_VISION_TRANSFORMERS = InlineVlmOptions(
repo_id="ibm-granite/granite-vision-3.2-2b",
@@ -59,6 +119,18 @@ GRANITE_VISION_TRANSFORMERS = InlineVlmOptions(
temperature=0.0,
)
GRANITE_VISION_VLLM = InlineVlmOptions(
repo_id="ibm-granite/granite-vision-3.2-2b",
prompt="Convert this page to markdown. Do not miss any text and only output the bare markdown!",
response_format=ResponseFormat.MARKDOWN,
inference_framework=InferenceFramework.VLLM,
supported_devices=[
AcceleratorDevice.CUDA,
],
scale=2.0,
temperature=0.0,
)
GRANITE_VISION_OLLAMA = ApiVlmOptions(
url=AnyUrl("http://localhost:11434/v1/chat/completions"),
params={"model": "granite3.2-vision:2b"},
@@ -116,6 +188,26 @@ QWEN25_VL_3B_MLX = InlineVlmOptions(
temperature=0.0,
)
# GoT 2.0
GOT2_TRANSFORMERS = InlineVlmOptions(
repo_id="stepfun-ai/GOT-OCR-2.0-hf",
prompt="",
response_format=ResponseFormat.MARKDOWN,
inference_framework=InferenceFramework.TRANSFORMERS,
transformers_prompt_style=TransformersPromptStyle.NONE,
transformers_model_type=TransformersModelType.AUTOMODEL_IMAGETEXTTOTEXT,
supported_devices=[
AcceleratorDevice.CPU,
AcceleratorDevice.CUDA,
# AcceleratorDevice.MPS,
],
scale=2.0,
temperature=0.0,
stop_strings=["<|im_end|>"],
extra_processor_kwargs={"format": True},
)
# Gemma-3
GEMMA3_12B_MLX = InlineVlmOptions(
repo_id="mlx-community/gemma-3-12b-it-bf16",
@@ -137,8 +229,29 @@ GEMMA3_27B_MLX = InlineVlmOptions(
temperature=0.0,
)
# Dolphin
DOLPHIN_TRANSFORMERS = InlineVlmOptions(
repo_id="ByteDance/Dolphin",
prompt="<s>Read text in the image. <Answer/>",
response_format=ResponseFormat.MARKDOWN,
inference_framework=InferenceFramework.TRANSFORMERS,
transformers_model_type=TransformersModelType.AUTOMODEL_IMAGETEXTTOTEXT,
transformers_prompt_style=TransformersPromptStyle.RAW,
supported_devices=[
AcceleratorDevice.CUDA,
AcceleratorDevice.CPU,
AcceleratorDevice.MPS,
],
scale=2.0,
temperature=0.0,
)
class VlmModelType(str, Enum):
SMOLDOCLING = "smoldocling"
SMOLDOCLING_VLLM = "smoldocling_vllm"
GRANITE_VISION = "granite_vision"
GRANITE_VISION_VLLM = "granite_vision_vllm"
GRANITE_VISION_OLLAMA = "granite_vision_ollama"
GOT_OCR_2 = "got_ocr_2"

View File

@@ -1,13 +1,24 @@
import logging
from abc import ABC, abstractmethod
from collections.abc import Iterable
from typing import Generic, Optional, Protocol, Type
from typing import Any, Generic, Optional, Protocol, Type, Union
import numpy as np
from docling_core.types.doc import BoundingBox, DocItem, DoclingDocument, NodeItem
from PIL.Image import Image
from typing_extensions import TypeVar
from docling.datamodel.base_models import ItemAndImageEnrichmentElement, Page
from docling.datamodel.base_models import (
ItemAndImageEnrichmentElement,
Page,
VlmPrediction,
)
from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options import BaseOptions
from docling.datamodel.pipeline_options_vlm_model import (
InlineVlmOptions,
TransformersPromptStyle,
)
from docling.datamodel.settings import settings
@@ -26,6 +37,88 @@ class BasePageModel(ABC):
pass
class BaseVlmModel(ABC):
"""Base class for Vision-Language Models that adds image processing capability."""
@abstractmethod
def process_images(
self,
image_batch: Iterable[Union[Image, np.ndarray]],
prompt: Union[str, list[str]],
) -> Iterable[VlmPrediction]:
"""Process raw images without page metadata.
Args:
image_batch: Iterable of PIL Images or numpy arrays
prompt: Either:
- str: Single prompt used for all images
- list[str]: List of prompts (one per image, must match image count)
Raises:
ValueError: If prompt list length doesn't match image count.
"""
class BaseVlmPageModel(BasePageModel, BaseVlmModel):
"""Base implementation for VLM models that inherit from BasePageModel.
Provides a default __call__ implementation that extracts images from pages,
processes them using process_images, and attaches results back to pages.
"""
# Type annotations for attributes that subclasses must initialize
vlm_options: InlineVlmOptions
processor: Any
@abstractmethod
def __call__(
self, conv_res: ConversionResult, page_batch: Iterable[Page]
) -> Iterable[Page]:
"""Extract images from pages, process them, and attach results back."""
def formulate_prompt(self, user_prompt: str) -> str:
"""Formulate a prompt for the VLM."""
_log = logging.getLogger(__name__)
if self.vlm_options.transformers_prompt_style == TransformersPromptStyle.RAW:
return user_prompt
elif self.vlm_options.repo_id == "microsoft/Phi-4-multimodal-instruct":
_log.debug("Using specialized prompt for Phi-4")
# Note: This might need adjustment for VLLM vs transformers
user_prompt_prefix = "<|user|>"
assistant_prompt = "<|assistant|>"
prompt_suffix = "<|end|>"
prompt = f"{user_prompt_prefix}<|image_1|>{user_prompt}{prompt_suffix}{assistant_prompt}"
_log.debug(f"prompt for {self.vlm_options.repo_id}: {prompt}")
return prompt
elif self.vlm_options.transformers_prompt_style == TransformersPromptStyle.CHAT:
messages = [
{
"role": "user",
"content": [
{
"type": "text",
"text": "This is a page from a document.",
},
{"type": "image"},
{"type": "text", "text": user_prompt},
],
}
]
prompt = self.processor.apply_chat_template(
messages, add_generation_prompt=True
)
return prompt
raise RuntimeError(
f"Unknown prompt style `{self.vlm_options.transformers_prompt_style}`. Valid values are {', '.join(s.value for s in TransformersPromptStyle)}."
)
EnrichElementT = TypeVar("EnrichElementT", default=NodeItem)

View File

@@ -17,6 +17,9 @@ from docling.utils.profiling import TimeRecorder
class PagePreprocessingOptions(BaseModel):
images_scale: Optional[float]
skip_cell_extraction: bool = (
False # Skip text cell extraction for VLM-only processing
)
class PagePreprocessingModel(BasePageModel):
@@ -41,7 +44,8 @@ class PagePreprocessingModel(BasePageModel):
else:
with TimeRecorder(conv_res, "page_parse"):
page = self._populate_page_images(page)
page = self._parse_page_cells(conv_res, page)
if not self.options.skip_cell_extraction:
page = self._parse_page_cells(conv_res, page)
yield page
# Generate the page image and store it in the page object

View File

@@ -4,6 +4,7 @@ from pathlib import Path
from typing import Optional, Type, Union
from PIL import Image
from transformers import AutoModelForImageTextToText
from docling.datamodel.accelerator_options import AcceleratorOptions
from docling.datamodel.pipeline_options import (
@@ -63,7 +64,7 @@ class PictureDescriptionVlmModel(
# Initialize processor and model
with _model_init_lock:
self.processor = AutoProcessor.from_pretrained(artifacts_path)
self.model = AutoModelForVision2Seq.from_pretrained(
self.model = AutoModelForImageTextToText.from_pretrained(
artifacts_path,
device_map=self.device,
torch_dtype=torch.bfloat16,
@@ -71,9 +72,10 @@ class PictureDescriptionVlmModel(
"flash_attention_2"
if self.device.startswith("cuda")
and accelerator_options.cuda_use_flash_attention2
else "eager"
else "sdpa"
),
)
self.model = torch.compile(self.model) # type: ignore
self.provenance = f"{self.options.repo_id}"

View File

@@ -3,7 +3,11 @@ import logging
import time
from collections.abc import Iterable
from pathlib import Path
from typing import Any, Optional
from typing import Any, Optional, Union
import numpy as np
from PIL.Image import Image
from transformers import StoppingCriteriaList, StopStringCriteria
from docling.datamodel.accelerator_options import (
AcceleratorOptions,
@@ -15,7 +19,7 @@ from docling.datamodel.pipeline_options_vlm_model import (
TransformersModelType,
TransformersPromptStyle,
)
from docling.models.base_model import BasePageModel
from docling.models.base_model import BaseVlmPageModel
from docling.models.utils.hf_model_download import (
HuggingFaceModelDownloadMixin,
)
@@ -25,7 +29,7 @@ from docling.utils.profiling import TimeRecorder
_log = logging.getLogger(__name__)
class HuggingFaceTransformersVlmModel(BasePageModel, HuggingFaceModelDownloadMixin):
class HuggingFaceTransformersVlmModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
def __init__(
self,
enabled: bool,
@@ -103,6 +107,8 @@ class HuggingFaceTransformersVlmModel(BasePageModel, HuggingFaceModelDownloadMix
artifacts_path,
trust_remote_code=vlm_options.trust_remote_code,
)
self.processor.tokenizer.padding_side = "left"
self.vlm_model = model_cls.from_pretrained(
artifacts_path,
device_map=self.device,
@@ -111,10 +117,11 @@ class HuggingFaceTransformersVlmModel(BasePageModel, HuggingFaceModelDownloadMix
"flash_attention_2"
if self.device.startswith("cuda")
and accelerator_options.cuda_use_flash_attention2
else "eager"
else "sdpa"
),
trust_remote_code=vlm_options.trust_remote_code,
)
self.vlm_model = torch.compile(self.vlm_model) # type: ignore
# Load generation config
self.generation_config = GenerationConfig.from_pretrained(artifacts_path)
@@ -122,93 +129,186 @@ class HuggingFaceTransformersVlmModel(BasePageModel, HuggingFaceModelDownloadMix
def __call__(
self, conv_res: ConversionResult, page_batch: Iterable[Page]
) -> Iterable[Page]:
for page in page_batch:
page_list = list(page_batch)
if not page_list:
return
valid_pages = []
invalid_pages = []
for page in page_list:
assert page._backend is not None
if not page._backend.is_valid():
yield page
invalid_pages.append(page)
else:
with TimeRecorder(conv_res, "vlm"):
assert page.size is not None
valid_pages.append(page)
# Process valid pages in batch
if valid_pages:
with TimeRecorder(conv_res, "vlm"):
# Prepare images and prompts for batch processing
images = []
user_prompts = []
pages_with_images = []
for page in valid_pages:
assert page.size is not None
hi_res_image = page.get_image(
scale=self.vlm_options.scale, max_size=self.vlm_options.max_size
)
# Define prompt structure
user_prompt = self.vlm_options.build_prompt(page.parsed_page)
prompt = self.formulate_prompt(user_prompt)
# Only process pages with valid images
if hi_res_image is not None:
images.append(hi_res_image)
inputs = self.processor(
text=prompt, images=[hi_res_image], return_tensors="pt"
).to(self.device)
# Define prompt structure
user_prompt = self.vlm_options.build_prompt(page.parsed_page)
start_time = time.time()
# Call model to generate:
generated_ids = self.vlm_model.generate(
**inputs,
max_new_tokens=self.max_new_tokens,
use_cache=self.use_cache,
temperature=self.temperature,
generation_config=self.generation_config,
**self.vlm_options.extra_generation_config,
)
user_prompts.append(user_prompt)
pages_with_images.append(page)
generation_time = time.time() - start_time
generated_texts = self.processor.batch_decode(
generated_ids[:, inputs["input_ids"].shape[1] :],
skip_special_tokens=False,
)[0]
# Use process_images for the actual inference
if images: # Only if we have valid images
predictions = list(self.process_images(images, user_prompts))
num_tokens = len(generated_ids[0])
_log.debug(
f"Generated {num_tokens} tokens in time {generation_time:.2f} seconds."
)
generated_texts = self.vlm_options.decode_response(generated_texts)
page.predictions.vlm_response = VlmPrediction(
text=generated_texts,
generation_time=generation_time,
)
# Attach results to pages
for page, prediction in zip(pages_with_images, predictions):
page.predictions.vlm_response = prediction
yield page
# Yield all pages (valid and invalid)
for page in invalid_pages:
yield page
for page in valid_pages:
yield page
def formulate_prompt(self, user_prompt: str) -> str:
"""Formulate a prompt for the VLM."""
def process_images(
self,
image_batch: Iterable[Union[Image, np.ndarray]],
prompt: Union[str, list[str]],
) -> Iterable[VlmPrediction]:
"""
Batched inference for Hugging Face Image-Text-to-Text VLMs (e.g., SmolDocling / SmolVLM).
- Lets the processor handle all padding & batching for text+images.
- Trims generated sequences per row using attention_mask (no pad-id fallbacks).
- Keeps your formulate_prompt() exactly as-is.
"""
import numpy as np
import torch
from PIL import Image as PILImage
if self.vlm_options.transformers_prompt_style == TransformersPromptStyle.RAW:
return user_prompt
# -- Normalize images to RGB PIL (SmolDocling & friends accept PIL/np via processor)
pil_images: list[Image] = []
for img in image_batch:
if isinstance(img, np.ndarray):
if img.ndim == 3 and img.shape[2] in (3, 4):
pil_img = PILImage.fromarray(img.astype(np.uint8))
elif img.ndim == 2:
pil_img = PILImage.fromarray(img.astype(np.uint8), mode="L")
else:
raise ValueError(f"Unsupported numpy array shape: {img.shape}")
else:
pil_img = img
if pil_img.mode != "RGB":
pil_img = pil_img.convert("RGB")
pil_images.append(pil_img)
elif self.vlm_options.repo_id == "microsoft/Phi-4-multimodal-instruct":
_log.debug("Using specialized prompt for Phi-4")
# more info here: https://huggingface.co/microsoft/Phi-4-multimodal-instruct#loading-the-model-locally
if not pil_images:
return
user_prompt = "<|user|>"
assistant_prompt = "<|assistant|>"
prompt_suffix = "<|end|>"
# -- Normalize prompts (1 per image)
if isinstance(prompt, str):
user_prompts = [prompt] * len(pil_images)
else:
if len(prompt) != len(pil_images):
raise ValueError(
f"Number of prompts ({len(prompt)}) must match number of images ({len(pil_images)})"
)
user_prompts = prompt
prompt = f"{user_prompt}<|image_1|>{user_prompt}{prompt_suffix}{assistant_prompt}"
_log.debug(f"prompt for {self.vlm_options.repo_id}: {prompt}")
return prompt
elif self.vlm_options.transformers_prompt_style == TransformersPromptStyle.CHAT:
messages = [
{
"role": "user",
"content": [
{
"type": "text",
"text": "This is a page from a document.",
},
{"type": "image"},
{"type": "text", "text": user_prompt},
],
}
]
prompt = self.processor.apply_chat_template(
messages, add_generation_prompt=False
# Use your prompt formatter verbatim
if self.vlm_options.transformers_prompt_style == TransformersPromptStyle.NONE:
inputs = self.processor(
pil_images,
return_tensors="pt",
padding=True, # pad across batch for both text and vision
**self.vlm_options.extra_processor_kwargs,
)
return prompt
else:
prompts: list[str] = [self.formulate_prompt(p) for p in user_prompts]
raise RuntimeError(
f"Uknown prompt style `{self.vlm_options.transformers_prompt_style}`. Valid values are {', '.join(s.value for s in TransformersPromptStyle)}."
# -- Processor performs BOTH text+image preprocessing + batch padding (recommended)
inputs = self.processor(
text=prompts,
images=pil_images,
return_tensors="pt",
padding=True, # pad across batch for both text and vision
**self.vlm_options.extra_processor_kwargs,
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
# -- Optional stopping criteria
stopping_criteria = None
if self.vlm_options.stop_strings:
stopping_criteria = StoppingCriteriaList(
[
StopStringCriteria(
stop_strings=self.vlm_options.stop_strings,
tokenizer=self.processor.tokenizer,
)
]
)
# -- Generate (Image-Text-to-Text class expects these inputs from processor)
gen_kwargs = {
**inputs,
"max_new_tokens": self.max_new_tokens,
"use_cache": self.use_cache,
"generation_config": self.generation_config,
**self.vlm_options.extra_generation_config,
}
if self.temperature > 0:
gen_kwargs["do_sample"] = True
gen_kwargs["temperature"] = self.temperature
else:
gen_kwargs["do_sample"] = False
if stopping_criteria is not None:
gen_kwargs["stopping_criteria"] = stopping_criteria
start_time = time.time()
with torch.inference_mode():
generated_ids = self.vlm_model.generate(**gen_kwargs)
generation_time = time.time() - start_time
input_len = inputs["input_ids"].shape[1] # common right-aligned prompt length
trimmed_sequences = generated_ids[:, input_len:] # only newly generated tokens
# -- Decode with the processor/tokenizer (skip specials, keep DocTags as text)
decode_fn = getattr(self.processor, "batch_decode", None)
if decode_fn is None and getattr(self.processor, "tokenizer", None) is not None:
decode_fn = self.processor.tokenizer.batch_decode
if decode_fn is None:
raise RuntimeError(
"Neither processor.batch_decode nor tokenizer.batch_decode is available."
)
decoded_texts: list[str] = decode_fn(
trimmed_sequences, skip_special_tokens=False
)
# -- Clip off pad tokens from decoded texts
pad_token = self.processor.tokenizer.pad_token
if pad_token:
decoded_texts = [text.rstrip(pad_token) for text in decoded_texts]
# -- Optional logging
if generated_ids.shape[0] > 0:
_log.debug(
f"Generated {int(generated_ids[0].shape[0])} tokens in {generation_time:.2f}s "
f"for batch size {generated_ids.shape[0]}."
)
for text in decoded_texts:
# Apply decode_response to the output text
decoded_text = self.vlm_options.decode_response(text)
yield VlmPrediction(text=decoded_text, generation_time=generation_time)

View File

@@ -1,8 +1,12 @@
import logging
import threading
import time
from collections.abc import Iterable
from pathlib import Path
from typing import Optional
from typing import Optional, Union
import numpy as np
from PIL.Image import Image
from docling.datamodel.accelerator_options import (
AcceleratorOptions,
@@ -10,7 +14,7 @@ from docling.datamodel.accelerator_options import (
from docling.datamodel.base_models import Page, VlmPrediction, VlmPredictionToken
from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options_vlm_model import InlineVlmOptions
from docling.models.base_model import BasePageModel
from docling.models.base_model import BaseVlmPageModel
from docling.models.utils.hf_model_download import (
HuggingFaceModelDownloadMixin,
)
@@ -18,8 +22,12 @@ from docling.utils.profiling import TimeRecorder
_log = logging.getLogger(__name__)
# Global lock for MLX model calls - MLX models are not thread-safe
# All MLX models share this lock to prevent concurrent MLX operations
_MLX_GLOBAL_LOCK = threading.Lock()
class HuggingFaceMlxModel(BasePageModel, HuggingFaceModelDownloadMixin):
class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
def __init__(
self,
enabled: bool,
@@ -63,87 +71,190 @@ class HuggingFaceMlxModel(BasePageModel, HuggingFaceModelDownloadMixin):
def __call__(
self, conv_res: ConversionResult, page_batch: Iterable[Page]
) -> Iterable[Page]:
for page in page_batch:
page_list = list(page_batch)
if not page_list:
return
valid_pages = []
invalid_pages = []
for page in page_list:
assert page._backend is not None
if not page._backend.is_valid():
yield page
invalid_pages.append(page)
else:
with TimeRecorder(conv_res, f"vlm-mlx-{self.vlm_options.repo_id}"):
assert page.size is not None
valid_pages.append(page)
# Process valid pages in batch
if valid_pages:
with TimeRecorder(conv_res, f"vlm-mlx-{self.vlm_options.repo_id}"):
# Prepare images and prompts for batch processing
images = []
user_prompts = []
pages_with_images = []
for page in valid_pages:
assert page.size is not None
hi_res_image = page.get_image(
scale=self.vlm_options.scale, max_size=self.vlm_options.max_size
)
# Only process pages with valid images
if hi_res_image is not None:
im_width, im_height = hi_res_image.size
images.append(hi_res_image)
# populate page_tags with predicted doc tags
page_tags = ""
if hi_res_image:
if hi_res_image.mode != "RGB":
hi_res_image = hi_res_image.convert("RGB")
user_prompt = self.vlm_options.build_prompt(page.parsed_page)
prompt = self.apply_chat_template(
self.processor, self.config, user_prompt, num_images=1
)
start_time = time.time()
_log.debug("start generating ...")
# Call model to generate:
tokens: list[VlmPredictionToken] = []
output = ""
for token in self.stream_generate(
self.vlm_model,
self.processor,
prompt,
[hi_res_image],
max_tokens=self.max_tokens,
verbose=False,
temp=self.temperature,
):
if len(token.logprobs.shape) == 1:
tokens.append(
VlmPredictionToken(
text=token.text,
token=token.token,
logprob=token.logprobs[token.token],
)
)
elif (
len(token.logprobs.shape) == 2
and token.logprobs.shape[0] == 1
):
tokens.append(
VlmPredictionToken(
text=token.text,
token=token.token,
logprob=token.logprobs[0, token.token],
)
)
# Define prompt structure
if callable(self.vlm_options.prompt):
user_prompt = self.vlm_options.prompt(page.parsed_page)
else:
_log.warning(
f"incompatible shape for logprobs: {token.logprobs.shape}"
)
user_prompt = self.vlm_options.prompt
output += token.text
if "</doctag>" in token.text:
user_prompts.append(user_prompt)
pages_with_images.append(page)
# Use process_images for the actual inference
if images: # Only if we have valid images
predictions = list(self.process_images(images, user_prompts))
# Attach results to pages
for page, prediction in zip(pages_with_images, predictions):
page.predictions.vlm_response = prediction
# Yield all pages (valid and invalid)
for page in invalid_pages:
yield page
for page in valid_pages:
yield page
def process_images(
self,
image_batch: Iterable[Union[Image, np.ndarray]],
prompt: Union[str, list[str]],
) -> Iterable[VlmPrediction]:
"""Process raw images without page metadata.
Args:
image_batch: Iterable of PIL Images or numpy arrays
prompt: Either:
- str: Single prompt used for all images
- list[str]: List of prompts (one per image, must match image count)
Raises:
ValueError: If prompt list length doesn't match image count.
"""
# Convert image batch to list for length validation
image_list = list(image_batch)
if len(image_list) == 0:
return
# Handle prompt parameter
if isinstance(prompt, str):
# Single prompt for all images
user_prompts = [prompt] * len(image_list)
elif isinstance(prompt, list):
# List of prompts (one per image)
if len(prompt) != len(image_list):
raise ValueError(
f"Number of prompts ({len(prompt)}) must match number of images ({len(image_list)})"
)
user_prompts = prompt
else:
raise ValueError(f"prompt must be str or list[str], got {type(prompt)}")
# MLX models are not thread-safe - use global lock to serialize access
with _MLX_GLOBAL_LOCK:
_log.debug("MLX model: Acquired global lock for thread safety")
for image, user_prompt in zip(image_list, user_prompts):
# Convert numpy array to PIL Image if needed
if isinstance(image, np.ndarray):
if image.ndim == 3 and image.shape[2] in [3, 4]:
# RGB or RGBA array
from PIL import Image as PILImage
image = PILImage.fromarray(image.astype(np.uint8))
elif image.ndim == 2:
# Grayscale array
from PIL import Image as PILImage
image = PILImage.fromarray(image.astype(np.uint8), mode="L")
else:
raise ValueError(
f"Unsupported numpy array shape: {image.shape}"
)
# Ensure image is in RGB mode (handles RGBA, L, etc.)
if image.mode != "RGB":
image = image.convert("RGB")
# Use the MLX chat template approach like in the __call__ method
formatted_prompt = self.apply_chat_template(
self.processor, self.config, user_prompt, num_images=1
)
# Stream generate with stop strings support
start_time = time.time()
_log.debug("start generating ...")
tokens: list[VlmPredictionToken] = []
output = ""
# Use stream_generate for proper stop string handling
for token in self.stream_generate(
self.vlm_model,
self.processor,
formatted_prompt,
[image], # MLX stream_generate expects list of images
max_tokens=self.max_tokens,
verbose=False,
temp=self.temperature,
):
# Collect token information
if len(token.logprobs.shape) == 1:
tokens.append(
VlmPredictionToken(
text=token.text,
token=token.token,
logprob=token.logprobs[token.token],
)
)
elif (
len(token.logprobs.shape) == 2 and token.logprobs.shape[0] == 1
):
tokens.append(
VlmPredictionToken(
text=token.text,
token=token.token,
logprob=token.logprobs[0, token.token],
)
)
else:
_log.warning(
f"incompatible shape for logprobs: {token.logprobs.shape}"
)
output += token.text
# Check for any configured stop strings
if self.vlm_options.stop_strings:
if any(
stop_str in output
for stop_str in self.vlm_options.stop_strings
):
_log.debug("Stopping generation due to stop string match")
break
generation_time = time.time() - start_time
page_tags = output
generation_time = time.time() - start_time
_log.debug(
f"{generation_time:.2f} seconds for {len(tokens)} tokens ({len(tokens) / generation_time} tokens/sec)."
)
page_tags = self.vlm_options.decode_response(page_tags)
page.predictions.vlm_response = VlmPrediction(
text=page_tags,
generation_time=generation_time,
generated_tokens=tokens,
)
_log.debug(
f"{generation_time:.2f} seconds for {len(tokens)} tokens ({len(tokens) / generation_time:.1f} tokens/sec)."
)
yield page
# Apply decode_response to the output before yielding
decoded_output = self.vlm_options.decode_response(output)
yield VlmPrediction(
text=decoded_output,
generation_time=generation_time,
generated_tokens=tokens,
)
_log.debug("MLX model: Released global lock")

View File

@@ -0,0 +1,235 @@
import logging
import time
from collections.abc import Iterable
from pathlib import Path
from typing import Any, Dict, Optional, Union
import numpy as np
from PIL.Image import Image
from docling.datamodel.accelerator_options import (
AcceleratorOptions,
)
from docling.datamodel.base_models import Page, VlmPrediction
from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options_vlm_model import (
InlineVlmOptions,
TransformersPromptStyle,
)
from docling.models.base_model import BaseVlmPageModel
from docling.models.utils.hf_model_download import (
HuggingFaceModelDownloadMixin,
)
from docling.utils.accelerator_utils import decide_device
from docling.utils.profiling import TimeRecorder
_log = logging.getLogger(__name__)
class VllmVlmModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
def __init__(
self,
enabled: bool,
artifacts_path: Optional[Path],
accelerator_options: AcceleratorOptions,
vlm_options: InlineVlmOptions,
):
self.enabled = enabled
self.vlm_options = vlm_options
if self.enabled:
from transformers import AutoProcessor
from vllm import LLM, SamplingParams
self.device = decide_device(
accelerator_options.device,
supported_devices=vlm_options.supported_devices,
)
_log.debug(f"Available device for VLM: {self.device}")
self.max_new_tokens = vlm_options.max_new_tokens
self.temperature = vlm_options.temperature
repo_cache_folder = vlm_options.repo_id.replace("/", "--")
if artifacts_path is None:
artifacts_path = self.download_models(self.vlm_options.repo_id)
elif (artifacts_path / repo_cache_folder).exists():
artifacts_path = artifacts_path / repo_cache_folder
# Initialize VLLM LLM
llm_kwargs: Dict[str, Any] = {
"model": str(artifacts_path),
"limit_mm_per_prompt": {"image": 1},
"trust_remote_code": vlm_options.trust_remote_code,
"model_impl": "transformers",
"gpu_memory_utilization": 0.3, # hardcoded for now, leaves room for ~3 different models.
}
# Add device-specific configurations
if self.device == "cpu":
llm_kwargs["device"] = "cpu"
# Add quantization if specified
if vlm_options.quantized:
if vlm_options.load_in_8bit:
llm_kwargs["quantization"] = "bitsandbytes"
self.llm = LLM(**llm_kwargs)
# Initialize processor for prompt formatting
self.processor = AutoProcessor.from_pretrained(
artifacts_path,
trust_remote_code=vlm_options.trust_remote_code,
)
# Set up sampling parameters
self.sampling_params = SamplingParams(
temperature=self.temperature,
max_tokens=self.max_new_tokens,
stop=vlm_options.stop_strings if vlm_options.stop_strings else None,
**vlm_options.extra_generation_config,
)
def __call__(
self, conv_res: ConversionResult, page_batch: Iterable[Page]
) -> Iterable[Page]:
page_list = list(page_batch)
if not page_list:
return
valid_pages = []
invalid_pages = []
for page in page_list:
assert page._backend is not None
if not page._backend.is_valid():
invalid_pages.append(page)
else:
valid_pages.append(page)
# Process valid pages in batch
if valid_pages:
with TimeRecorder(conv_res, "vlm"):
# Prepare images and prompts for batch processing
images = []
user_prompts = []
pages_with_images = []
for page in valid_pages:
assert page.size is not None
hi_res_image = page.get_image(
scale=self.vlm_options.scale, max_size=self.vlm_options.max_size
)
# Only process pages with valid images
if hi_res_image is not None:
images.append(hi_res_image)
# Define prompt structure
if callable(self.vlm_options.prompt):
user_prompt = self.vlm_options.prompt(page.parsed_page)
else:
user_prompt = self.vlm_options.prompt
user_prompts.append(user_prompt)
pages_with_images.append(page)
# Use process_images for the actual inference
if images: # Only if we have valid images
predictions = list(self.process_images(images, user_prompts))
# Attach results to pages
for page, prediction in zip(pages_with_images, predictions):
page.predictions.vlm_response = prediction
# Yield all pages (valid and invalid)
for page in invalid_pages:
yield page
for page in valid_pages:
yield page
def process_images(
self,
image_batch: Iterable[Union[Image, np.ndarray]],
prompt: Union[str, list[str]],
) -> Iterable[VlmPrediction]:
"""Process raw images without page metadata in a single batched inference call.
Args:
image_batch: Iterable of PIL Images or numpy arrays
prompt: Either:
- str: Single prompt used for all images
- list[str]: List of prompts (one per image, must match image count)
Raises:
ValueError: If prompt list length doesn't match image count.
"""
pil_images: list[Image] = []
for img in image_batch:
# Convert numpy array to PIL Image if needed
if isinstance(img, np.ndarray):
if img.ndim == 3 and img.shape[2] in [3, 4]:
from PIL import Image as PILImage
pil_img = PILImage.fromarray(img.astype(np.uint8))
elif img.ndim == 2:
from PIL import Image as PILImage
pil_img = PILImage.fromarray(img.astype(np.uint8), mode="L")
else:
raise ValueError(f"Unsupported numpy array shape: {img.shape}")
else:
pil_img = img
# Ensure image is in RGB mode (handles RGBA, L, etc.)
if pil_img.mode != "RGB":
pil_img = pil_img.convert("RGB")
pil_images.append(pil_img)
if len(pil_images) == 0:
return
# Handle prompt parameter
if isinstance(prompt, str):
# Single prompt for all images
user_prompts = [prompt] * len(pil_images)
elif isinstance(prompt, list):
# List of prompts (one per image)
if len(prompt) != len(pil_images):
raise ValueError(
f"Number of prompts ({len(prompt)}) must match number of images ({len(pil_images)})"
)
user_prompts = prompt
else:
raise ValueError(f"prompt must be str or list[str], got {type(prompt)}")
# Format prompts individually
prompts: list[str] = [
self.formulate_prompt(user_prompt) for user_prompt in user_prompts
]
# Prepare VLLM inputs
llm_inputs = []
for prompt, image in zip(prompts, pil_images):
llm_inputs.append({"prompt": prompt, "multi_modal_data": {"image": image}})
start_time = time.time()
outputs = self.llm.generate(llm_inputs, sampling_params=self.sampling_params) # type: ignore
generation_time = time.time() - start_time
# Logging tokens count for the first sample as a representative metric
if len(outputs) > 0:
num_tokens = len(outputs[0].outputs[0].token_ids)
_log.debug(
f"Generated {num_tokens} tokens in time {generation_time:.2f} seconds."
)
for output in outputs:
# Apply decode_response to the output text
decoded_text = self.vlm_options.decode_response(output.outputs[0].text)
yield VlmPrediction(text=decoded_text, generation_time=generation_time)

View File

@@ -194,7 +194,7 @@ class ThreadedPipelineStage:
return
self._running = True
self._thread = threading.Thread(
target=self._run, name=f"Stage-{self.name}", daemon=False
target=self._run, name=f"Stage-{self.name}", daemon=True
)
self._thread.start()

View File

@@ -103,6 +103,17 @@ class VlmPipeline(PaginatedPipeline):
vlm_options=vlm_options,
),
]
elif vlm_options.inference_framework == InferenceFramework.VLLM:
from docling.models.vlm_models_inline.vllm_model import VllmVlmModel
self.build_pipe = [
VllmVlmModel(
enabled=True, # must be always enabled for this pipeline to make sense.
artifacts_path=artifacts_path,
accelerator_options=pipeline_options.accelerator_options,
vlm_options=vlm_options,
),
]
else:
raise ValueError(
f"Could not instantiate the right type of VLM pipeline: {vlm_options.inference_framework}"
@@ -117,7 +128,9 @@ class VlmPipeline(PaginatedPipeline):
page._backend = conv_res.input._backend.load_page(page.page_no) # type: ignore
if page._backend is not None and page._backend.is_valid():
page.size = page._backend.get_size()
page.parsed_page = page._backend.get_segmented_page()
if self.force_backend_text:
page.parsed_page = page._backend.get_segmented_page()
return page

View File

@@ -239,15 +239,18 @@ class LayoutPostprocessor:
final_clusters = self._sort_clusters(
self.regular_clusters + self.special_clusters, mode="id"
)
for cluster in final_clusters:
cluster.cells = self._sort_cells(cluster.cells)
# Also sort cells in children if any
for child in cluster.children:
child.cells = self._sort_cells(child.cells)
assert self.page.parsed_page is not None
self.page.parsed_page.textline_cells = self.cells
self.page.parsed_page.has_lines = len(self.cells) > 0
# Conditionally process cells if not skipping cell assignment
if not self.options.skip_cell_assignment:
for cluster in final_clusters:
cluster.cells = self._sort_cells(cluster.cells)
# Also sort cells in children if any
for child in cluster.children:
child.cells = self._sort_cells(child.cells)
assert self.page.parsed_page is not None
self.page.parsed_page.textline_cells = self.cells
self.page.parsed_page.has_lines = len(self.cells) > 0
return final_clusters, self.cells
@@ -264,36 +267,38 @@ class LayoutPostprocessor:
if cluster.label in self.LABEL_REMAPPING:
cluster.label = self.LABEL_REMAPPING[cluster.label]
# Initial cell assignment
clusters = self._assign_cells_to_clusters(clusters)
# Conditionally assign cells to clusters
if not self.options.skip_cell_assignment:
# Initial cell assignment
clusters = self._assign_cells_to_clusters(clusters)
# Remove clusters with no cells (if keep_empty_clusters is False),
# but always keep clusters with label DocItemLabel.FORMULA
if not self.options.keep_empty_clusters:
clusters = [
cluster
for cluster in clusters
if cluster.cells or cluster.label == DocItemLabel.FORMULA
]
# Remove clusters with no cells (if keep_empty_clusters is False),
# but always keep clusters with label DocItemLabel.FORMULA
if not self.options.keep_empty_clusters:
clusters = [
cluster
for cluster in clusters
if cluster.cells or cluster.label == DocItemLabel.FORMULA
]
# Handle orphaned cells
unassigned = self._find_unassigned_cells(clusters)
if unassigned and self.options.create_orphan_clusters:
next_id = max((c.id for c in self.all_clusters), default=0) + 1
orphan_clusters = []
for i, cell in enumerate(unassigned):
conf = cell.confidence
# Handle orphaned cells
unassigned = self._find_unassigned_cells(clusters)
if unassigned and self.options.create_orphan_clusters:
next_id = max((c.id for c in self.all_clusters), default=0) + 1
orphan_clusters = []
for i, cell in enumerate(unassigned):
conf = cell.confidence
orphan_clusters.append(
Cluster(
id=next_id + i,
label=DocItemLabel.TEXT,
bbox=cell.to_bounding_box(),
confidence=conf,
cells=[cell],
orphan_clusters.append(
Cluster(
id=next_id + i,
label=DocItemLabel.TEXT,
bbox=cell.to_bounding_box(),
confidence=conf,
cells=[cell],
)
)
)
clusters.extend(orphan_clusters)
clusters.extend(orphan_clusters)
# Iterative refinement
prev_count = len(clusters) + 1
@@ -350,12 +355,15 @@ class LayoutPostprocessor:
b=max(c.bbox.b for c in contained),
)
# Collect all cells from children
all_cells = []
for child in contained:
all_cells.extend(child.cells)
special.cells = self._deduplicate_cells(all_cells)
special.cells = self._sort_cells(special.cells)
# Conditionally collect cells from children
if not self.options.skip_cell_assignment:
all_cells = []
for child in contained:
all_cells.extend(child.cells)
special.cells = self._deduplicate_cells(all_cells)
special.cells = self._sort_cells(special.cells)
else:
special.cells = []
picture_clusters = [
c for c in special_clusters if c.label == DocItemLabel.PICTURE

View File

@@ -93,6 +93,7 @@ vlm = [
'transformers (>=4.46.0,<5.0.0)',
'accelerate (>=1.2.1,<2.0.0)',
'mlx-vlm (>=0.3.0,<1.0.0) ; python_version >= "3.10" and sys_platform == "darwin" and platform_machine == "arm64"',
'vllm (>=0.10.0,<1.0.0) ; python_version >= "3.10" and sys_platform == "linux"',
]
rapidocr = [
'rapidocr-onnxruntime (>=1.4.0,<2.0.0) ; python_version < "3.13"',
@@ -252,6 +253,7 @@ module = [
"huggingface_hub.*",
"transformers.*",
"pylatexenc.*",
"vllm.*",
]
ignore_missing_imports = true

2064
uv.lock generated

File diff suppressed because it is too large Load Diff