Files
docling/docling/models/vlm_models_inline/vllm_model.py
peets 6a04e27352 feat(vlm): track generated tokens and stop reasons for VLM models (#2543)
* feat: add enum StopReason and use it in VlmPrediction

Signed-off-by: ElHachem02 <peterelhachem02@gmail.com>

* add vlm_inference time for api calls and track stop reason

Signed-off-by: ElHachem02 <peterelhachem02@gmail.com>

* fix: rename enum to VlmStopReason

Signed-off-by: ElHachem02 <peterelhachem02@gmail.com>

* Propagate partial success status if page reaches max tokens

Signed-off-by: ElHachem02 <peterelhachem02@gmail.com>

* feat: page with generation stopped by loop detector create partial success status

Signed-off-by: Peter El Hachem <peter.el.hachem@ibm.com>

* Add hint for future improvement

Signed-off-by: Peter El Hachem <peter.el.hachem@ibm.com>

* fix: remove vlm_stop_reason from extracted page data, add UNSPECIFIED state as VlmStopReason to avoid null value

Signed-off-by: Peter El Hachem <peter.el.hachem@ibm.com>

---------

Signed-off-by: ElHachem02 <peterelhachem02@gmail.com>
Signed-off-by: Peter El Hachem <peter.el.hachem@ibm.com>
Co-authored-by: Peter El Hachem <peter.el.hachem@ibm.com>
2025-11-04 19:39:09 +01:00

338 lines
11 KiB
Python

import logging
import sys
import time
from collections.abc import Iterable
from pathlib import Path
from typing import Any, Dict, Optional, Union
import numpy as np
from PIL.Image import Image
from docling.datamodel.accelerator_options import AcceleratorOptions
from docling.datamodel.base_models import (
Page,
VlmPrediction,
VlmPredictionToken,
VlmStopReason,
)
from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options_vlm_model import (
InlineVlmOptions,
TransformersPromptStyle,
)
from docling.models.base_model import BaseVlmPageModel
from docling.models.utils.hf_model_download import HuggingFaceModelDownloadMixin
from docling.utils.accelerator_utils import decide_device
from docling.utils.profiling import TimeRecorder
_log = logging.getLogger(__name__)
class VllmVlmModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
"""
vLLM-backed vision-language model that accepts PIL images (or numpy arrays)
via vLLM's multi_modal_data, with prompt formatting handled by formulate_prompt().
"""
# --------- Allowlist of vLLM args ---------
# SamplingParams (runtime generation controls)
_VLLM_SAMPLING_KEYS = {
# Core
"max_tokens",
"temperature",
"top_p",
"top_k",
# Penalties
"presence_penalty",
"frequency_penalty",
"repetition_penalty",
# Stops / outputs
"stop",
"stop_token_ids",
"skip_special_tokens",
"spaces_between_special_tokens",
# Search / length
"n",
"best_of",
"length_penalty",
"early_stopping",
# Misc
"logprobs",
"prompt_logprobs",
"min_p",
"seed",
}
# LLM(...) / EngineArgs (engine/load-time controls)
_VLLM_ENGINE_KEYS = {
# Model/tokenizer/impl
"tokenizer",
"tokenizer_mode",
"download_dir",
# Parallelism / memory / lengths
"tensor_parallel_size",
"pipeline_parallel_size",
"gpu_memory_utilization",
"max_model_len",
"max_num_batched_tokens",
"kv_cache_dtype",
"dtype",
# Quantization (coarse switch)
"quantization",
# Multimodal limits
"limit_mm_per_prompt",
# Execution toggles
"enforce_eager",
}
def __init__(
self,
enabled: bool,
artifacts_path: Optional[Path],
accelerator_options: AcceleratorOptions,
vlm_options: InlineVlmOptions,
):
self.enabled = enabled
self.vlm_options: InlineVlmOptions = vlm_options
self.llm = None
self.sampling_params = None
self.processor = None # used for CHAT templating in formulate_prompt()
self.device = "cpu"
self.max_new_tokens = vlm_options.max_new_tokens
self.temperature = vlm_options.temperature
if not self.enabled:
return
from transformers import AutoProcessor
try:
from vllm import LLM, SamplingParams
except ImportError:
if sys.version_info < (3, 14):
raise ImportError(
"vllm is not installed. Please install it via `pip install vllm`."
)
else:
raise ImportError(
"vllm is not installed. It is not yet available on Python 3.14."
)
# Device selection
self.device = decide_device(
accelerator_options.device, supported_devices=vlm_options.supported_devices
)
_log.debug(f"Available device for VLM: {self.device}")
# Resolve artifacts path / cache folder
repo_cache_folder = vlm_options.repo_id.replace("/", "--")
if artifacts_path is None:
artifacts_path = self.download_models(
self.vlm_options.repo_id, revision=self.vlm_options.revision
)
elif (artifacts_path / repo_cache_folder).exists():
artifacts_path = artifacts_path / repo_cache_folder
# --------- Strict split & validation of extra_generation_config ---------
extra_cfg = self.vlm_options.extra_generation_config
load_cfg = {k: v for k, v in extra_cfg.items() if k in self._VLLM_ENGINE_KEYS}
gen_cfg = {k: v for k, v in extra_cfg.items() if k in self._VLLM_SAMPLING_KEYS}
unknown = sorted(
k
for k in extra_cfg.keys()
if k not in self._VLLM_ENGINE_KEYS and k not in self._VLLM_SAMPLING_KEYS
)
if unknown:
_log.warning(
"Ignoring unknown extra_generation_config keys for vLLM: %s", unknown
)
# --------- Construct LLM kwargs (engine/load-time) ---------
llm_kwargs: Dict[str, Any] = {
"model": str(artifacts_path),
"model_impl": "transformers",
"limit_mm_per_prompt": {"image": 1},
"revision": self.vlm_options.revision,
"trust_remote_code": self.vlm_options.trust_remote_code,
**load_cfg,
}
if self.device == "cpu":
llm_kwargs.setdefault("enforce_eager", True)
else:
llm_kwargs.setdefault(
"gpu_memory_utilization", 0.3
) # room for other models
# Quantization (kept as-is; coarse)
if self.vlm_options.quantized and self.vlm_options.load_in_8bit:
llm_kwargs.setdefault("quantization", "bitsandbytes")
# Initialize vLLM LLM
self.llm = LLM(**llm_kwargs)
# Initialize processor for prompt templating (needed for CHAT style)
self.processor = AutoProcessor.from_pretrained(
artifacts_path,
trust_remote_code=self.vlm_options.trust_remote_code,
revision=self.vlm_options.revision,
)
# --------- SamplingParams (runtime) ---------
self.sampling_params = SamplingParams(
temperature=self.temperature,
max_tokens=self.max_new_tokens,
stop=(self.vlm_options.stop_strings or None),
**gen_cfg,
)
def __call__(
self, conv_res: ConversionResult, page_batch: Iterable[Page]
) -> Iterable[Page]:
# If disabled, pass-through
if not self.enabled:
for page in page_batch:
yield page
return
page_list = list(page_batch)
if not page_list:
return
# Preserve original order
original_order = page_list[:]
# Separate valid/invalid
valid_pages: list[Page] = []
invalid_pages: list[Page] = []
for page in page_list:
assert page._backend is not None
if page._backend.is_valid():
valid_pages.append(page)
else:
invalid_pages.append(page)
if valid_pages:
with TimeRecorder(conv_res, "vlm"):
images: list[Image] = []
user_prompts: list[str] = []
pages_with_images: list[Page] = []
for page in valid_pages:
assert page.size is not None
hi_res_image = page.get_image(
scale=self.vlm_options.scale,
max_size=self.vlm_options.max_size,
)
if hi_res_image is None:
continue
images.append(hi_res_image)
# Define prompt structure
user_prompt = self.vlm_options.build_prompt(page.parsed_page)
user_prompts.append(user_prompt)
pages_with_images.append(page)
if images:
with TimeRecorder(conv_res, "vlm_inference"):
predictions = list(self.process_images(images, user_prompts))
for page, prediction in zip(pages_with_images, predictions):
page.predictions.vlm_response = prediction
# Yield in original order
for page in original_order:
yield page
def process_images(
self,
image_batch: Iterable[Union[Image, np.ndarray]],
prompt: Union[str, list[str]],
) -> Iterable[VlmPrediction]:
"""Process images in a single batched vLLM inference call."""
import numpy as np
from PIL import Image as PILImage
# -- Normalize images to RGB PIL
pil_images: list[Image] = []
for img in image_batch:
if isinstance(img, np.ndarray):
if img.ndim == 3 and img.shape[2] in (3, 4):
pil_img = PILImage.fromarray(img.astype(np.uint8))
elif img.ndim == 2:
pil_img = PILImage.fromarray(img.astype(np.uint8), mode="L")
else:
raise ValueError(f"Unsupported numpy array shape: {img.shape}")
else:
pil_img = img
if pil_img.mode != "RGB":
pil_img = pil_img.convert("RGB")
pil_images.append(pil_img)
if not pil_images:
return
# Normalize prompts
if isinstance(prompt, str):
user_prompts = [prompt] * len(pil_images)
elif isinstance(prompt, list):
if len(prompt) != len(pil_images):
raise ValueError(
f"Number of prompts ({len(prompt)}) must match number of images ({len(pil_images)})"
)
user_prompts = prompt
else:
raise ValueError(f"prompt must be str or list[str], got {type(prompt)}")
# Format prompts
prompts: list[str] = [self.formulate_prompt(up) for up in user_prompts]
# Build vLLM inputs
llm_inputs = [
{"prompt": p, "multi_modal_data": {"image": im}}
for p, im in zip(prompts, pil_images)
]
# Generate
assert self.llm is not None and self.sampling_params is not None
start_time = time.time()
outputs = self.llm.generate(llm_inputs, sampling_params=self.sampling_params) # type: ignore
generation_time = time.time() - start_time
# Optional debug
if outputs:
try:
num_tokens_within_batch = len(outputs[0].outputs[0].token_ids)
_log.debug(
f"Generated {num_tokens_within_batch} tokens for batch in {generation_time:.2f}s."
)
except Exception:
num_tokens_within_batch = 0
# Emit predictions
for output in outputs:
text = output.outputs[0].text if output.outputs else ""
stop_reason = (
VlmStopReason.END_OF_SEQUENCE
if output.outputs[0].stop_reason
else VlmStopReason.LENGTH
)
generated_tokens = (
[VlmPredictionToken(token=int(t)) for t in output.outputs[0].token_ids]
if self.vlm_options.track_generated_tokens
else []
)
num_tokens = len(generated_tokens)
decoded_text = self.vlm_options.decode_response(text)
yield VlmPrediction(
text=decoded_text,
generation_time=generation_time,
num_tokens=num_tokens,
stop_reason=stop_reason,
generated_tokens=generated_tokens,
)