feat(vlm): add num_tokens as attribtue for VlmPrediction (#2489)

* feat: add num_tokens as attribtue for VlmPrediction

* feat: implement tokens tracking for api_vlm

Signed-off-by: Peter El Hachem <peter.el.hachem@ibm.com>

* DCO Remediation Commit for ElHachem02 <peterelhachem02@gmail.com>

I, ElHachem02 <peterelhachem02@gmail.com>, hereby add my Signed-off-by to this commit: 311287f562

Signed-off-by: Peter El Hachem <peter.el.hachem@ibm.com>

* DCO Remediation Commit for ElHachem02 <peterelhachem02@gmail.com>

I, ElHachem02 <peterelhachem02@gmail.com>, hereby add my Signed-off-by to this commit: 311287f562

Signed-off-by: ElHachem02 <peterelhachem02@gmail.com>

* update return type

Signed-off-by: ElHachem02 <peterelhachem02@gmail.com>

* add time recorder for vlm inference and track generated token ids depending on config

Signed-off-by: ElHachem02 <peterelhachem02@gmail.com>

* update num_tokens to have None as value on exception

Signed-off-by: ElHachem02 <peterelhachem02@gmail.com>

* set default value of num_tokens to None

Signed-off-by: ElHachem02 <peterelhachem02@gmail.com>

---------

Signed-off-by: Peter El Hachem <peter.el.hachem@ibm.com>
Signed-off-by: ElHachem02 <peterelhachem02@gmail.com>
Signed-off-by: peets <100425207+ElHachem02@users.noreply.github.com>
Co-authored-by: Peter El Hachem <peter.el.hachem@ibm.com>
This commit is contained in:
peets
2025-10-28 17:18:44 +01:00
committed by GitHub
parent cdffb47b9a
commit b6c892b505
8 changed files with 71 additions and 20 deletions

View File

@@ -207,6 +207,8 @@ class VlmPrediction(BaseModel):
text: str = "" text: str = ""
generated_tokens: list[VlmPredictionToken] = [] generated_tokens: list[VlmPredictionToken] = []
generation_time: float = -1 generation_time: float = -1
num_tokens: Optional[int] = None
stop_reason: Optional[str] = None # todo define an enum for possible stop reasons
class ContainerElement( class ContainerElement(

View File

@@ -82,6 +82,7 @@ class InlineVlmOptions(BaseVlmOptions):
use_kv_cache: bool = True use_kv_cache: bool = True
max_new_tokens: int = 4096 max_new_tokens: int = 4096
track_generated_tokens: bool = False
@property @property
def repo_cache_folder(self) -> str: def repo_cache_folder(self) -> str:

View File

@@ -73,7 +73,7 @@ class ApiVlmModel(BasePageModel):
# Skip non-GenerationStopper criteria (should have been caught in validation) # Skip non-GenerationStopper criteria (should have been caught in validation)
# Streaming path with early abort support # Streaming path with early abort support
page_tags = api_image_request_streaming( page_tags, num_tokens = api_image_request_streaming(
image=hi_res_image, image=hi_res_image,
prompt=prompt, prompt=prompt,
url=self.vlm_options.url, url=self.vlm_options.url,
@@ -84,7 +84,7 @@ class ApiVlmModel(BasePageModel):
) )
else: else:
# Non-streaming fallback (existing behavior) # Non-streaming fallback (existing behavior)
page_tags = api_image_request( page_tags, num_tokens = api_image_request(
image=hi_res_image, image=hi_res_image,
prompt=prompt, prompt=prompt,
url=self.vlm_options.url, url=self.vlm_options.url,
@@ -94,7 +94,9 @@ class ApiVlmModel(BasePageModel):
) )
page_tags = self.vlm_options.decode_response(page_tags) page_tags = self.vlm_options.decode_response(page_tags)
page.predictions.vlm_response = VlmPrediction(text=page_tags) page.predictions.vlm_response = VlmPrediction(
text=page_tags, num_tokens=num_tokens
)
return page return page
with ThreadPoolExecutor(max_workers=self.concurrency) as executor: with ThreadPoolExecutor(max_workers=self.concurrency) as executor:

View File

@@ -367,13 +367,19 @@ class HuggingFaceTransformersVlmModel(BaseVlmPageModel, HuggingFaceModelDownload
decoded_texts = [text.rstrip(pad_token) for text in decoded_texts] decoded_texts = [text.rstrip(pad_token) for text in decoded_texts]
# -- Optional logging # -- Optional logging
num_tokens = None
if generated_ids.shape[0] > 0: if generated_ids.shape[0] > 0:
num_tokens = int(generated_ids[0].shape[0])
_log.debug( _log.debug(
f"Generated {int(generated_ids[0].shape[0])} tokens in {generation_time:.2f}s " f"Generated {num_tokens} tokens in {generation_time:.2f}s "
f"for batch size {generated_ids.shape[0]}." f"for batch size {generated_ids.shape[0]}."
) )
for text in decoded_texts: for text in decoded_texts:
# Apply decode_response to the output text # Apply decode_response to the output text
decoded_text = self.vlm_options.decode_response(text) decoded_text = self.vlm_options.decode_response(text)
yield VlmPrediction(text=decoded_text, generation_time=generation_time) yield VlmPrediction(
text=decoded_text,
generation_time=generation_time,
num_tokens=num_tokens,
)

View File

@@ -318,5 +318,6 @@ class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
text=decoded_output, text=decoded_output,
generation_time=generation_time, generation_time=generation_time,
generated_tokens=tokens, generated_tokens=tokens,
num_tokens=len(tokens),
) )
_log.debug("MLX model: Released global lock") _log.debug("MLX model: Released global lock")

View File

@@ -282,13 +282,19 @@ class NuExtractTransformersModel(BaseVlmModel, HuggingFaceModelDownloadMixin):
) )
# Optional logging # Optional logging
num_tokens = None
if generated_ids.shape[0] > 0: # type: ignore if generated_ids.shape[0] > 0: # type: ignore
num_tokens = int(generated_ids[0].shape[0])
_log.debug( _log.debug(
f"Generated {int(generated_ids[0].shape[0])} tokens in {generation_time:.2f}s " f"Generated {num_tokens} tokens in {generation_time:.2f}s "
f"for batch size {generated_ids.shape[0]}." # type: ignore f"for batch size {generated_ids.shape[0]}." # type: ignore
) )
for text in decoded_texts: for text in decoded_texts:
# Apply decode_response to the output text # Apply decode_response to the output text
decoded_text = self.vlm_options.decode_response(text) decoded_text = self.vlm_options.decode_response(text)
yield VlmPrediction(text=decoded_text, generation_time=generation_time) yield VlmPrediction(
text=decoded_text,
generation_time=generation_time,
num_tokens=num_tokens,
)

