Implement proper batch inference for HuggingFaceTransformersVlmModel

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>
This commit is contained in:
Christoph Auer
2025-08-15 17:56:14 +02:00
parent 1aa522792a
commit f42676aab9
6 changed files with 193 additions and 95 deletions

View File

@@ -27,6 +27,7 @@ SMOLDOCLING_MLX = InlineVlmOptions(
supported_devices=[AcceleratorDevice.MPS], supported_devices=[AcceleratorDevice.MPS],
scale=2.0, scale=2.0,
temperature=0.0, temperature=0.0,
stop_strings=["</doctag>", "<end_of_utterance>"],
) )
SMOLDOCLING_TRANSFORMERS = InlineVlmOptions( SMOLDOCLING_TRANSFORMERS = InlineVlmOptions(
@@ -42,6 +43,7 @@ SMOLDOCLING_TRANSFORMERS = InlineVlmOptions(
], ],
scale=2.0, scale=2.0,
temperature=0.0, temperature=0.0,
stop_strings=["</doctag>", "<end_of_utterance>"],
) )
SMOLDOCLING_VLLM = InlineVlmOptions( SMOLDOCLING_VLLM = InlineVlmOptions(
@@ -56,8 +58,55 @@ SMOLDOCLING_VLLM = InlineVlmOptions(
], ],
scale=2.0, scale=2.0,
temperature=0.0, temperature=0.0,
stop_strings=["</doctag>", "<end_of_utterance>"],
) )
# SmolVLM-500-Instruct
SMOLVLM500_TRANSFORMERS = InlineVlmOptions(
repo_id="HuggingFaceTB/SmolVLM-500M-Instruct",
prompt="Transcribe this image to plain text.",
response_format=ResponseFormat.DOCTAGS,
inference_framework=InferenceFramework.TRANSFORMERS,
transformers_model_type=TransformersModelType.AUTOMODEL_IMAGETEXTTOTEXT,
supported_devices=[
AcceleratorDevice.CPU,
AcceleratorDevice.CUDA,
AcceleratorDevice.MPS,
],
scale=2.0,
temperature=0.0,
)
# SmolVLM-500-Instruct
SMOLVLM500_MLX = InlineVlmOptions(
repo_id="moot20/SmolVLM-500M-Instruct-MLX",
prompt="Transcribe this image to plain text.",
response_format=ResponseFormat.DOCTAGS,
inference_framework=InferenceFramework.MLX,
transformers_model_type=TransformersModelType.AUTOMODEL_IMAGETEXTTOTEXT,
supported_devices=[
AcceleratorDevice.MPS,
],
scale=2.0,
temperature=0.0,
)
SMOLVLM500_VLLM = InlineVlmOptions(
repo_id="HuggingFaceTB/SmolVLM-500M-Instruct",
prompt="Transcribe this image to plain text.",
response_format=ResponseFormat.DOCTAGS,
inference_framework=InferenceFramework.VLLM,
transformers_model_type=TransformersModelType.AUTOMODEL_IMAGETEXTTOTEXT,
supported_devices=[
AcceleratorDevice.CPU,
AcceleratorDevice.CUDA,
AcceleratorDevice.MPS,
],
scale=2.0,
temperature=0.0,
)
# GraniteVision # GraniteVision
GRANITE_VISION_TRANSFORMERS = InlineVlmOptions( GRANITE_VISION_TRANSFORMERS = InlineVlmOptions(
repo_id="ibm-granite/granite-vision-3.2-2b", repo_id="ibm-granite/granite-vision-3.2-2b",

View File

@@ -316,7 +316,7 @@ class ReadingOrderModel:
current_list = None current_list = None
new_item = out_doc.add_text( new_item = out_doc.add_text(
label=DocItemLabel.FORMULA, text="", orig=cap_text, prov=prov label=DocItemLabel.FORMULA, text=cap_text, orig=cap_text, prov=prov
) )
else: else:
current_list = None current_list = None

View File

@@ -7,6 +7,7 @@ from typing import Any, Optional, Union
import numpy as np import numpy as np
from PIL.Image import Image from PIL.Image import Image
from transformers import StoppingCriteriaList, StopStringCriteria
from docling.datamodel.accelerator_options import ( from docling.datamodel.accelerator_options import (
AcceleratorOptions, AcceleratorOptions,
@@ -227,109 +228,119 @@ class HuggingFaceTransformersVlmModel(BaseVlmPageModel, HuggingFaceModelDownload
image_batch: Iterable[Union[Image, np.ndarray]], image_batch: Iterable[Union[Image, np.ndarray]],
prompt: Union[str, list[str]], prompt: Union[str, list[str]],
) -> Iterable[VlmPrediction]: ) -> Iterable[VlmPrediction]:
"""Process raw images without page metadata in a single batched inference call.
Args:
image_batch: Iterable of PIL Images or numpy arrays
prompt: Either:
- str: Single prompt used for all images
- list[str]: List of prompts (one per image, must match image count)
Raises:
ValueError: If prompt list length doesn't match image count.
""" """
Batched inference for Hugging Face Image-Text-to-Text VLMs (e.g., SmolDocling / SmolVLM).
- Lets the processor handle all padding & batching for text+images.
- Trims generated sequences per row using attention_mask (no pad-id fallbacks).
- Keeps your formulate_prompt() exactly as-is.
"""
import numpy as np
import torch
from PIL import Image as PILImage
# -- Normalize images to RGB PIL (SmolDocling & friends accept PIL/np via processor)
pil_images: list[Image] = [] pil_images: list[Image] = []
for img in image_batch: for img in image_batch:
# Convert numpy array to PIL Image if needed
if isinstance(img, np.ndarray): if isinstance(img, np.ndarray):
if img.ndim == 3 and img.shape[2] in [3, 4]: if img.ndim == 3 and img.shape[2] in (3, 4):
from PIL import Image as PILImage
pil_img = PILImage.fromarray(img.astype(np.uint8)) pil_img = PILImage.fromarray(img.astype(np.uint8))
elif img.ndim == 2: elif img.ndim == 2:
from PIL import Image as PILImage
pil_img = PILImage.fromarray(img.astype(np.uint8), mode="L") pil_img = PILImage.fromarray(img.astype(np.uint8), mode="L")
else: else:
raise ValueError(f"Unsupported numpy array shape: {img.shape}") raise ValueError(f"Unsupported numpy array shape: {img.shape}")
else: else:
pil_img = img pil_img = img
# Ensure image is in RGB mode (handles RGBA, L, etc.)
if pil_img.mode != "RGB": if pil_img.mode != "RGB":
pil_img = pil_img.convert("RGB") pil_img = pil_img.convert("RGB")
pil_images.append(pil_img) pil_images.append(pil_img)
if len(pil_images) == 0: if not pil_images:
return return
# Handle prompt parameter # -- Normalize prompts (1 per image)
if isinstance(prompt, str): if isinstance(prompt, str):
# Single prompt for all images
user_prompts = [prompt] * len(pil_images) user_prompts = [prompt] * len(pil_images)
elif isinstance(prompt, list): else:
# List of prompts (one per image)
if len(prompt) != len(pil_images): if len(prompt) != len(pil_images):
raise ValueError( raise ValueError(
f"Number of prompts ({len(prompt)}) must match number of images ({len(pil_images)})" f"Number of prompts ({len(prompt)}) must match number of images ({len(pil_images)})"
) )
user_prompts = prompt user_prompts = prompt
else:
raise ValueError(f"prompt must be str or list[str], got {type(prompt)}")
# Format prompts individually # Use your prompt formatter verbatim
prompts: list[str] = [ prompts: list[str] = [self.formulate_prompt(p) for p in user_prompts]
self.formulate_prompt(user_prompt) for user_prompt in user_prompts
]
# -- Processor performs BOTH text+image preprocessing + batch padding (recommended)
inputs = self.processor( inputs = self.processor(
text=prompts, images=pil_images, return_tensors="pt", padding=True text=prompts,
).to(self.device) images=pil_images,
return_tensors="pt",
padding=True, # pad across batch for both text and vision
# no truncation by default; match SmolDocling examples
)
inputs = {
k: (v.to(self.device) if hasattr(v, "to") else v) for k, v in inputs.items()
}
# -- Optional stopping criteria
stopping_criteria = None
if self.vlm_options.stop_strings:
stopping_criteria = StoppingCriteriaList(
[
StopStringCriteria(
stop_strings=self.vlm_options.stop_strings,
tokenizer=self.processor.tokenizer,
)
]
)
# -- Generate (Image-Text-to-Text class expects these inputs from processor)
gen_kwargs = {
**inputs,
"max_new_tokens": self.max_new_tokens,
"use_cache": self.use_cache,
"generation_config": self.generation_config,
"temperature": self.temperature,
**self.vlm_options.extra_generation_config,
}
if stopping_criteria is not None:
gen_kwargs["stopping_criteria"] = stopping_criteria
start_time = time.time() start_time = time.time()
generated_ids = self.vlm_model.generate( with torch.no_grad():
**inputs, generated_ids = self.vlm_model.generate(**gen_kwargs)
max_new_tokens=self.max_new_tokens,
use_cache=self.use_cache,
# temperature=self.temperature,
generation_config=self.generation_config,
**self.vlm_options.extra_generation_config,
)
generation_time = time.time() - start_time generation_time = time.time() - start_time
# Determine per-sample prompt lengths # -- Trim per sample using attention_mask (robust for batched prompts)
try: if "attention_mask" not in inputs:
attention_mask = inputs["attention_mask"] raise RuntimeError(
input_lengths: list[int] = attention_mask.sum(dim=1).tolist() "Processor did not return 'attention_mask'. Ensure padding=True and text tokenization are enabled."
except KeyError: )
tokenizer = ( input_lengths = inputs["attention_mask"].sum(dim=1).tolist()
self.processor.tokenizer
) # Expect tokenizer to be present when text is provided
pad_token_id = tokenizer.pad_token_id
if pad_token_id is not None:
input_lengths = (
(inputs["input_ids"] != pad_token_id).sum(dim=1).tolist()
)
else:
# Fallback: assume uniform prompt length (least accurate but preserves execution)
uniform_len = int(inputs["input_ids"].shape[1])
input_lengths = [uniform_len] * int(inputs["input_ids"].shape[0])
trimmed_sequences: list[list[int]] = [ trimmed_sequences: list[list[int]] = [
generated_ids[i, int(input_lengths[i]) :].tolist() generated_ids[i, int(input_lengths[i]) :].tolist()
for i in range(generated_ids.shape[0]) for i in range(generated_ids.shape[0])
] ]
decoded_texts: list[str] = self.processor.batch_decode(
# -- Decode with the processor/tokenizer (skip specials, keep DocTags as text)
decode_fn = getattr(self.processor, "batch_decode", None)
if decode_fn is None and getattr(self.processor, "tokenizer", None) is not None:
decode_fn = self.processor.tokenizer.batch_decode
if decode_fn is None:
raise RuntimeError(
"Neither processor.batch_decode nor tokenizer.batch_decode is available."
)
decoded_texts: list[str] = decode_fn(
trimmed_sequences, skip_special_tokens=True trimmed_sequences, skip_special_tokens=True
) )
# Logging tokens count for the first sample as a representative metric # -- Optional logging
if generated_ids.shape[0] > 0: if generated_ids.shape[0] > 0:
num_tokens = int(generated_ids[0].shape[0])
_log.debug( _log.debug(
f"Generated {num_tokens} tokens in time {generation_time:.2f} seconds." f"Generated {int(generated_ids[0].shape[0])} tokens in {generation_time:.2f}s "
f"for batch size {generated_ids.shape[0]}."
) )
for text in decoded_texts: for text in decoded_texts:

View File

@@ -142,8 +142,6 @@ class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
Raises: Raises:
ValueError: If prompt list length doesn't match image count. ValueError: If prompt list length doesn't match image count.
""" """
from mlx_vlm import generate
# Convert image batch to list for length validation # Convert image batch to list for length validation
image_list = list(image_batch) image_list = list(image_batch)
@@ -194,33 +192,67 @@ class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
self.processor, self.config, user_prompt, num_images=1 self.processor, self.config, user_prompt, num_images=1
) )
# Generate text from the image - MLX can accept PIL Images directly despite type annotations # Stream generate with stop strings support
start_time = time.time() start_time = time.time()
generated_result = generate( _log.debug("start generating ...")
tokens: list[VlmPredictionToken] = []
output = ""
# Use stream_generate for proper stop string handling
for token in self.stream_generate(
self.vlm_model, self.vlm_model,
self.processor, self.processor,
formatted_prompt, formatted_prompt,
image=image, # Pass PIL Image directly - much more efficient than disk I/O [image], # MLX stream_generate expects list of images
max_tokens=self.max_tokens,
verbose=False, verbose=False,
temp=self.temperature, temp=self.temperature,
max_tokens=self.max_tokens, ):
) # Collect token information
if len(token.logprobs.shape) == 1:
tokens.append(
VlmPredictionToken(
text=token.text,
token=token.token,
logprob=token.logprobs[token.token],
)
)
elif (
len(token.logprobs.shape) == 2 and token.logprobs.shape[0] == 1
):
tokens.append(
VlmPredictionToken(
text=token.text,
token=token.token,
logprob=token.logprobs[0, token.token],
)
)
else:
_log.warning(
f"incompatible shape for logprobs: {token.logprobs.shape}"
)
output += token.text
# Check for any configured stop strings
if self.vlm_options.stop_strings:
if any(
stop_str in output
for stop_str in self.vlm_options.stop_strings
):
_log.debug("Stopping generation due to stop string match")
break
generation_time = time.time() - start_time generation_time = time.time() - start_time
# MLX generate returns a tuple (text, info_dict), extract just the text _log.debug(
if isinstance(generated_result, tuple): f"{generation_time:.2f} seconds for {len(tokens)} tokens ({len(tokens) / generation_time:.1f} tokens/sec)."
generated_text = generated_result[0] )
_log.debug(
f"MLX generate returned tuple with additional info: {generated_result[1] if len(generated_result) > 1 else 'N/A'}"
)
else:
generated_text = generated_result
_log.debug(f"Generated text in {generation_time:.2f}s.")
yield VlmPrediction( yield VlmPrediction(
text=generated_text, text=output,
generation_time=generation_time, generation_time=generation_time,
# MLX generate doesn't expose tokens directly, so we leave it empty generated_tokens=tokens,
generated_tokens=[],
) )
_log.debug("MLX model: Released global lock") _log.debug("MLX model: Released global lock")

