mirror of
https://github.com/DS4SD/docling.git
synced 2025-12-08 12:48:28 +00:00
feat(vlm): add num_tokens as attribtue for VlmPrediction (#2489)
* feat: add num_tokens as attribtue for VlmPrediction * feat: implement tokens tracking for api_vlm Signed-off-by: Peter El Hachem <peter.el.hachem@ibm.com> * DCO Remediation Commit for ElHachem02 <peterelhachem02@gmail.com> I, ElHachem02 <peterelhachem02@gmail.com>, hereby add my Signed-off-by to this commit:311287f562Signed-off-by: Peter El Hachem <peter.el.hachem@ibm.com> * DCO Remediation Commit for ElHachem02 <peterelhachem02@gmail.com> I, ElHachem02 <peterelhachem02@gmail.com>, hereby add my Signed-off-by to this commit:311287f562Signed-off-by: ElHachem02 <peterelhachem02@gmail.com> * update return type Signed-off-by: ElHachem02 <peterelhachem02@gmail.com> * add time recorder for vlm inference and track generated token ids depending on config Signed-off-by: ElHachem02 <peterelhachem02@gmail.com> * update num_tokens to have None as value on exception Signed-off-by: ElHachem02 <peterelhachem02@gmail.com> * set default value of num_tokens to None Signed-off-by: ElHachem02 <peterelhachem02@gmail.com> --------- Signed-off-by: Peter El Hachem <peter.el.hachem@ibm.com> Signed-off-by: ElHachem02 <peterelhachem02@gmail.com> Signed-off-by: peets <100425207+ElHachem02@users.noreply.github.com> Co-authored-by: Peter El Hachem <peter.el.hachem@ibm.com>
This commit is contained in:
@@ -207,6 +207,8 @@ class VlmPrediction(BaseModel):
|
||||
text: str = ""
|
||||
generated_tokens: list[VlmPredictionToken] = []
|
||||
generation_time: float = -1
|
||||
num_tokens: Optional[int] = None
|
||||
stop_reason: Optional[str] = None # todo define an enum for possible stop reasons
|
||||
|
||||
|
||||
class ContainerElement(
|
||||
|
||||
@@ -82,6 +82,7 @@ class InlineVlmOptions(BaseVlmOptions):
|
||||
|
||||
use_kv_cache: bool = True
|
||||
max_new_tokens: int = 4096
|
||||
track_generated_tokens: bool = False
|
||||
|
||||
@property
|
||||
def repo_cache_folder(self) -> str:
|
||||
|
||||
@@ -73,7 +73,7 @@ class ApiVlmModel(BasePageModel):
|
||||
# Skip non-GenerationStopper criteria (should have been caught in validation)
|
||||
|
||||
# Streaming path with early abort support
|
||||
page_tags = api_image_request_streaming(
|
||||
page_tags, num_tokens = api_image_request_streaming(
|
||||
image=hi_res_image,
|
||||
prompt=prompt,
|
||||
url=self.vlm_options.url,
|
||||
@@ -84,7 +84,7 @@ class ApiVlmModel(BasePageModel):
|
||||
)
|
||||
else:
|
||||
# Non-streaming fallback (existing behavior)
|
||||
page_tags = api_image_request(
|
||||
page_tags, num_tokens = api_image_request(
|
||||
image=hi_res_image,
|
||||
prompt=prompt,
|
||||
url=self.vlm_options.url,
|
||||
@@ -94,7 +94,9 @@ class ApiVlmModel(BasePageModel):
|
||||
)
|
||||
|
||||
page_tags = self.vlm_options.decode_response(page_tags)
|
||||
page.predictions.vlm_response = VlmPrediction(text=page_tags)
|
||||
page.predictions.vlm_response = VlmPrediction(
|
||||
text=page_tags, num_tokens=num_tokens
|
||||
)
|
||||
return page
|
||||
|
||||
with ThreadPoolExecutor(max_workers=self.concurrency) as executor:
|
||||
|
||||
@@ -367,13 +367,19 @@ class HuggingFaceTransformersVlmModel(BaseVlmPageModel, HuggingFaceModelDownload
|
||||
decoded_texts = [text.rstrip(pad_token) for text in decoded_texts]
|
||||
|
||||
# -- Optional logging
|
||||
num_tokens = None
|
||||
if generated_ids.shape[0] > 0:
|
||||
num_tokens = int(generated_ids[0].shape[0])
|
||||
_log.debug(
|
||||
f"Generated {int(generated_ids[0].shape[0])} tokens in {generation_time:.2f}s "
|
||||
f"Generated {num_tokens} tokens in {generation_time:.2f}s "
|
||||
f"for batch size {generated_ids.shape[0]}."
|
||||
)
|
||||
|
||||
for text in decoded_texts:
|
||||
# Apply decode_response to the output text
|
||||
decoded_text = self.vlm_options.decode_response(text)
|
||||
yield VlmPrediction(text=decoded_text, generation_time=generation_time)
|
||||
yield VlmPrediction(
|
||||
text=decoded_text,
|
||||
generation_time=generation_time,
|
||||
num_tokens=num_tokens,
|
||||
)
|
||||
|
||||
@@ -318,5 +318,6 @@ class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
|
||||
text=decoded_output,
|
||||
generation_time=generation_time,
|
||||
generated_tokens=tokens,
|
||||
num_tokens=len(tokens),
|
||||
)
|
||||
_log.debug("MLX model: Released global lock")
|
||||
|
||||
@@ -282,13 +282,19 @@ class NuExtractTransformersModel(BaseVlmModel, HuggingFaceModelDownloadMixin):
|
||||
)
|
||||
|
||||
# Optional logging
|
||||
num_tokens = None
|
||||
if generated_ids.shape[0] > 0: # type: ignore
|
||||
num_tokens = int(generated_ids[0].shape[0])
|
||||
_log.debug(
|
||||
f"Generated {int(generated_ids[0].shape[0])} tokens in {generation_time:.2f}s "
|
||||
f"Generated {num_tokens} tokens in {generation_time:.2f}s "
|
||||
f"for batch size {generated_ids.shape[0]}." # type: ignore
|
||||
)
|
||||
|
||||
for text in decoded_texts:
|
||||
# Apply decode_response to the output text
|
||||
decoded_text = self.vlm_options.decode_response(text)
|
||||
yield VlmPrediction(text=decoded_text, generation_time=generation_time)
|
||||
yield VlmPrediction(
|
||||
text=decoded_text,
|
||||
generation_time=generation_time,
|
||||
num_tokens=num_tokens,
|
||||
)
|
||||
|
||||
@@ -9,7 +9,7 @@ import numpy as np
|
||||
from PIL.Image import Image
|
||||
|
||||
from docling.datamodel.accelerator_options import AcceleratorOptions
|
||||
from docling.datamodel.base_models import Page, VlmPrediction
|
||||
from docling.datamodel.base_models import Page, VlmPrediction, VlmPredictionToken
|
||||
from docling.datamodel.document import ConversionResult
|
||||
from docling.datamodel.pipeline_options_vlm_model import (
|
||||
InlineVlmOptions,
|
||||
@@ -88,7 +88,7 @@ class VllmVlmModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
|
||||
vlm_options: InlineVlmOptions,
|
||||
):
|
||||
self.enabled = enabled
|
||||
self.vlm_options = vlm_options
|
||||
self.vlm_options: InlineVlmOptions = vlm_options
|
||||
|
||||
self.llm = None
|
||||
self.sampling_params = None
|
||||
@@ -234,7 +234,8 @@ class VllmVlmModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
|
||||
pages_with_images.append(page)
|
||||
|
||||
if images:
|
||||
predictions = list(self.process_images(images, user_prompts))
|
||||
with TimeRecorder(conv_res, "vlm_inference"):
|
||||
predictions = list(self.process_images(images, user_prompts))
|
||||
for page, prediction in zip(pages_with_images, predictions):
|
||||
page.predictions.vlm_response = prediction
|
||||
|
||||
@@ -300,13 +301,34 @@ class VllmVlmModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
|
||||
# Optional debug
|
||||
if outputs:
|
||||
try:
|
||||
num_tokens = len(outputs[0].outputs[0].token_ids)
|
||||
_log.debug(f"Generated {num_tokens} tokens in {generation_time:.2f}s.")
|
||||
num_tokens_within_batch = len(outputs[0].outputs[0].token_ids)
|
||||
_log.debug(
|
||||
f"Generated {num_tokens_within_batch} tokens for batch in {generation_time:.2f}s."
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
num_tokens_within_batch = 0
|
||||
|
||||
# Emit predictions
|
||||
for output in outputs:
|
||||
text = output.outputs[0].text if output.outputs else ""
|
||||
stop_reason = output.outputs[0].stop_reason if output.outputs else ""
|
||||
generated_tokens = [
|
||||
VlmPredictionToken(token=int(p)) for p in output.outputs[0].token_ids
|
||||
]
|
||||
num_tokens = len(generated_tokens)
|
||||
decoded_text = self.vlm_options.decode_response(text)
|
||||
yield VlmPrediction(text=decoded_text, generation_time=generation_time)
|
||||
if self.vlm_options.track_generated_tokens:
|
||||
yield VlmPrediction(
|
||||
text=decoded_text,
|
||||
generation_time=generation_time,
|
||||
num_tokens=num_tokens,
|
||||
stop_reason=stop_reason,
|
||||
generated_tokens=generated_tokens,
|
||||
)
|
||||
else:
|
||||
yield VlmPrediction(
|
||||
text=decoded_text,
|
||||
generation_time=generation_time,
|
||||
num_tokens=num_tokens,
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
|
||||
@@ -2,7 +2,7 @@ import base64
|
||||
import json
|
||||
import logging
|
||||
from io import BytesIO
|
||||
from typing import Optional
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import requests
|
||||
from PIL import Image
|
||||
@@ -21,7 +21,7 @@ def api_image_request(
|
||||
timeout: float = 20,
|
||||
headers: Optional[dict[str, str]] = None,
|
||||
**params,
|
||||
) -> str:
|
||||
) -> Tuple[str, Optional[int]]:
|
||||
img_io = BytesIO()
|
||||
image.save(img_io, "PNG")
|
||||
image_base64 = base64.b64encode(img_io.getvalue()).decode("utf-8")
|
||||
@@ -60,7 +60,8 @@ def api_image_request(
|
||||
|
||||
api_resp = OpenAiApiResponse.model_validate_json(r.text)
|
||||
generated_text = api_resp.choices[0].message.content.strip()
|
||||
return generated_text
|
||||
num_tokens = api_resp.usage.total_tokens
|
||||
return generated_text, num_tokens
|
||||
|
||||
|
||||
def api_image_request_streaming(
|
||||
@@ -72,7 +73,7 @@ def api_image_request_streaming(
|
||||
headers: Optional[dict[str, str]] = None,
|
||||
generation_stoppers: list[GenerationStopper] = [],
|
||||
**params,
|
||||
) -> str:
|
||||
) -> Tuple[str, Optional[int]]:
|
||||
"""
|
||||
Stream a chat completion from an OpenAI-compatible server (e.g., vLLM).
|
||||
Parses SSE lines: 'data: {json}\\n\\n', terminated by 'data: [DONE]'.
|
||||
@@ -150,6 +151,16 @@ def api_image_request_streaming(
|
||||
_log.debug("Unexpected SSE chunk shape: %s", e)
|
||||
piece = ""
|
||||
|
||||
# Try to extract token count
|
||||
num_tokens = None
|
||||
try:
|
||||
if "usage" in obj:
|
||||
usage = obj["usage"]
|
||||
num_tokens = usage.get("total_tokens")
|
||||
except Exception as e:
|
||||
num_tokens = None
|
||||
_log.debug("Usage key not included in response: %s", e)
|
||||
|
||||
if piece:
|
||||
full_text.append(piece)
|
||||
for stopper in generation_stoppers:
|
||||
@@ -162,6 +173,6 @@ def api_image_request_streaming(
|
||||
# closing the connection when we exit the 'with' block.
|
||||
# vLLM/OpenAI-compatible servers will detect the client disconnect
|
||||
# and abort the request server-side.
|
||||
return "".join(full_text)
|
||||
return "".join(full_text), num_tokens
|
||||
|
||||
return "".join(full_text)
|
||||
return "".join(full_text), num_tokens
|
||||
|
||||
Reference in New Issue
Block a user