feat: batching support for VLMs in transformers backend, add initial VLLM backend (#2094)

* Prepare existing codes for use with new multi-stage VLM pipeline

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Add multithreaded VLM pipeline

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Add VLM task interpreters

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Add VLM task interpreters

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Remove prints

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Fix KeyboardInterrupt behaviour

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Add VLLM backend support, optimize process_images

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Tweak defaults

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Implement proper batch inference for HuggingFaceTransformersVlmModel

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Small fixes

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Cleanup hf_transformers_model batching impl

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Adjust example instatiation of multi-stage VLM pipeline

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Add GoT OCR 2.0

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Factor out changes without multi-stage pipeline

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Reset defaults for generation

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Cleanup

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Add torch.compile, fix temperature setting in gen_kwargs

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Expose page_batch_size on CLI

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Add torch_dtype bfloat16 to SMOLDOCLING and SMOLVLM model spec

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Clip off pad_token

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

---------

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>
This commit is contained in:
Christoph Auer
2025-08-22 13:17:33 +02:00
committed by GitHub
parent 3f03709885
commit 3c660c0511
17 changed files with 2837 additions and 319 deletions

View File

@@ -1,8 +1,12 @@
import logging
import threading
import time
from collections.abc import Iterable
from pathlib import Path
from typing import Optional
from typing import Optional, Union
import numpy as np
from PIL.Image import Image
from docling.datamodel.accelerator_options import (
AcceleratorOptions,
@@ -10,7 +14,7 @@ from docling.datamodel.accelerator_options import (
from docling.datamodel.base_models import Page, VlmPrediction, VlmPredictionToken
from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options_vlm_model import InlineVlmOptions
from docling.models.base_model import BasePageModel
from docling.models.base_model import BaseVlmPageModel
from docling.models.utils.hf_model_download import (
HuggingFaceModelDownloadMixin,
)
@@ -18,8 +22,12 @@ from docling.utils.profiling import TimeRecorder
_log = logging.getLogger(__name__)
# Global lock for MLX model calls - MLX models are not thread-safe
# All MLX models share this lock to prevent concurrent MLX operations
_MLX_GLOBAL_LOCK = threading.Lock()
class HuggingFaceMlxModel(BasePageModel, HuggingFaceModelDownloadMixin):
class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
def __init__(
self,
enabled: bool,
@@ -63,87 +71,190 @@ class HuggingFaceMlxModel(BasePageModel, HuggingFaceModelDownloadMixin):
def __call__(
self, conv_res: ConversionResult, page_batch: Iterable[Page]
) -> Iterable[Page]:
for page in page_batch:
page_list = list(page_batch)
if not page_list:
return
valid_pages = []
invalid_pages = []
for page in page_list:
assert page._backend is not None
if not page._backend.is_valid():
yield page
invalid_pages.append(page)
else:
with TimeRecorder(conv_res, f"vlm-mlx-{self.vlm_options.repo_id}"):
assert page.size is not None
valid_pages.append(page)
# Process valid pages in batch
if valid_pages:
with TimeRecorder(conv_res, f"vlm-mlx-{self.vlm_options.repo_id}"):
# Prepare images and prompts for batch processing
images = []
user_prompts = []
pages_with_images = []
for page in valid_pages:
assert page.size is not None
hi_res_image = page.get_image(
scale=self.vlm_options.scale, max_size=self.vlm_options.max_size
)
# Only process pages with valid images
if hi_res_image is not None:
im_width, im_height = hi_res_image.size
images.append(hi_res_image)
# populate page_tags with predicted doc tags
page_tags = ""
if hi_res_image:
if hi_res_image.mode != "RGB":
hi_res_image = hi_res_image.convert("RGB")
user_prompt = self.vlm_options.build_prompt(page.parsed_page)
prompt = self.apply_chat_template(
self.processor, self.config, user_prompt, num_images=1
)
start_time = time.time()
_log.debug("start generating ...")
# Call model to generate:
tokens: list[VlmPredictionToken] = []
output = ""
for token in self.stream_generate(
self.vlm_model,
self.processor,
prompt,
[hi_res_image],
max_tokens=self.max_tokens,
verbose=False,
temp=self.temperature,
):
if len(token.logprobs.shape) == 1:
tokens.append(
VlmPredictionToken(
text=token.text,
token=token.token,
logprob=token.logprobs[token.token],
)
)
elif (
len(token.logprobs.shape) == 2
and token.logprobs.shape[0] == 1
):
tokens.append(
VlmPredictionToken(
text=token.text,
token=token.token,
logprob=token.logprobs[0, token.token],
)
)
# Define prompt structure
if callable(self.vlm_options.prompt):
user_prompt = self.vlm_options.prompt(page.parsed_page)
else:
_log.warning(
f"incompatible shape for logprobs: {token.logprobs.shape}"
)
user_prompt = self.vlm_options.prompt
output += token.text
if "</doctag>" in token.text:
user_prompts.append(user_prompt)
pages_with_images.append(page)
# Use process_images for the actual inference
if images: # Only if we have valid images
predictions = list(self.process_images(images, user_prompts))
# Attach results to pages
for page, prediction in zip(pages_with_images, predictions):
page.predictions.vlm_response = prediction
# Yield all pages (valid and invalid)
for page in invalid_pages:
yield page
for page in valid_pages:
yield page
def process_images(
self,
image_batch: Iterable[Union[Image, np.ndarray]],
prompt: Union[str, list[str]],
) -> Iterable[VlmPrediction]:
"""Process raw images without page metadata.
Args:
image_batch: Iterable of PIL Images or numpy arrays
prompt: Either:
- str: Single prompt used for all images
- list[str]: List of prompts (one per image, must match image count)
Raises:
ValueError: If prompt list length doesn't match image count.
"""
# Convert image batch to list for length validation
image_list = list(image_batch)
if len(image_list) == 0:
return
# Handle prompt parameter
if isinstance(prompt, str):
# Single prompt for all images
user_prompts = [prompt] * len(image_list)
elif isinstance(prompt, list):
# List of prompts (one per image)
if len(prompt) != len(image_list):
raise ValueError(
f"Number of prompts ({len(prompt)}) must match number of images ({len(image_list)})"
)
user_prompts = prompt
else:
raise ValueError(f"prompt must be str or list[str], got {type(prompt)}")
# MLX models are not thread-safe - use global lock to serialize access
with _MLX_GLOBAL_LOCK:
_log.debug("MLX model: Acquired global lock for thread safety")
for image, user_prompt in zip(image_list, user_prompts):
# Convert numpy array to PIL Image if needed
if isinstance(image, np.ndarray):
if image.ndim == 3 and image.shape[2] in [3, 4]:
# RGB or RGBA array
from PIL import Image as PILImage
image = PILImage.fromarray(image.astype(np.uint8))
elif image.ndim == 2:
# Grayscale array
from PIL import Image as PILImage
image = PILImage.fromarray(image.astype(np.uint8), mode="L")
else:
raise ValueError(
f"Unsupported numpy array shape: {image.shape}"
)
# Ensure image is in RGB mode (handles RGBA, L, etc.)
if image.mode != "RGB":
image = image.convert("RGB")
# Use the MLX chat template approach like in the __call__ method
formatted_prompt = self.apply_chat_template(
self.processor, self.config, user_prompt, num_images=1
)
# Stream generate with stop strings support
start_time = time.time()
_log.debug("start generating ...")
tokens: list[VlmPredictionToken] = []
output = ""
# Use stream_generate for proper stop string handling
for token in self.stream_generate(
self.vlm_model,
self.processor,
formatted_prompt,
[image], # MLX stream_generate expects list of images
max_tokens=self.max_tokens,
verbose=False,
temp=self.temperature,
):
# Collect token information
if len(token.logprobs.shape) == 1:
tokens.append(
VlmPredictionToken(
text=token.text,
token=token.token,
logprob=token.logprobs[token.token],
)
)
elif (
len(token.logprobs.shape) == 2 and token.logprobs.shape[0] == 1
):
tokens.append(
VlmPredictionToken(
text=token.text,
token=token.token,
logprob=token.logprobs[0, token.token],
)
)
else:
_log.warning(
f"incompatible shape for logprobs: {token.logprobs.shape}"
)
output += token.text
# Check for any configured stop strings
if self.vlm_options.stop_strings:
if any(
stop_str in output
for stop_str in self.vlm_options.stop_strings
):
_log.debug("Stopping generation due to stop string match")
break
generation_time = time.time() - start_time
page_tags = output
generation_time = time.time() - start_time
_log.debug(
f"{generation_time:.2f} seconds for {len(tokens)} tokens ({len(tokens) / generation_time} tokens/sec)."
)
page_tags = self.vlm_options.decode_response(page_tags)
page.predictions.vlm_response = VlmPrediction(
text=page_tags,
generation_time=generation_time,
generated_tokens=tokens,
)
_log.debug(
f"{generation_time:.2f} seconds for {len(tokens)} tokens ({len(tokens) / generation_time:.1f} tokens/sec)."
)
yield page
# Apply decode_response to the output before yielding
decoded_output = self.vlm_options.decode_response(output)
yield VlmPrediction(
text=decoded_output,
generation_time=generation_time,
generated_tokens=tokens,
)
_log.debug("MLX model: Released global lock")