feat(vlm): add num_tokens as attribtue for VlmPrediction (#2489)

* feat: add num_tokens as attribtue for VlmPrediction

* feat: implement tokens tracking for api_vlm

Signed-off-by: Peter El Hachem <peter.el.hachem@ibm.com>

* DCO Remediation Commit for ElHachem02 <peterelhachem02@gmail.com>

I, ElHachem02 <peterelhachem02@gmail.com>, hereby add my Signed-off-by to this commit: 311287f562

Signed-off-by: Peter El Hachem <peter.el.hachem@ibm.com>

* DCO Remediation Commit for ElHachem02 <peterelhachem02@gmail.com>

I, ElHachem02 <peterelhachem02@gmail.com>, hereby add my Signed-off-by to this commit: 311287f562

Signed-off-by: ElHachem02 <peterelhachem02@gmail.com>

* update return type

Signed-off-by: ElHachem02 <peterelhachem02@gmail.com>

* add time recorder for vlm inference and track generated token ids depending on config

Signed-off-by: ElHachem02 <peterelhachem02@gmail.com>

* update num_tokens to have None as value on exception

Signed-off-by: ElHachem02 <peterelhachem02@gmail.com>

* set default value of num_tokens to None

Signed-off-by: ElHachem02 <peterelhachem02@gmail.com>

---------

Signed-off-by: Peter El Hachem <peter.el.hachem@ibm.com>
Signed-off-by: ElHachem02 <peterelhachem02@gmail.com>
Signed-off-by: peets <100425207+ElHachem02@users.noreply.github.com>
Co-authored-by: Peter El Hachem <peter.el.hachem@ibm.com>
This commit is contained in:
peets
2025-10-28 17:18:44 +01:00
committed by GitHub
parent cdffb47b9a
commit b6c892b505
8 changed files with 71 additions and 20 deletions

View File

@@ -207,6 +207,8 @@ class VlmPrediction(BaseModel):
text: str = ""
generated_tokens: list[VlmPredictionToken] = []
generation_time: float = -1
num_tokens: Optional[int] = None
stop_reason: Optional[str] = None # todo define an enum for possible stop reasons
class ContainerElement(

View File

@@ -82,6 +82,7 @@ class InlineVlmOptions(BaseVlmOptions):
use_kv_cache: bool = True
max_new_tokens: int = 4096
track_generated_tokens: bool = False
@property
def repo_cache_folder(self) -> str:

View File

@@ -73,7 +73,7 @@ class ApiVlmModel(BasePageModel):
# Skip non-GenerationStopper criteria (should have been caught in validation)
# Streaming path with early abort support
page_tags = api_image_request_streaming(
page_tags, num_tokens = api_image_request_streaming(
image=hi_res_image,
prompt=prompt,
url=self.vlm_options.url,
@@ -84,7 +84,7 @@ class ApiVlmModel(BasePageModel):
)
else:
# Non-streaming fallback (existing behavior)
page_tags = api_image_request(
page_tags, num_tokens = api_image_request(
image=hi_res_image,
prompt=prompt,
url=self.vlm_options.url,
@@ -94,7 +94,9 @@ class ApiVlmModel(BasePageModel):
)
page_tags = self.vlm_options.decode_response(page_tags)
page.predictions.vlm_response = VlmPrediction(text=page_tags)
page.predictions.vlm_response = VlmPrediction(
text=page_tags, num_tokens=num_tokens
)
return page
with ThreadPoolExecutor(max_workers=self.concurrency) as executor:

View File

@@ -367,13 +367,19 @@ class HuggingFaceTransformersVlmModel(BaseVlmPageModel, HuggingFaceModelDownload
decoded_texts = [text.rstrip(pad_token) for text in decoded_texts]
# -- Optional logging
num_tokens = None
if generated_ids.shape[0] > 0:
num_tokens = int(generated_ids[0].shape[0])
_log.debug(
f"Generated {int(generated_ids[0].shape[0])} tokens in {generation_time:.2f}s "
f"Generated {num_tokens} tokens in {generation_time:.2f}s "
f"for batch size {generated_ids.shape[0]}."
)
for text in decoded_texts:
# Apply decode_response to the output text
decoded_text = self.vlm_options.decode_response(text)
yield VlmPrediction(text=decoded_text, generation_time=generation_time)
yield VlmPrediction(
text=decoded_text,
generation_time=generation_time,
num_tokens=num_tokens,
)

View File

@@ -318,5 +318,6 @@ class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
text=decoded_output,
generation_time=generation_time,
generated_tokens=tokens,
num_tokens=len(tokens),
)
_log.debug("MLX model: Released global lock")

View File

@@ -282,13 +282,19 @@ class NuExtractTransformersModel(BaseVlmModel, HuggingFaceModelDownloadMixin):
)
# Optional logging
num_tokens = None
if generated_ids.shape[0] > 0: # type: ignore
num_tokens = int(generated_ids[0].shape[0])
_log.debug(
f"Generated {int(generated_ids[0].shape[0])} tokens in {generation_time:.2f}s "
f"Generated {num_tokens} tokens in {generation_time:.2f}s "
f"for batch size {generated_ids.shape[0]}." # type: ignore
)
for text in decoded_texts:
# Apply decode_response to the output text
decoded_text = self.vlm_options.decode_response(text)
yield VlmPrediction(text=decoded_text, generation_time=generation_time)
yield VlmPrediction(
text=decoded_text,
generation_time=generation_time,
num_tokens=num_tokens,
)

View File

@@ -9,7 +9,7 @@ import numpy as np
from PIL.Image import Image
from docling.datamodel.accelerator_options import AcceleratorOptions
from docling.datamodel.base_models import Page, VlmPrediction
from docling.datamodel.base_models import Page, VlmPrediction, VlmPredictionToken
from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options_vlm_model import (
InlineVlmOptions,
@@ -88,7 +88,7 @@ class VllmVlmModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
vlm_options: InlineVlmOptions,
):
self.enabled = enabled
self.vlm_options = vlm_options
self.vlm_options: InlineVlmOptions = vlm_options
self.llm = None
self.sampling_params = None
@@ -234,7 +234,8 @@ class VllmVlmModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
pages_with_images.append(page)
if images:
predictions = list(self.process_images(images, user_prompts))
with TimeRecorder(conv_res, "vlm_inference"):
predictions = list(self.process_images(images, user_prompts))
for page, prediction in zip(pages_with_images, predictions):
page.predictions.vlm_response = prediction
@@ -300,13 +301,34 @@ class VllmVlmModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
# Optional debug
if outputs:
try:
num_tokens = len(outputs[0].outputs[0].token_ids)
_log.debug(f"Generated {num_tokens} tokens in {generation_time:.2f}s.")
num_tokens_within_batch = len(outputs[0].outputs[0].token_ids)
_log.debug(
f"Generated {num_tokens_within_batch} tokens for batch in {generation_time:.2f}s."
)
except Exception:
pass
num_tokens_within_batch = 0
# Emit predictions
for output in outputs:
text = output.outputs[0].text if output.outputs else ""
stop_reason = output.outputs[0].stop_reason if output.outputs else ""
generated_tokens = [
VlmPredictionToken(token=int(p)) for p in output.outputs[0].token_ids
]
num_tokens = len(generated_tokens)
decoded_text = self.vlm_options.decode_response(text)
yield VlmPrediction(text=decoded_text, generation_time=generation_time)
if self.vlm_options.track_generated_tokens:
yield VlmPrediction(
text=decoded_text,
generation_time=generation_time,
num_tokens=num_tokens,
stop_reason=stop_reason,
generated_tokens=generated_tokens,
)
else:
yield VlmPrediction(
text=decoded_text,
generation_time=generation_time,
num_tokens=num_tokens,
stop_reason=stop_reason,
)

View File

@@ -2,7 +2,7 @@ import base64
import json
import logging
from io import BytesIO
from typing import Optional
from typing import Dict, List, Optional, Tuple
import requests
from PIL import Image
@@ -21,7 +21,7 @@ def api_image_request(
timeout: float = 20,
headers: Optional[dict[str, str]] = None,
**params,
) -> str:
) -> Tuple[str, Optional[int]]:
img_io = BytesIO()
image.save(img_io, "PNG")
image_base64 = base64.b64encode(img_io.getvalue()).decode("utf-8")
@@ -60,7 +60,8 @@ def api_image_request(
api_resp = OpenAiApiResponse.model_validate_json(r.text)
generated_text = api_resp.choices[0].message.content.strip()
return generated_text
num_tokens = api_resp.usage.total_tokens
return generated_text, num_tokens
def api_image_request_streaming(
@@ -72,7 +73,7 @@ def api_image_request_streaming(
headers: Optional[dict[str, str]] = None,
generation_stoppers: list[GenerationStopper] = [],
**params,
) -> str:
) -> Tuple[str, Optional[int]]:
"""
Stream a chat completion from an OpenAI-compatible server (e.g., vLLM).
Parses SSE lines: 'data: {json}\\n\\n', terminated by 'data: [DONE]'.
@@ -150,6 +151,16 @@ def api_image_request_streaming(
_log.debug("Unexpected SSE chunk shape: %s", e)
piece = ""
# Try to extract token count
num_tokens = None
try:
if "usage" in obj:
usage = obj["usage"]
num_tokens = usage.get("total_tokens")
except Exception as e:
num_tokens = None
_log.debug("Usage key not included in response: %s", e)
if piece:
full_text.append(piece)
for stopper in generation_stoppers:
@@ -162,6 +173,6 @@ def api_image_request_streaming(
# closing the connection when we exit the 'with' block.
# vLLM/OpenAI-compatible servers will detect the client disconnect
# and abort the request server-side.
return "".join(full_text)
return "".join(full_text), num_tokens
return "".join(full_text)
return "".join(full_text), num_tokens