mirror of
https://github.com/DS4SD/docling.git
synced 2025-12-08 12:48:28 +00:00
Implement proper batch inference for HuggingFaceTransformersVlmModel
Signed-off-by: Christoph Auer <cau@zurich.ibm.com>
This commit is contained in:
@@ -27,6 +27,7 @@ SMOLDOCLING_MLX = InlineVlmOptions(
|
||||
supported_devices=[AcceleratorDevice.MPS],
|
||||
scale=2.0,
|
||||
temperature=0.0,
|
||||
stop_strings=["</doctag>", "<end_of_utterance>"],
|
||||
)
|
||||
|
||||
SMOLDOCLING_TRANSFORMERS = InlineVlmOptions(
|
||||
@@ -42,6 +43,7 @@ SMOLDOCLING_TRANSFORMERS = InlineVlmOptions(
|
||||
],
|
||||
scale=2.0,
|
||||
temperature=0.0,
|
||||
stop_strings=["</doctag>", "<end_of_utterance>"],
|
||||
)
|
||||
|
||||
SMOLDOCLING_VLLM = InlineVlmOptions(
|
||||
@@ -56,8 +58,55 @@ SMOLDOCLING_VLLM = InlineVlmOptions(
|
||||
],
|
||||
scale=2.0,
|
||||
temperature=0.0,
|
||||
stop_strings=["</doctag>", "<end_of_utterance>"],
|
||||
)
|
||||
|
||||
# SmolVLM-500-Instruct
|
||||
SMOLVLM500_TRANSFORMERS = InlineVlmOptions(
|
||||
repo_id="HuggingFaceTB/SmolVLM-500M-Instruct",
|
||||
prompt="Transcribe this image to plain text.",
|
||||
response_format=ResponseFormat.DOCTAGS,
|
||||
inference_framework=InferenceFramework.TRANSFORMERS,
|
||||
transformers_model_type=TransformersModelType.AUTOMODEL_IMAGETEXTTOTEXT,
|
||||
supported_devices=[
|
||||
AcceleratorDevice.CPU,
|
||||
AcceleratorDevice.CUDA,
|
||||
AcceleratorDevice.MPS,
|
||||
],
|
||||
scale=2.0,
|
||||
temperature=0.0,
|
||||
)
|
||||
|
||||
# SmolVLM-500-Instruct
|
||||
SMOLVLM500_MLX = InlineVlmOptions(
|
||||
repo_id="moot20/SmolVLM-500M-Instruct-MLX",
|
||||
prompt="Transcribe this image to plain text.",
|
||||
response_format=ResponseFormat.DOCTAGS,
|
||||
inference_framework=InferenceFramework.MLX,
|
||||
transformers_model_type=TransformersModelType.AUTOMODEL_IMAGETEXTTOTEXT,
|
||||
supported_devices=[
|
||||
AcceleratorDevice.MPS,
|
||||
],
|
||||
scale=2.0,
|
||||
temperature=0.0,
|
||||
)
|
||||
|
||||
SMOLVLM500_VLLM = InlineVlmOptions(
|
||||
repo_id="HuggingFaceTB/SmolVLM-500M-Instruct",
|
||||
prompt="Transcribe this image to plain text.",
|
||||
response_format=ResponseFormat.DOCTAGS,
|
||||
inference_framework=InferenceFramework.VLLM,
|
||||
transformers_model_type=TransformersModelType.AUTOMODEL_IMAGETEXTTOTEXT,
|
||||
supported_devices=[
|
||||
AcceleratorDevice.CPU,
|
||||
AcceleratorDevice.CUDA,
|
||||
AcceleratorDevice.MPS,
|
||||
],
|
||||
scale=2.0,
|
||||
temperature=0.0,
|
||||
)
|
||||
|
||||
|
||||
# GraniteVision
|
||||
GRANITE_VISION_TRANSFORMERS = InlineVlmOptions(
|
||||
repo_id="ibm-granite/granite-vision-3.2-2b",
|
||||
|
||||
@@ -316,7 +316,7 @@ class ReadingOrderModel:
|
||||
current_list = None
|
||||
|
||||
new_item = out_doc.add_text(
|
||||
label=DocItemLabel.FORMULA, text="", orig=cap_text, prov=prov
|
||||
label=DocItemLabel.FORMULA, text=cap_text, orig=cap_text, prov=prov
|
||||
)
|
||||
else:
|
||||
current_list = None
|
||||
|
||||
@@ -7,6 +7,7 @@ from typing import Any, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from PIL.Image import Image
|
||||
from transformers import StoppingCriteriaList, StopStringCriteria
|
||||
|
||||
from docling.datamodel.accelerator_options import (
|
||||
AcceleratorOptions,
|
||||
@@ -227,109 +228,119 @@ class HuggingFaceTransformersVlmModel(BaseVlmPageModel, HuggingFaceModelDownload
|
||||
image_batch: Iterable[Union[Image, np.ndarray]],
|
||||
prompt: Union[str, list[str]],
|
||||
) -> Iterable[VlmPrediction]:
|
||||
"""Process raw images without page metadata in a single batched inference call.
|
||||
|
||||
Args:
|
||||
image_batch: Iterable of PIL Images or numpy arrays
|
||||
prompt: Either:
|
||||
- str: Single prompt used for all images
|
||||
- list[str]: List of prompts (one per image, must match image count)
|
||||
|
||||
Raises:
|
||||
ValueError: If prompt list length doesn't match image count.
|
||||
"""
|
||||
Batched inference for Hugging Face Image-Text-to-Text VLMs (e.g., SmolDocling / SmolVLM).
|
||||
- Lets the processor handle all padding & batching for text+images.
|
||||
- Trims generated sequences per row using attention_mask (no pad-id fallbacks).
|
||||
- Keeps your formulate_prompt() exactly as-is.
|
||||
"""
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image as PILImage
|
||||
|
||||
# -- Normalize images to RGB PIL (SmolDocling & friends accept PIL/np via processor)
|
||||
pil_images: list[Image] = []
|
||||
|
||||
for img in image_batch:
|
||||
# Convert numpy array to PIL Image if needed
|
||||
if isinstance(img, np.ndarray):
|
||||
if img.ndim == 3 and img.shape[2] in [3, 4]:
|
||||
from PIL import Image as PILImage
|
||||
|
||||
if img.ndim == 3 and img.shape[2] in (3, 4):
|
||||
pil_img = PILImage.fromarray(img.astype(np.uint8))
|
||||
elif img.ndim == 2:
|
||||
from PIL import Image as PILImage
|
||||
|
||||
pil_img = PILImage.fromarray(img.astype(np.uint8), mode="L")
|
||||
else:
|
||||
raise ValueError(f"Unsupported numpy array shape: {img.shape}")
|
||||
else:
|
||||
pil_img = img
|
||||
|
||||
# Ensure image is in RGB mode (handles RGBA, L, etc.)
|
||||
if pil_img.mode != "RGB":
|
||||
pil_img = pil_img.convert("RGB")
|
||||
|
||||
pil_images.append(pil_img)
|
||||
|
||||
if len(pil_images) == 0:
|
||||
if not pil_images:
|
||||
return
|
||||
|
||||
# Handle prompt parameter
|
||||
# -- Normalize prompts (1 per image)
|
||||
if isinstance(prompt, str):
|
||||
# Single prompt for all images
|
||||
user_prompts = [prompt] * len(pil_images)
|
||||
elif isinstance(prompt, list):
|
||||
# List of prompts (one per image)
|
||||
else:
|
||||
if len(prompt) != len(pil_images):
|
||||
raise ValueError(
|
||||
f"Number of prompts ({len(prompt)}) must match number of images ({len(pil_images)})"
|
||||
)
|
||||
user_prompts = prompt
|
||||
else:
|
||||
raise ValueError(f"prompt must be str or list[str], got {type(prompt)}")
|
||||
|
||||
# Format prompts individually
|
||||
prompts: list[str] = [
|
||||
self.formulate_prompt(user_prompt) for user_prompt in user_prompts
|
||||
]
|
||||
# Use your prompt formatter verbatim
|
||||
prompts: list[str] = [self.formulate_prompt(p) for p in user_prompts]
|
||||
|
||||
# -- Processor performs BOTH text+image preprocessing + batch padding (recommended)
|
||||
inputs = self.processor(
|
||||
text=prompts, images=pil_images, return_tensors="pt", padding=True
|
||||
).to(self.device)
|
||||
text=prompts,
|
||||
images=pil_images,
|
||||
return_tensors="pt",
|
||||
padding=True, # pad across batch for both text and vision
|
||||
# no truncation by default; match SmolDocling examples
|
||||
)
|
||||
inputs = {
|
||||
k: (v.to(self.device) if hasattr(v, "to") else v) for k, v in inputs.items()
|
||||
}
|
||||
|
||||
# -- Optional stopping criteria
|
||||
stopping_criteria = None
|
||||
if self.vlm_options.stop_strings:
|
||||
stopping_criteria = StoppingCriteriaList(
|
||||
[
|
||||
StopStringCriteria(
|
||||
stop_strings=self.vlm_options.stop_strings,
|
||||
tokenizer=self.processor.tokenizer,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# -- Generate (Image-Text-to-Text class expects these inputs from processor)
|
||||
gen_kwargs = {
|
||||
**inputs,
|
||||
"max_new_tokens": self.max_new_tokens,
|
||||
"use_cache": self.use_cache,
|
||||
"generation_config": self.generation_config,
|
||||
"temperature": self.temperature,
|
||||
**self.vlm_options.extra_generation_config,
|
||||
}
|
||||
if stopping_criteria is not None:
|
||||
gen_kwargs["stopping_criteria"] = stopping_criteria
|
||||
|
||||
start_time = time.time()
|
||||
generated_ids = self.vlm_model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=self.max_new_tokens,
|
||||
use_cache=self.use_cache,
|
||||
# temperature=self.temperature,
|
||||
generation_config=self.generation_config,
|
||||
**self.vlm_options.extra_generation_config,
|
||||
)
|
||||
with torch.no_grad():
|
||||
generated_ids = self.vlm_model.generate(**gen_kwargs)
|
||||
generation_time = time.time() - start_time
|
||||
|
||||
# Determine per-sample prompt lengths
|
||||
try:
|
||||
attention_mask = inputs["attention_mask"]
|
||||
input_lengths: list[int] = attention_mask.sum(dim=1).tolist()
|
||||
except KeyError:
|
||||
tokenizer = (
|
||||
self.processor.tokenizer
|
||||
) # Expect tokenizer to be present when text is provided
|
||||
pad_token_id = tokenizer.pad_token_id
|
||||
if pad_token_id is not None:
|
||||
input_lengths = (
|
||||
(inputs["input_ids"] != pad_token_id).sum(dim=1).tolist()
|
||||
)
|
||||
else:
|
||||
# Fallback: assume uniform prompt length (least accurate but preserves execution)
|
||||
uniform_len = int(inputs["input_ids"].shape[1])
|
||||
input_lengths = [uniform_len] * int(inputs["input_ids"].shape[0])
|
||||
# -- Trim per sample using attention_mask (robust for batched prompts)
|
||||
if "attention_mask" not in inputs:
|
||||
raise RuntimeError(
|
||||
"Processor did not return 'attention_mask'. Ensure padding=True and text tokenization are enabled."
|
||||
)
|
||||
input_lengths = inputs["attention_mask"].sum(dim=1).tolist()
|
||||
|
||||
trimmed_sequences: list[list[int]] = [
|
||||
generated_ids[i, int(input_lengths[i]) :].tolist()
|
||||
for i in range(generated_ids.shape[0])
|
||||
]
|
||||
decoded_texts: list[str] = self.processor.batch_decode(
|
||||
|
||||
# -- Decode with the processor/tokenizer (skip specials, keep DocTags as text)
|
||||
decode_fn = getattr(self.processor, "batch_decode", None)
|
||||
if decode_fn is None and getattr(self.processor, "tokenizer", None) is not None:
|
||||
decode_fn = self.processor.tokenizer.batch_decode
|
||||
if decode_fn is None:
|
||||
raise RuntimeError(
|
||||
"Neither processor.batch_decode nor tokenizer.batch_decode is available."
|
||||
)
|
||||
|
||||
decoded_texts: list[str] = decode_fn(
|
||||
trimmed_sequences, skip_special_tokens=True
|
||||
)
|
||||
|
||||
# Logging tokens count for the first sample as a representative metric
|
||||
# -- Optional logging
|
||||
if generated_ids.shape[0] > 0:
|
||||
num_tokens = int(generated_ids[0].shape[0])
|
||||
_log.debug(
|
||||
f"Generated {num_tokens} tokens in time {generation_time:.2f} seconds."
|
||||
f"Generated {int(generated_ids[0].shape[0])} tokens in {generation_time:.2f}s "
|
||||
f"for batch size {generated_ids.shape[0]}."
|
||||
)
|
||||
|
||||
for text in decoded_texts:
|
||||
|
||||
@@ -142,8 +142,6 @@ class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
|
||||
Raises:
|
||||
ValueError: If prompt list length doesn't match image count.
|
||||
"""
|
||||
from mlx_vlm import generate
|
||||
|
||||
# Convert image batch to list for length validation
|
||||
image_list = list(image_batch)
|
||||
|
||||
@@ -194,33 +192,67 @@ class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
|
||||
self.processor, self.config, user_prompt, num_images=1
|
||||
)
|
||||
|
||||
# Generate text from the image - MLX can accept PIL Images directly despite type annotations
|
||||
# Stream generate with stop strings support
|
||||
start_time = time.time()
|
||||
generated_result = generate(
|
||||
_log.debug("start generating ...")
|
||||
|
||||
tokens: list[VlmPredictionToken] = []
|
||||
output = ""
|
||||
|
||||
# Use stream_generate for proper stop string handling
|
||||
for token in self.stream_generate(
|
||||
self.vlm_model,
|
||||
self.processor,
|
||||
formatted_prompt,
|
||||
image=image, # Pass PIL Image directly - much more efficient than disk I/O
|
||||
[image], # MLX stream_generate expects list of images
|
||||
max_tokens=self.max_tokens,
|
||||
verbose=False,
|
||||
temp=self.temperature,
|
||||
max_tokens=self.max_tokens,
|
||||
)
|
||||
):
|
||||
# Collect token information
|
||||
if len(token.logprobs.shape) == 1:
|
||||
tokens.append(
|
||||
VlmPredictionToken(
|
||||
text=token.text,
|
||||
token=token.token,
|
||||
logprob=token.logprobs[token.token],
|
||||
)
|
||||
)
|
||||
elif (
|
||||
len(token.logprobs.shape) == 2 and token.logprobs.shape[0] == 1
|
||||
):
|
||||
tokens.append(
|
||||
VlmPredictionToken(
|
||||
text=token.text,
|
||||
token=token.token,
|
||||
logprob=token.logprobs[0, token.token],
|
||||
)
|
||||
)
|
||||
else:
|
||||
_log.warning(
|
||||
f"incompatible shape for logprobs: {token.logprobs.shape}"
|
||||
)
|
||||
|
||||
output += token.text
|
||||
|
||||
# Check for any configured stop strings
|
||||
if self.vlm_options.stop_strings:
|
||||
if any(
|
||||
stop_str in output
|
||||
for stop_str in self.vlm_options.stop_strings
|
||||
):
|
||||
_log.debug("Stopping generation due to stop string match")
|
||||
break
|
||||
|
||||
generation_time = time.time() - start_time
|
||||
|
||||
# MLX generate returns a tuple (text, info_dict), extract just the text
|
||||
if isinstance(generated_result, tuple):
|
||||
generated_text = generated_result[0]
|
||||
_log.debug(
|
||||
f"MLX generate returned tuple with additional info: {generated_result[1] if len(generated_result) > 1 else 'N/A'}"
|
||||
)
|
||||
else:
|
||||
generated_text = generated_result
|
||||
_log.debug(
|
||||
f"{generation_time:.2f} seconds for {len(tokens)} tokens ({len(tokens) / generation_time:.1f} tokens/sec)."
|
||||
)
|
||||
|
||||
_log.debug(f"Generated text in {generation_time:.2f}s.")
|
||||
yield VlmPrediction(
|
||||
text=generated_text,
|
||||
text=output,
|
||||
generation_time=generation_time,
|
||||
# MLX generate doesn't expose tokens directly, so we leave it empty
|
||||
generated_tokens=[],
|
||||
generated_tokens=tokens,
|
||||
)
|
||||
_log.debug("MLX model: Released global lock")
|
||||
|
||||
@@ -57,6 +57,8 @@ from docling.datamodel.vlm_model_specs import (
|
||||
DOLPHIN_TRANSFORMERS,
|
||||
SMOLDOCLING_MLX,
|
||||
SMOLDOCLING_TRANSFORMERS,
|
||||
SMOLVLM500_MLX,
|
||||
SMOLVLM500_TRANSFORMERS,
|
||||
)
|
||||
from docling.models.layout_model import LayoutModel
|
||||
from docling.models.page_assemble_model import PageAssembleModel, PageAssembleOptions
|
||||
@@ -159,21 +161,23 @@ class ThreadedMultiStageVlmPipelineOptions(PaginatedPipelineOptions):
|
||||
# text_opts = DOLPHIN_TRANSFORMERS.model_copy()
|
||||
# text_opts.prompt = "<s>Read text in the image. <Answer/>"
|
||||
|
||||
base_model = SMOLDOCLING_TRANSFORMERS
|
||||
|
||||
formula_opts = base_model.model_copy()
|
||||
formula_opts.prompt = "Convert formula to latex."
|
||||
formula_opts.response_format = ResponseFormat.OTSL
|
||||
|
||||
code_opts = base_model.model_copy()
|
||||
code_opts.prompt = "Convert code to text."
|
||||
code_opts.response_format = ResponseFormat.OTSL
|
||||
base_model = SMOLVLM500_TRANSFORMERS
|
||||
|
||||
text_opts = base_model.model_copy()
|
||||
text_opts.prompt = "Convert this page to docling."
|
||||
text_opts.response_format = ResponseFormat.OTSL
|
||||
# text_opts.prompt = "Convert this page to docling."
|
||||
text_opts.prompt = "What does this say?"
|
||||
text_opts.response_format = ResponseFormat.PLAINTEXT
|
||||
|
||||
table_opts = base_model.model_copy()
|
||||
formula_opts = base_model.model_copy()
|
||||
# formula_opts.prompt = "Convert formula to latex."
|
||||
formula_opts.prompt = "What does this say?"
|
||||
formula_opts.response_format = ResponseFormat.PLAINTEXT
|
||||
|
||||
code_opts = SMOLDOCLING_TRANSFORMERS.model_copy()
|
||||
code_opts.prompt = "Convert code to text."
|
||||
code_opts.response_format = ResponseFormat.DOCTAGS
|
||||
|
||||
table_opts = SMOLDOCLING_TRANSFORMERS.model_copy()
|
||||
table_opts.prompt = "Convert this table to OTSL."
|
||||
table_opts.response_format = ResponseFormat.OTSL
|
||||
|
||||
|
||||
@@ -128,7 +128,9 @@ class VlmPipeline(PaginatedPipeline):
|
||||
page._backend = conv_res.input._backend.load_page(page.page_no) # type: ignore
|
||||
if page._backend is not None and page._backend.is_valid():
|
||||
page.size = page._backend.get_size()
|
||||
page.parsed_page = page._backend.get_segmented_page()
|
||||
|
||||
if self.force_backend_text:
|
||||
page.parsed_page = page._backend.get_segmented_page()
|
||||
|
||||
return page
|
||||
|
||||
|
||||
Reference in New Issue
Block a user