View File

@@ -57,6 +57,8 @@ from docling.datamodel.vlm_model_specs import (
DOLPHIN_TRANSFORMERS, DOLPHIN_TRANSFORMERS,
SMOLDOCLING_MLX, SMOLDOCLING_MLX,
SMOLDOCLING_TRANSFORMERS, SMOLDOCLING_TRANSFORMERS,
SMOLVLM500_MLX,
SMOLVLM500_TRANSFORMERS,
) )
from docling.models.layout_model import LayoutModel from docling.models.layout_model import LayoutModel
from docling.models.page_assemble_model import PageAssembleModel, PageAssembleOptions from docling.models.page_assemble_model import PageAssembleModel, PageAssembleOptions
@@ -159,21 +161,23 @@ class ThreadedMultiStageVlmPipelineOptions(PaginatedPipelineOptions):
# text_opts = DOLPHIN_TRANSFORMERS.model_copy() # text_opts = DOLPHIN_TRANSFORMERS.model_copy()
# text_opts.prompt = "<s>Read text in the image. <Answer/>" # text_opts.prompt = "<s>Read text in the image. <Answer/>"
base_model = SMOLDOCLING_TRANSFORMERS base_model = SMOLVLM500_TRANSFORMERS
formula_opts = base_model.model_copy()
formula_opts.prompt = "Convert formula to latex."
formula_opts.response_format = ResponseFormat.OTSL
code_opts = base_model.model_copy()
code_opts.prompt = "Convert code to text."
code_opts.response_format = ResponseFormat.OTSL
text_opts = base_model.model_copy() text_opts = base_model.model_copy()
text_opts.prompt = "Convert this page to docling." # text_opts.prompt = "Convert this page to docling."
text_opts.response_format = ResponseFormat.OTSL text_opts.prompt = "What does this say?"
text_opts.response_format = ResponseFormat.PLAINTEXT
table_opts = base_model.model_copy() formula_opts = base_model.model_copy()
# formula_opts.prompt = "Convert formula to latex."
formula_opts.prompt = "What does this say?"
formula_opts.response_format = ResponseFormat.PLAINTEXT
code_opts = SMOLDOCLING_TRANSFORMERS.model_copy()
code_opts.prompt = "Convert code to text."
code_opts.response_format = ResponseFormat.DOCTAGS
table_opts = SMOLDOCLING_TRANSFORMERS.model_copy()
table_opts.prompt = "Convert this table to OTSL." table_opts.prompt = "Convert this table to OTSL."
table_opts.response_format = ResponseFormat.OTSL table_opts.response_format = ResponseFormat.OTSL

View File

@@ -128,7 +128,9 @@ class VlmPipeline(PaginatedPipeline):
page._backend = conv_res.input._backend.load_page(page.page_no) # type: ignore page._backend = conv_res.input._backend.load_page(page.page_no) # type: ignore
if page._backend is not None and page._backend.is_valid(): if page._backend is not None and page._backend.is_valid():
page.size = page._backend.get_size() page.size = page._backend.get_size()
page.parsed_page = page._backend.get_segmented_page()
if self.force_backend_text:
page.parsed_page = page._backend.get_segmented_page()
return page return page