View File

@@ -9,7 +9,7 @@ import numpy as np
from PIL.Image import Image from PIL.Image import Image
from docling.datamodel.accelerator_options import AcceleratorOptions from docling.datamodel.accelerator_options import AcceleratorOptions
from docling.datamodel.base_models import Page, VlmPrediction from docling.datamodel.base_models import Page, VlmPrediction, VlmPredictionToken
from docling.datamodel.document import ConversionResult from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options_vlm_model import ( from docling.datamodel.pipeline_options_vlm_model import (
InlineVlmOptions, InlineVlmOptions,
@@ -88,7 +88,7 @@ class VllmVlmModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
vlm_options: InlineVlmOptions, vlm_options: InlineVlmOptions,
): ):
self.enabled = enabled self.enabled = enabled
self.vlm_options = vlm_options self.vlm_options: InlineVlmOptions = vlm_options
self.llm = None self.llm = None
self.sampling_params = None self.sampling_params = None
@@ -234,7 +234,8 @@ class VllmVlmModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
pages_with_images.append(page) pages_with_images.append(page)
if images: if images:
predictions = list(self.process_images(images, user_prompts)) with TimeRecorder(conv_res, "vlm_inference"):
predictions = list(self.process_images(images, user_prompts))
for page, prediction in zip(pages_with_images, predictions): for page, prediction in zip(pages_with_images, predictions):
page.predictions.vlm_response = prediction page.predictions.vlm_response = prediction
@@ -300,13 +301,34 @@ class VllmVlmModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
# Optional debug # Optional debug
if outputs: if outputs:
try: try:
num_tokens = len(outputs[0].outputs[0].token_ids) num_tokens_within_batch = len(outputs[0].outputs[0].token_ids)
_log.debug(f"Generated {num_tokens} tokens in {generation_time:.2f}s.") _log.debug(
f"Generated {num_tokens_within_batch} tokens for batch in {generation_time:.2f}s."
)
except Exception: except Exception:
pass num_tokens_within_batch = 0
# Emit predictions # Emit predictions
for output in outputs: for output in outputs:
text = output.outputs[0].text if output.outputs else "" text = output.outputs[0].text if output.outputs else ""
stop_reason = output.outputs[0].stop_reason if output.outputs else ""
generated_tokens = [
VlmPredictionToken(token=int(p)) for p in output.outputs[0].token_ids
]
num_tokens = len(generated_tokens)
decoded_text = self.vlm_options.decode_response(text) decoded_text = self.vlm_options.decode_response(text)
yield VlmPrediction(text=decoded_text, generation_time=generation_time) if self.vlm_options.track_generated_tokens:
yield VlmPrediction(
text=decoded_text,
generation_time=generation_time,
num_tokens=num_tokens,
stop_reason=stop_reason,
generated_tokens=generated_tokens,
)
else:
yield VlmPrediction(
text=decoded_text,
generation_time=generation_time,
num_tokens=num_tokens,
stop_reason=stop_reason,
)

