Implement proper batch inference for HuggingFaceTransformersVlmModel

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>
This commit is contained in:
Christoph Auer
2025-08-15 17:56:14 +02:00
parent 1aa522792a
commit f42676aab9
6 changed files with 193 additions and 95 deletions

View File

@@ -27,6 +27,7 @@ SMOLDOCLING_MLX = InlineVlmOptions(
supported_devices=[AcceleratorDevice.MPS],
scale=2.0,
temperature=0.0,
stop_strings=["</doctag>", "<end_of_utterance>"],
)
SMOLDOCLING_TRANSFORMERS = InlineVlmOptions(
@@ -42,6 +43,7 @@ SMOLDOCLING_TRANSFORMERS = InlineVlmOptions(
],
scale=2.0,
temperature=0.0,
stop_strings=["</doctag>", "<end_of_utterance>"],
)
SMOLDOCLING_VLLM = InlineVlmOptions(
@@ -56,8 +58,55 @@ SMOLDOCLING_VLLM = InlineVlmOptions(
],
scale=2.0,
temperature=0.0,
stop_strings=["</doctag>", "<end_of_utterance>"],
)
# SmolVLM-500-Instruct
SMOLVLM500_TRANSFORMERS = InlineVlmOptions(
repo_id="HuggingFaceTB/SmolVLM-500M-Instruct",
prompt="Transcribe this image to plain text.",
response_format=ResponseFormat.DOCTAGS,
inference_framework=InferenceFramework.TRANSFORMERS,
transformers_model_type=TransformersModelType.AUTOMODEL_IMAGETEXTTOTEXT,
supported_devices=[
AcceleratorDevice.CPU,
AcceleratorDevice.CUDA,
AcceleratorDevice.MPS,
],
scale=2.0,
temperature=0.0,
)
# SmolVLM-500-Instruct
SMOLVLM500_MLX = InlineVlmOptions(
repo_id="moot20/SmolVLM-500M-Instruct-MLX",
prompt="Transcribe this image to plain text.",
response_format=ResponseFormat.DOCTAGS,
inference_framework=InferenceFramework.MLX,
transformers_model_type=TransformersModelType.AUTOMODEL_IMAGETEXTTOTEXT,
supported_devices=[
AcceleratorDevice.MPS,
],
scale=2.0,
temperature=0.0,
)
SMOLVLM500_VLLM = InlineVlmOptions(
repo_id="HuggingFaceTB/SmolVLM-500M-Instruct",
prompt="Transcribe this image to plain text.",
response_format=ResponseFormat.DOCTAGS,
inference_framework=InferenceFramework.VLLM,
transformers_model_type=TransformersModelType.AUTOMODEL_IMAGETEXTTOTEXT,
supported_devices=[
AcceleratorDevice.CPU,
AcceleratorDevice.CUDA,
AcceleratorDevice.MPS,
],
scale=2.0,
temperature=0.0,
)
# GraniteVision
GRANITE_VISION_TRANSFORMERS = InlineVlmOptions(
repo_id="ibm-granite/granite-vision-3.2-2b",

View File

@@ -316,7 +316,7 @@ class ReadingOrderModel:
current_list = None
new_item = out_doc.add_text(
label=DocItemLabel.FORMULA, text="", orig=cap_text, prov=prov
label=DocItemLabel.FORMULA, text=cap_text, orig=cap_text, prov=prov
)
else:
current_list = None

View File

@@ -7,6 +7,7 @@ from typing import Any, Optional, Union
import numpy as np
from PIL.Image import Image
from transformers import StoppingCriteriaList, StopStringCriteria
from docling.datamodel.accelerator_options import (
AcceleratorOptions,
@@ -227,109 +228,119 @@ class HuggingFaceTransformersVlmModel(BaseVlmPageModel, HuggingFaceModelDownload
image_batch: Iterable[Union[Image, np.ndarray]],
prompt: Union[str, list[str]],
) -> Iterable[VlmPrediction]:
"""Process raw images without page metadata in a single batched inference call.
Args:
image_batch: Iterable of PIL Images or numpy arrays
prompt: Either:
- str: Single prompt used for all images
- list[str]: List of prompts (one per image, must match image count)
Raises:
ValueError: If prompt list length doesn't match image count.
"""
pil_images: list[Image] = []
for img in image_batch:
# Convert numpy array to PIL Image if needed
if isinstance(img, np.ndarray):
if img.ndim == 3 and img.shape[2] in [3, 4]:
Batched inference for Hugging Face Image-Text-to-Text VLMs (e.g., SmolDocling / SmolVLM).
- Lets the processor handle all padding & batching for text+images.
- Trims generated sequences per row using attention_mask (no pad-id fallbacks).
- Keeps your formulate_prompt() exactly as-is.
"""
import numpy as np
import torch
from PIL import Image as PILImage
# -- Normalize images to RGB PIL (SmolDocling & friends accept PIL/np via processor)
pil_images: list[Image] = []
for img in image_batch:
if isinstance(img, np.ndarray):
if img.ndim == 3 and img.shape[2] in (3, 4):
pil_img = PILImage.fromarray(img.astype(np.uint8))
elif img.ndim == 2:
from PIL import Image as PILImage
pil_img = PILImage.fromarray(img.astype(np.uint8), mode="L")
else:
raise ValueError(f"Unsupported numpy array shape: {img.shape}")
else:
pil_img = img
# Ensure image is in RGB mode (handles RGBA, L, etc.)
if pil_img.mode != "RGB":
pil_img = pil_img.convert("RGB")
pil_images.append(pil_img)
if len(pil_images) == 0:
if not pil_images:
return
# Handle prompt parameter
# -- Normalize prompts (1 per image)
if isinstance(prompt, str):
# Single prompt for all images
user_prompts = [prompt] * len(pil_images)
elif isinstance(prompt, list):
# List of prompts (one per image)
else:
if len(prompt) != len(pil_images):
raise ValueError(
f"Number of prompts ({len(prompt)}) must match number of images ({len(pil_images)})"
)
user_prompts = prompt
else:
raise ValueError(f"prompt must be str or list[str], got {type(prompt)}")
# Format prompts individually
prompts: list[str] = [
self.formulate_prompt(user_prompt) for user_prompt in user_prompts
]
# Use your prompt formatter verbatim
prompts: list[str] = [self.formulate_prompt(p) for p in user_prompts]
# -- Processor performs BOTH text+image preprocessing + batch padding (recommended)
inputs = self.processor(
text=prompts, images=pil_images, return_tensors="pt", padding=True
).to(self.device)
text=prompts,
images=pil_images,
return_tensors="pt",
padding=True, # pad across batch for both text and vision
# no truncation by default; match SmolDocling examples
)
inputs = {
k: (v.to(self.device) if hasattr(v, "to") else v) for k, v in inputs.items()
}
# -- Optional stopping criteria
stopping_criteria = None
if self.vlm_options.stop_strings:
stopping_criteria = StoppingCriteriaList(
[
StopStringCriteria(
stop_strings=self.vlm_options.stop_strings,
tokenizer=self.processor.tokenizer,
)
]
)
# -- Generate (Image-Text-to-Text class expects these inputs from processor)
gen_kwargs = {
**inputs,
"max_new_tokens": self.max_new_tokens,
"use_cache": self.use_cache,
"generation_config": self.generation_config,
"temperature": self.temperature,
**self.vlm_options.extra_generation_config,
}
if stopping_criteria is not None:
gen_kwargs["stopping_criteria"] = stopping_criteria
start_time = time.time()
generated_ids = self.vlm_model.generate(
**inputs,
max_new_tokens=self.max_new_tokens,
use_cache=self.use_cache,
# temperature=self.temperature,
generation_config=self.generation_config,
**self.vlm_options.extra_generation_config,
)
with torch.no_grad():
generated_ids = self.vlm_model.generate(**gen_kwargs)
generation_time = time.time() - start_time
# Determine per-sample prompt lengths
try:
attention_mask = inputs["attention_mask"]
input_lengths: list[int] = attention_mask.sum(dim=1).tolist()
except KeyError:
tokenizer = (
self.processor.tokenizer
) # Expect tokenizer to be present when text is provided
pad_token_id = tokenizer.pad_token_id
if pad_token_id is not None:
input_lengths = (
(inputs["input_ids"] != pad_token_id).sum(dim=1).tolist()
# -- Trim per sample using attention_mask (robust for batched prompts)
if "attention_mask" not in inputs:
raise RuntimeError(
"Processor did not return 'attention_mask'. Ensure padding=True and text tokenization are enabled."
)
else:
# Fallback: assume uniform prompt length (least accurate but preserves execution)
uniform_len = int(inputs["input_ids"].shape[1])
input_lengths = [uniform_len] * int(inputs["input_ids"].shape[0])
input_lengths = inputs["attention_mask"].sum(dim=1).tolist()
trimmed_sequences: list[list[int]] = [
generated_ids[i, int(input_lengths[i]) :].tolist()
for i in range(generated_ids.shape[0])
]
decoded_texts: list[str] = self.processor.batch_decode(
# -- Decode with the processor/tokenizer (skip specials, keep DocTags as text)
decode_fn = getattr(self.processor, "batch_decode", None)
if decode_fn is None and getattr(self.processor, "tokenizer", None) is not None:
decode_fn = self.processor.tokenizer.batch_decode
if decode_fn is None:
raise RuntimeError(
"Neither processor.batch_decode nor tokenizer.batch_decode is available."
)
decoded_texts: list[str] = decode_fn(
trimmed_sequences, skip_special_tokens=True
)
# Logging tokens count for the first sample as a representative metric
# -- Optional logging
if generated_ids.shape[0] > 0:
num_tokens = int(generated_ids[0].shape[0])
_log.debug(
f"Generated {num_tokens} tokens in time {generation_time:.2f} seconds."
f"Generated {int(generated_ids[0].shape[0])} tokens in {generation_time:.2f}s "
f"for batch size {generated_ids.shape[0]}."
)
for text in decoded_texts:

View File

@@ -142,8 +142,6 @@ class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
Raises:
ValueError: If prompt list length doesn't match image count.
"""
from mlx_vlm import generate
# Convert image batch to list for length validation
image_list = list(image_batch)
@@ -194,33 +192,67 @@ class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
self.processor, self.config, user_prompt, num_images=1
)
# Generate text from the image - MLX can accept PIL Images directly despite type annotations
# Stream generate with stop strings support
start_time = time.time()
generated_result = generate(
_log.debug("start generating ...")
tokens: list[VlmPredictionToken] = []
output = ""
# Use stream_generate for proper stop string handling
for token in self.stream_generate(
self.vlm_model,
self.processor,
formatted_prompt,
image=image, # Pass PIL Image directly - much more efficient than disk I/O
[image], # MLX stream_generate expects list of images
max_tokens=self.max_tokens,
verbose=False,
temp=self.temperature,
max_tokens=self.max_tokens,
):
# Collect token information
if len(token.logprobs.shape) == 1:
tokens.append(
VlmPredictionToken(
text=token.text,
token=token.token,
logprob=token.logprobs[token.token],
)
)
elif (
len(token.logprobs.shape) == 2 and token.logprobs.shape[0] == 1
):
tokens.append(
VlmPredictionToken(
text=token.text,
token=token.token,
logprob=token.logprobs[0, token.token],
)
generation_time = time.time() - start_time
# MLX generate returns a tuple (text, info_dict), extract just the text
if isinstance(generated_result, tuple):
generated_text = generated_result[0]
_log.debug(
f"MLX generate returned tuple with additional info: {generated_result[1] if len(generated_result) > 1 else 'N/A'}"
)
else:
generated_text = generated_result
_log.warning(
f"incompatible shape for logprobs: {token.logprobs.shape}"
)
output += token.text
# Check for any configured stop strings
if self.vlm_options.stop_strings:
if any(
stop_str in output
for stop_str in self.vlm_options.stop_strings
):
_log.debug("Stopping generation due to stop string match")
break
generation_time = time.time() - start_time
_log.debug(
f"{generation_time:.2f} seconds for {len(tokens)} tokens ({len(tokens) / generation_time:.1f} tokens/sec)."
)
_log.debug(f"Generated text in {generation_time:.2f}s.")
yield VlmPrediction(
text=generated_text,
text=output,
generation_time=generation_time,
# MLX generate doesn't expose tokens directly, so we leave it empty
generated_tokens=[],
generated_tokens=tokens,
)
_log.debug("MLX model: Released global lock")

View File

@@ -57,6 +57,8 @@ from docling.datamodel.vlm_model_specs import (
DOLPHIN_TRANSFORMERS,
SMOLDOCLING_MLX,
SMOLDOCLING_TRANSFORMERS,
SMOLVLM500_MLX,
SMOLVLM500_TRANSFORMERS,
)
from docling.models.layout_model import LayoutModel
from docling.models.page_assemble_model import PageAssembleModel, PageAssembleOptions
@@ -159,21 +161,23 @@ class ThreadedMultiStageVlmPipelineOptions(PaginatedPipelineOptions):
# text_opts = DOLPHIN_TRANSFORMERS.model_copy()
# text_opts.prompt = "<s>Read text in the image. <Answer/>"
base_model = SMOLDOCLING_TRANSFORMERS
formula_opts = base_model.model_copy()
formula_opts.prompt = "Convert formula to latex."
formula_opts.response_format = ResponseFormat.OTSL
code_opts = base_model.model_copy()
code_opts.prompt = "Convert code to text."
code_opts.response_format = ResponseFormat.OTSL
base_model = SMOLVLM500_TRANSFORMERS
text_opts = base_model.model_copy()
text_opts.prompt = "Convert this page to docling."
text_opts.response_format = ResponseFormat.OTSL
# text_opts.prompt = "Convert this page to docling."
text_opts.prompt = "What does this say?"
text_opts.response_format = ResponseFormat.PLAINTEXT
table_opts = base_model.model_copy()
formula_opts = base_model.model_copy()
# formula_opts.prompt = "Convert formula to latex."
formula_opts.prompt = "What does this say?"
formula_opts.response_format = ResponseFormat.PLAINTEXT
code_opts = SMOLDOCLING_TRANSFORMERS.model_copy()
code_opts.prompt = "Convert code to text."
code_opts.response_format = ResponseFormat.DOCTAGS
table_opts = SMOLDOCLING_TRANSFORMERS.model_copy()
table_opts.prompt = "Convert this table to OTSL."
table_opts.response_format = ResponseFormat.OTSL

View File

@@ -128,6 +128,8 @@ class VlmPipeline(PaginatedPipeline):
page._backend = conv_res.input._backend.load_page(page.page_no) # type: ignore
if page._backend is not None and page._backend.is_valid():
page.size = page._backend.get_size()
if self.force_backend_text:
page.parsed_page = page._backend.get_segmented_page()
return page