Add VLLM backend support, optimize process_images

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>
This commit is contained in:
Christoph Auer
2025-08-15 13:18:02 +02:00
parent 18b1a43744
commit 16fea9cd8b
10 changed files with 489 additions and 161 deletions

View File

@@ -27,6 +27,7 @@ class ResponseFormat(str, Enum):
class InferenceFramework(str, Enum):
MLX = "mlx"
TRANSFORMERS = "transformers"
VLLM = "vllm"
class TransformersModelType(str, Enum):

View File

@@ -44,6 +44,20 @@ SMOLDOCLING_TRANSFORMERS = InlineVlmOptions(
temperature=0.0,
)
SMOLDOCLING_VLLM = InlineVlmOptions(
repo_id="ds4sd/SmolDocling-256M-preview",
prompt="Convert this page to docling.",
response_format=ResponseFormat.DOCTAGS,
inference_framework=InferenceFramework.VLLM,
transformers_model_type=TransformersModelType.AUTOMODEL_VISION2SEQ,
supported_devices=[
AcceleratorDevice.CPU,
AcceleratorDevice.CUDA,
],
scale=2.0,
temperature=0.0,
)
# GraniteVision
GRANITE_VISION_TRANSFORMERS = InlineVlmOptions(
repo_id="ibm-granite/granite-vision-3.2-2b",
@@ -60,6 +74,20 @@ GRANITE_VISION_TRANSFORMERS = InlineVlmOptions(
temperature=0.0,
)
GRANITE_VISION_VLLM = InlineVlmOptions(
repo_id="ibm-granite/granite-vision-3.2-2b",
prompt="Convert this page to markdown. Do not miss any text and only output the bare markdown!",
response_format=ResponseFormat.MARKDOWN,
inference_framework=InferenceFramework.VLLM,
transformers_model_type=TransformersModelType.AUTOMODEL_VISION2SEQ,
supported_devices=[
AcceleratorDevice.CPU,
AcceleratorDevice.CUDA,
],
scale=2.0,
temperature=0.0,
)
GRANITE_VISION_OLLAMA = ApiVlmOptions(
url=AnyUrl("http://localhost:11434/v1/chat/completions"),
params={"model": "granite3.2-vision:2b"},
@@ -158,5 +186,7 @@ DOLPHIN_TRANSFORMERS = InlineVlmOptions(
class VlmModelType(str, Enum):
SMOLDOCLING = "smoldocling"
SMOLDOCLING_VLLM = "smoldocling_vllm"
GRANITE_VISION = "granite_vision"
GRANITE_VISION_VLLM = "granite_vision_vllm"
GRANITE_VISION_OLLAMA = "granite_vision_ollama"

View File

@@ -37,9 +37,21 @@ class BaseVlmModel(ABC):
@abstractmethod
def process_images(
self, image_batch: Iterable[Union[Image, np.ndarray]]
self,
image_batch: Iterable[Union[Image, np.ndarray]],
prompt: Union[str, list[str]],
) -> Iterable[VlmPrediction]:
"""Process raw images without page metadata."""
"""Process raw images without page metadata.
Args:
image_batch: Iterable of PIL Images or numpy arrays
prompt: Either:
- str: Single prompt used for all images
- list[str]: List of prompts (one per image, must match image count)
Raises:
ValueError: If prompt list length doesn't match image count.
"""
class BaseVlmPageModel(BasePageModel, BaseVlmModel):
@@ -55,23 +67,6 @@ class BaseVlmPageModel(BasePageModel, BaseVlmModel):
) -> Iterable[Page]:
"""Extract images from pages, process them, and attach results back."""
@abstractmethod
def process_images(
self,
image_batch: Iterable[Union[Image, np.ndarray]],
prompt: Optional[str] = None,
) -> Iterable[VlmPrediction]:
"""Process raw images without page metadata.
Args:
image_batch: Iterable of PIL Images or numpy arrays
prompt: Optional prompt string. If None, uses vlm_options.prompt if it's a string.
If vlm_options.prompt is callable and no prompt is provided, raises ValueError.
Raises:
ValueError: If vlm_options.prompt is callable and no prompt parameter is provided.
"""
EnrichElementT = TypeVar("EnrichElementT", default=NodeItem)

View File

@@ -125,55 +125,59 @@ class HuggingFaceTransformersVlmModel(BaseVlmPageModel, HuggingFaceModelDownload
def __call__(
self, conv_res: ConversionResult, page_batch: Iterable[Page]
) -> Iterable[Page]:
for page in page_batch:
page_list = list(page_batch)
if not page_list:
return
valid_pages = []
invalid_pages = []
for page in page_list:
assert page._backend is not None
if not page._backend.is_valid():
yield page
invalid_pages.append(page)
else:
with TimeRecorder(conv_res, "vlm"):
assert page.size is not None
valid_pages.append(page)
# Process valid pages in batch
if valid_pages:
with TimeRecorder(conv_res, "vlm"):
# Prepare images and prompts for batch processing
images = []
user_prompts = []
pages_with_images = []
for page in valid_pages:
assert page.size is not None
hi_res_image = page.get_image(
scale=self.vlm_options.scale, max_size=self.vlm_options.max_size
)
# Only process pages with valid images
if hi_res_image is not None:
images.append(hi_res_image)
# Define prompt structure
if callable(self.vlm_options.prompt):
user_prompt = self.vlm_options.prompt(page.parsed_page)
else:
user_prompt = self.vlm_options.prompt
prompt = self.formulate_prompt(user_prompt)
inputs = self.processor(
text=prompt, images=[hi_res_image], return_tensors="pt"
).to(self.device)
user_prompts.append(user_prompt)
pages_with_images.append(page)
start_time = time.time()
# Call model to generate:
generated_ids = self.vlm_model.generate(
**inputs,
max_new_tokens=self.max_new_tokens,
use_cache=self.use_cache,
temperature=self.temperature,
generation_config=self.generation_config,
**self.vlm_options.extra_generation_config,
)
# Use process_images for the actual inference
if images: # Only if we have valid images
predictions = list(self.process_images(images, user_prompts))
generation_time = time.time() - start_time
generated_texts = self.processor.batch_decode(
generated_ids[:, inputs["input_ids"].shape[1] :],
skip_special_tokens=True,
)[0]
num_tokens = len(generated_ids[0])
_log.debug(
f"Generated {num_tokens} tokens in time {generation_time:.2f} seconds."
)
page.predictions.vlm_response = VlmPrediction(
text=generated_texts,
generation_time=generation_time,
)
# Attach results to pages
for page, prediction in zip(pages_with_images, predictions):
page.predictions.vlm_response = prediction
# Yield all pages (valid and invalid)
for page in invalid_pages:
yield page
for page in valid_pages:
yield page
def formulate_prompt(self, user_prompt: str) -> str:
@@ -221,9 +225,19 @@ class HuggingFaceTransformersVlmModel(BaseVlmPageModel, HuggingFaceModelDownload
def process_images(
self,
image_batch: Iterable[Union[Image, np.ndarray]],
prompt: Optional[str] = None,
prompt: Union[str, list[str]],
) -> Iterable[VlmPrediction]:
"""Process raw images without page metadata in a single batched inference call."""
"""Process raw images without page metadata in a single batched inference call.
Args:
image_batch: Iterable of PIL Images or numpy arrays
prompt: Either:
- str: Single prompt used for all images
- list[str]: List of prompts (one per image, must match image count)
Raises:
ValueError: If prompt list length doesn't match image count.
"""
pil_images: list[Image] = []
for img in image_batch:
@@ -251,19 +265,24 @@ class HuggingFaceTransformersVlmModel(BaseVlmPageModel, HuggingFaceModelDownload
if len(pil_images) == 0:
return
# Handle prompt with priority: parameter > vlm_options.prompt > error
if prompt is not None:
user_prompt = prompt
elif not callable(self.vlm_options.prompt):
user_prompt = self.vlm_options.prompt
else:
# Handle prompt parameter
if isinstance(prompt, str):
# Single prompt for all images
user_prompts = [prompt] * len(pil_images)
elif isinstance(prompt, list):
# List of prompts (one per image)
if len(prompt) != len(pil_images):
raise ValueError(
"vlm_options.prompt is callable but no prompt parameter provided to process_images. "
"Please provide a prompt parameter when calling process_images directly."
f"Number of prompts ({len(prompt)}) must match number of images ({len(pil_images)})"
)
user_prompts = prompt
else:
raise ValueError(f"prompt must be str or list[str], got {type(prompt)}")
formatted_prompt = self.formulate_prompt(user_prompt)
prompts: list[str] = [formatted_prompt] * len(pil_images)
# Format prompts individually
prompts: list[str] = [
self.formulate_prompt(user_prompt) for user_prompt in user_prompts
]
inputs = self.processor(
text=prompts, images=pil_images, return_tensors="pt", padding=True

View File

@@ -71,110 +71,103 @@ class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
def __call__(
self, conv_res: ConversionResult, page_batch: Iterable[Page]
) -> Iterable[Page]:
for page in page_batch:
page_list = list(page_batch)
if not page_list:
return
valid_pages = []
invalid_pages = []
for page in page_list:
assert page._backend is not None
if not page._backend.is_valid():
yield page
invalid_pages.append(page)
else:
with TimeRecorder(conv_res, f"vlm-mlx-{self.vlm_options.repo_id}"):
assert page.size is not None
valid_pages.append(page)
# Process valid pages in batch
if valid_pages:
with TimeRecorder(conv_res, f"vlm-mlx-{self.vlm_options.repo_id}"):
# Prepare images and prompts for batch processing
images = []
user_prompts = []
pages_with_images = []
for page in valid_pages:
assert page.size is not None
hi_res_image = page.get_image(
scale=self.vlm_options.scale, max_size=self.vlm_options.max_size
)
# Only process pages with valid images
if hi_res_image is not None:
im_width, im_height = hi_res_image.size
# populate page_tags with predicted doc tags
page_tags = ""
if hi_res_image:
if hi_res_image.mode != "RGB":
hi_res_image = hi_res_image.convert("RGB")
images.append(hi_res_image)
# Define prompt structure
if callable(self.vlm_options.prompt):
user_prompt = self.vlm_options.prompt(page.parsed_page)
else:
user_prompt = self.vlm_options.prompt
prompt = self.apply_chat_template(
self.processor, self.config, user_prompt, num_images=1
)
# MLX models are not thread-safe - use global lock to serialize access
with _MLX_GLOBAL_LOCK:
_log.debug(
"MLX model: Acquired global lock for __call__ method"
)
start_time = time.time()
_log.debug("start generating ...")
user_prompts.append(user_prompt)
pages_with_images.append(page)
# Call model to generate:
tokens: list[VlmPredictionToken] = []
# Use process_images for the actual inference
if images: # Only if we have valid images
predictions = list(self.process_images(images, user_prompts))
output = ""
for token in self.stream_generate(
self.vlm_model,
self.processor,
prompt,
[hi_res_image],
max_tokens=self.max_tokens,
verbose=False,
temp=self.temperature,
):
if len(token.logprobs.shape) == 1:
tokens.append(
VlmPredictionToken(
text=token.text,
token=token.token,
logprob=token.logprobs[token.token],
)
)
elif (
len(token.logprobs.shape) == 2
and token.logprobs.shape[0] == 1
):
tokens.append(
VlmPredictionToken(
text=token.text,
token=token.token,
logprob=token.logprobs[0, token.token],
)
)
else:
_log.warning(
f"incompatible shape for logprobs: {token.logprobs.shape}"
)
output += token.text
if "</doctag>" in token.text:
break
generation_time = time.time() - start_time
_log.debug("MLX model: Released global lock")
page_tags = output
_log.debug(
f"{generation_time:.2f} seconds for {len(tokens)} tokens ({len(tokens) / generation_time} tokens/sec)."
)
page.predictions.vlm_response = VlmPrediction(
text=page_tags,
generation_time=generation_time,
generated_tokens=tokens,
)
# Attach results to pages
for page, prediction in zip(pages_with_images, predictions):
page.predictions.vlm_response = prediction
# Yield all pages (valid and invalid)
for page in invalid_pages:
yield page
for page in valid_pages:
yield page
def process_images(
self,
image_batch: Iterable[Union[Image, np.ndarray]],
prompt: Optional[str] = None,
prompt: Union[str, list[str]],
) -> Iterable[VlmPrediction]:
"""Process raw images without page metadata.
Args:
image_batch: Iterable of PIL Images or numpy arrays
prompt: Either:
- str: Single prompt used for all images
- list[str]: List of prompts (one per image, must match image count)
Raises:
ValueError: If prompt list length doesn't match image count.
"""
from mlx_vlm import generate
# Convert image batch to list for length validation
image_list = list(image_batch)
if len(image_list) == 0:
return
# Handle prompt parameter
if isinstance(prompt, str):
# Single prompt for all images
user_prompts = [prompt] * len(image_list)
elif isinstance(prompt, list):
# List of prompts (one per image)
if len(prompt) != len(image_list):
raise ValueError(
f"Number of prompts ({len(prompt)}) must match number of images ({len(image_list)})"
)
user_prompts = prompt
else:
raise ValueError(f"prompt must be str or list[str], got {type(prompt)}")
# MLX models are not thread-safe - use global lock to serialize access
with _MLX_GLOBAL_LOCK:
_log.debug("MLX model: Acquired global lock for thread safety")
for image in image_batch:
for image, user_prompt in zip(image_list, user_prompts):
# Convert numpy array to PIL Image if needed
if isinstance(image, np.ndarray):
if image.ndim == 3 and image.shape[2] in [3, 4]:
@@ -196,17 +189,6 @@ class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
if image.mode != "RGB":
image = image.convert("RGB")
# Handle prompt with priority: parameter > vlm_options.prompt > error
if prompt is not None:
user_prompt = prompt
elif not callable(self.vlm_options.prompt):
user_prompt = self.vlm_options.prompt
else:
raise ValueError(
"vlm_options.prompt is callable but no prompt parameter provided to process_images. "
"Please provide a prompt parameter when calling process_images directly."
)
# Use the MLX chat template approach like in the __call__ method
formatted_prompt = self.apply_chat_template(
self.processor, self.config, user_prompt, num_images=1

View File

@@ -0,0 +1,277 @@
import logging
import time
from collections.abc import Iterable
from pathlib import Path
from typing import Any, Optional, Union
import numpy as np
from PIL.Image import Image
from docling.datamodel.accelerator_options import (
AcceleratorOptions,
)
from docling.datamodel.base_models import Page, VlmPrediction
from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options_vlm_model import (
InlineVlmOptions,
TransformersPromptStyle,
)
from docling.models.base_model import BaseVlmPageModel
from docling.models.utils.hf_model_download import (
HuggingFaceModelDownloadMixin,
)
from docling.utils.accelerator_utils import decide_device
from docling.utils.profiling import TimeRecorder
_log = logging.getLogger(__name__)
class VllmVlmModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
def __init__(
self,
enabled: bool,
artifacts_path: Optional[Path],
accelerator_options: AcceleratorOptions,
vlm_options: InlineVlmOptions,
):
self.enabled = enabled
self.vlm_options = vlm_options
if self.enabled:
from transformers import AutoProcessor
from vllm import LLM, SamplingParams
self.device = decide_device(
accelerator_options.device,
supported_devices=vlm_options.supported_devices,
)
_log.debug(f"Available device for VLM: {self.device}")
self.max_new_tokens = vlm_options.max_new_tokens
self.temperature = vlm_options.temperature
repo_cache_folder = vlm_options.repo_id.replace("/", "--")
if artifacts_path is None:
artifacts_path = self.download_models(self.vlm_options.repo_id)
elif (artifacts_path / repo_cache_folder).exists():
artifacts_path = artifacts_path / repo_cache_folder
# Initialize VLLM LLM
llm_kwargs = {
"model": str(artifacts_path),
"model_impl": "transformers",
"limit_mm_per_prompt": {"image": 1},
"trust_remote_code": vlm_options.trust_remote_code,
}
# Add device-specific configurations
if self.device.startswith("cuda"):
# VLLM automatically detects GPU
pass
elif self.device == "cpu":
llm_kwargs["device"] = "cpu"
# Add quantization if specified
if vlm_options.quantized:
if vlm_options.load_in_8bit:
llm_kwargs["quantization"] = "bitsandbytes"
self.llm = LLM(**llm_kwargs)
# Initialize processor for prompt formatting
self.processor = AutoProcessor.from_pretrained(
artifacts_path,
trust_remote_code=vlm_options.trust_remote_code,
)
# Set up sampling parameters
self.sampling_params = SamplingParams(
temperature=self.temperature,
max_tokens=self.max_new_tokens,
stop=vlm_options.stop_strings if vlm_options.stop_strings else None,
**vlm_options.extra_generation_config,
)
def __call__(
self, conv_res: ConversionResult, page_batch: Iterable[Page]
) -> Iterable[Page]:
page_list = list(page_batch)
if not page_list:
return
valid_pages = []
invalid_pages = []
for page in page_list:
assert page._backend is not None
if not page._backend.is_valid():
invalid_pages.append(page)
else:
valid_pages.append(page)
# Process valid pages in batch
if valid_pages:
with TimeRecorder(conv_res, "vlm"):
# Prepare images and prompts for batch processing
images = []
user_prompts = []
pages_with_images = []
for page in valid_pages:
assert page.size is not None
hi_res_image = page.get_image(
scale=self.vlm_options.scale, max_size=self.vlm_options.max_size
)
# Only process pages with valid images
if hi_res_image is not None:
images.append(hi_res_image)
# Define prompt structure
if callable(self.vlm_options.prompt):
user_prompt = self.vlm_options.prompt(page.parsed_page)
else:
user_prompt = self.vlm_options.prompt
user_prompts.append(user_prompt)
pages_with_images.append(page)
# Use process_images for the actual inference
if images: # Only if we have valid images
predictions = list(self.process_images(images, user_prompts))
# Attach results to pages
for page, prediction in zip(pages_with_images, predictions):
page.predictions.vlm_response = prediction
# Yield all pages (valid and invalid)
for page in invalid_pages:
yield page
for page in valid_pages:
yield page
def formulate_prompt(self, user_prompt: str) -> str:
"""Formulate a prompt for the VLM."""
if self.vlm_options.transformers_prompt_style == TransformersPromptStyle.RAW:
return user_prompt
elif self.vlm_options.repo_id == "microsoft/Phi-4-multimodal-instruct":
_log.debug("Using specialized prompt for Phi-4")
# Note: This might need adjustment for VLLM vs transformers
user_prompt_prefix = "<|user|>"
assistant_prompt = "<|assistant|>"
prompt_suffix = "<|end|>"
prompt = f"{user_prompt_prefix}<|image_1|>{user_prompt}{prompt_suffix}{assistant_prompt}"
_log.debug(f"prompt for {self.vlm_options.repo_id}: {prompt}")
return prompt
elif self.vlm_options.transformers_prompt_style == TransformersPromptStyle.CHAT:
messages = [
{
"role": "user",
"content": [
{
"type": "text",
"text": "This is a page from a document.",
},
{"type": "image"},
{"type": "text", "text": user_prompt},
],
}
]
prompt = self.processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
return prompt
raise RuntimeError(
f"Unknown prompt style `{self.vlm_options.transformers_prompt_style}`. Valid values are {', '.join(s.value for s in TransformersPromptStyle)}."
)
def process_images(
self,
image_batch: Iterable[Union[Image, np.ndarray]],
prompt: Union[str, list[str]],
) -> Iterable[VlmPrediction]:
"""Process raw images without page metadata in a single batched inference call.
Args:
image_batch: Iterable of PIL Images or numpy arrays
prompt: Either:
- str: Single prompt used for all images
- list[str]: List of prompts (one per image, must match image count)
Raises:
ValueError: If prompt list length doesn't match image count.
"""
pil_images: list[Image] = []
for img in image_batch:
# Convert numpy array to PIL Image if needed
if isinstance(img, np.ndarray):
if img.ndim == 3 and img.shape[2] in [3, 4]:
from PIL import Image as PILImage
pil_img = PILImage.fromarray(img.astype(np.uint8))
elif img.ndim == 2:
from PIL import Image as PILImage
pil_img = PILImage.fromarray(img.astype(np.uint8), mode="L")
else:
raise ValueError(f"Unsupported numpy array shape: {img.shape}")
else:
pil_img = img
# Ensure image is in RGB mode (handles RGBA, L, etc.)
if pil_img.mode != "RGB":
pil_img = pil_img.convert("RGB")
pil_images.append(pil_img)
if len(pil_images) == 0:
return
# Handle prompt parameter
if isinstance(prompt, str):
# Single prompt for all images
user_prompts = [prompt] * len(pil_images)
elif isinstance(prompt, list):
# List of prompts (one per image)
if len(prompt) != len(pil_images):
raise ValueError(
f"Number of prompts ({len(prompt)}) must match number of images ({len(pil_images)})"
)
user_prompts = prompt
else:
raise ValueError(f"prompt must be str or list[str], got {type(prompt)}")
# Format prompts individually
prompts: list[str] = [
self.formulate_prompt(user_prompt) for user_prompt in user_prompts
]
# Prepare VLLM inputs
llm_inputs = []
for prompt, image in zip(prompts, pil_images):
llm_inputs.append({"prompt": prompt, "multi_modal_data": {"image": image}})
start_time = time.time()
outputs = self.llm.generate(llm_inputs, sampling_params=self.sampling_params)
generation_time = time.time() - start_time
# Logging tokens count for the first sample as a representative metric
if len(outputs) > 0:
num_tokens = len(outputs[0].outputs[0].token_ids)
_log.debug(
f"Generated {num_tokens} tokens in time {generation_time:.2f} seconds."
)
for output in outputs:
yield VlmPrediction(
text=output.outputs[0].text, generation_time=generation_time
)

View File

@@ -693,6 +693,17 @@ class ThreadedMultiStageVlmPipeline(BasePipeline):
accelerator_options=self.pipeline_options.accelerator_options,
vlm_options=vlm_options,
)
elif vlm_options.inference_framework == InferenceFramework.VLLM:
from docling.models.vlm_models_inline.vllm_model import (
VllmVlmModel,
)
model = VllmVlmModel(
enabled=True,
artifacts_path=art_path,
accelerator_options=self.pipeline_options.accelerator_options,
vlm_options=vlm_options,
)
else:
raise ValueError(
f"Unsupported inference framework: {vlm_options.inference_framework}"

View File

@@ -103,6 +103,17 @@ class VlmPipeline(PaginatedPipeline):
vlm_options=vlm_options,
),
]
elif vlm_options.inference_framework == InferenceFramework.VLLM:
from docling.models.vlm_models_inline.vllm_model import VllmVlmModel
self.build_pipe = [
VllmVlmModel(
enabled=True, # must be always enabled for this pipeline to make sense.
artifacts_path=artifacts_path,
accelerator_options=pipeline_options.accelerator_options,
vlm_options=vlm_options,
),
]
else:
raise ValueError(
f"Could not instantiate the right type of VLM pipeline: {vlm_options.inference_framework}"

View File

@@ -255,6 +255,7 @@ module = [
"huggingface_hub.*",
"transformers.*",
"pylatexenc.*",
"vllm.*",
]
ignore_missing_imports = true