View File

@@ -2,7 +2,7 @@ import base64
import json import json
import logging import logging
from io import BytesIO from io import BytesIO
from typing import Optional from typing import Dict, List, Optional, Tuple
import requests import requests
from PIL import Image from PIL import Image
@@ -21,7 +21,7 @@ def api_image_request(
timeout: float = 20, timeout: float = 20,
headers: Optional[dict[str, str]] = None, headers: Optional[dict[str, str]] = None,
**params, **params,
) -> str: ) -> Tuple[str, Optional[int]]:
img_io = BytesIO() img_io = BytesIO()
image.save(img_io, "PNG") image.save(img_io, "PNG")
image_base64 = base64.b64encode(img_io.getvalue()).decode("utf-8") image_base64 = base64.b64encode(img_io.getvalue()).decode("utf-8")
@@ -60,7 +60,8 @@ def api_image_request(
api_resp = OpenAiApiResponse.model_validate_json(r.text) api_resp = OpenAiApiResponse.model_validate_json(r.text)
generated_text = api_resp.choices[0].message.content.strip() generated_text = api_resp.choices[0].message.content.strip()
return generated_text num_tokens = api_resp.usage.total_tokens
return generated_text, num_tokens
def api_image_request_streaming( def api_image_request_streaming(
@@ -72,7 +73,7 @@ def api_image_request_streaming(
headers: Optional[dict[str, str]] = None, headers: Optional[dict[str, str]] = None,
generation_stoppers: list[GenerationStopper] = [], generation_stoppers: list[GenerationStopper] = [],
**params, **params,
) -> str: ) -> Tuple[str, Optional[int]]:
""" """
Stream a chat completion from an OpenAI-compatible server (e.g., vLLM). Stream a chat completion from an OpenAI-compatible server (e.g., vLLM).
Parses SSE lines: 'data: {json}\\n\\n', terminated by 'data: [DONE]'. Parses SSE lines: 'data: {json}\\n\\n', terminated by 'data: [DONE]'.
@@ -150,6 +151,16 @@ def api_image_request_streaming(
_log.debug("Unexpected SSE chunk shape: %s", e) _log.debug("Unexpected SSE chunk shape: %s", e)
piece = "" piece = ""
# Try to extract token count
num_tokens = None
try:
if "usage" in obj:
usage = obj["usage"]
num_tokens = usage.get("total_tokens")
except Exception as e:
num_tokens = None
_log.debug("Usage key not included in response: %s", e)
if piece: if piece:
full_text.append(piece) full_text.append(piece)
for stopper in generation_stoppers: for stopper in generation_stoppers:
@@ -162,6 +173,6 @@ def api_image_request_streaming(
# closing the connection when we exit the 'with' block. # closing the connection when we exit the 'with' block.
# vLLM/OpenAI-compatible servers will detect the client disconnect # vLLM/OpenAI-compatible servers will detect the client disconnect
# and abort the request server-side. # and abort the request server-side.
return "".join(full_text) return "".join(full_text), num_tokens
return "".join(full_text) return "".join(full_text), num_tokens