mirror of
https://github.com/DS4SD/docling.git
synced 2025-12-08 12:48:28 +00:00
feat(vlm): add num_tokens as attribtue for VlmPrediction (#2489)
* feat: add num_tokens as attribtue for VlmPrediction * feat: implement tokens tracking for api_vlm Signed-off-by: Peter El Hachem <peter.el.hachem@ibm.com> * DCO Remediation Commit for ElHachem02 <peterelhachem02@gmail.com> I, ElHachem02 <peterelhachem02@gmail.com>, hereby add my Signed-off-by to this commit:311287f562Signed-off-by: Peter El Hachem <peter.el.hachem@ibm.com> * DCO Remediation Commit for ElHachem02 <peterelhachem02@gmail.com> I, ElHachem02 <peterelhachem02@gmail.com>, hereby add my Signed-off-by to this commit:311287f562Signed-off-by: ElHachem02 <peterelhachem02@gmail.com> * update return type Signed-off-by: ElHachem02 <peterelhachem02@gmail.com> * add time recorder for vlm inference and track generated token ids depending on config Signed-off-by: ElHachem02 <peterelhachem02@gmail.com> * update num_tokens to have None as value on exception Signed-off-by: ElHachem02 <peterelhachem02@gmail.com> * set default value of num_tokens to None Signed-off-by: ElHachem02 <peterelhachem02@gmail.com> --------- Signed-off-by: Peter El Hachem <peter.el.hachem@ibm.com> Signed-off-by: ElHachem02 <peterelhachem02@gmail.com> Signed-off-by: peets <100425207+ElHachem02@users.noreply.github.com> Co-authored-by: Peter El Hachem <peter.el.hachem@ibm.com>
This commit is contained in:
@@ -207,6 +207,8 @@ class VlmPrediction(BaseModel):
|
|||||||
text: str = ""
|
text: str = ""
|
||||||
generated_tokens: list[VlmPredictionToken] = []
|
generated_tokens: list[VlmPredictionToken] = []
|
||||||
generation_time: float = -1
|
generation_time: float = -1
|
||||||
|
num_tokens: Optional[int] = None
|
||||||
|
stop_reason: Optional[str] = None # todo define an enum for possible stop reasons
|
||||||
|
|
||||||
|
|
||||||
class ContainerElement(
|
class ContainerElement(
|
||||||
|
|||||||
@@ -82,6 +82,7 @@ class InlineVlmOptions(BaseVlmOptions):
|
|||||||
|
|
||||||
use_kv_cache: bool = True
|
use_kv_cache: bool = True
|
||||||
max_new_tokens: int = 4096
|
max_new_tokens: int = 4096
|
||||||
|
track_generated_tokens: bool = False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def repo_cache_folder(self) -> str:
|
def repo_cache_folder(self) -> str:
|
||||||
|
|||||||
@@ -73,7 +73,7 @@ class ApiVlmModel(BasePageModel):
|
|||||||
# Skip non-GenerationStopper criteria (should have been caught in validation)
|
# Skip non-GenerationStopper criteria (should have been caught in validation)
|
||||||
|
|
||||||
# Streaming path with early abort support
|
# Streaming path with early abort support
|
||||||
page_tags = api_image_request_streaming(
|
page_tags, num_tokens = api_image_request_streaming(
|
||||||
image=hi_res_image,
|
image=hi_res_image,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
url=self.vlm_options.url,
|
url=self.vlm_options.url,
|
||||||
@@ -84,7 +84,7 @@ class ApiVlmModel(BasePageModel):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Non-streaming fallback (existing behavior)
|
# Non-streaming fallback (existing behavior)
|
||||||
page_tags = api_image_request(
|
page_tags, num_tokens = api_image_request(
|
||||||
image=hi_res_image,
|
image=hi_res_image,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
url=self.vlm_options.url,
|
url=self.vlm_options.url,
|
||||||
@@ -94,7 +94,9 @@ class ApiVlmModel(BasePageModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
page_tags = self.vlm_options.decode_response(page_tags)
|
page_tags = self.vlm_options.decode_response(page_tags)
|
||||||
page.predictions.vlm_response = VlmPrediction(text=page_tags)
|
page.predictions.vlm_response = VlmPrediction(
|
||||||
|
text=page_tags, num_tokens=num_tokens
|
||||||
|
)
|
||||||
return page
|
return page
|
||||||
|
|
||||||
with ThreadPoolExecutor(max_workers=self.concurrency) as executor:
|
with ThreadPoolExecutor(max_workers=self.concurrency) as executor:
|
||||||
|
|||||||
@@ -367,13 +367,19 @@ class HuggingFaceTransformersVlmModel(BaseVlmPageModel, HuggingFaceModelDownload
|
|||||||
decoded_texts = [text.rstrip(pad_token) for text in decoded_texts]
|
decoded_texts = [text.rstrip(pad_token) for text in decoded_texts]
|
||||||
|
|
||||||
# -- Optional logging
|
# -- Optional logging
|
||||||
|
num_tokens = None
|
||||||
if generated_ids.shape[0] > 0:
|
if generated_ids.shape[0] > 0:
|
||||||
|
num_tokens = int(generated_ids[0].shape[0])
|
||||||
_log.debug(
|
_log.debug(
|
||||||
f"Generated {int(generated_ids[0].shape[0])} tokens in {generation_time:.2f}s "
|
f"Generated {num_tokens} tokens in {generation_time:.2f}s "
|
||||||
f"for batch size {generated_ids.shape[0]}."
|
f"for batch size {generated_ids.shape[0]}."
|
||||||
)
|
)
|
||||||
|
|
||||||
for text in decoded_texts:
|
for text in decoded_texts:
|
||||||
# Apply decode_response to the output text
|
# Apply decode_response to the output text
|
||||||
decoded_text = self.vlm_options.decode_response(text)
|
decoded_text = self.vlm_options.decode_response(text)
|
||||||
yield VlmPrediction(text=decoded_text, generation_time=generation_time)
|
yield VlmPrediction(
|
||||||
|
text=decoded_text,
|
||||||
|
generation_time=generation_time,
|
||||||
|
num_tokens=num_tokens,
|
||||||
|
)
|
||||||
|
|||||||
@@ -318,5 +318,6 @@ class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
|
|||||||
text=decoded_output,
|
text=decoded_output,
|
||||||
generation_time=generation_time,
|
generation_time=generation_time,
|
||||||
generated_tokens=tokens,
|
generated_tokens=tokens,
|
||||||
|
num_tokens=len(tokens),
|
||||||
)
|
)
|
||||||
_log.debug("MLX model: Released global lock")
|
_log.debug("MLX model: Released global lock")
|
||||||
|
|||||||
@@ -282,13 +282,19 @@ class NuExtractTransformersModel(BaseVlmModel, HuggingFaceModelDownloadMixin):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Optional logging
|
# Optional logging
|
||||||
|
num_tokens = None
|
||||||
if generated_ids.shape[0] > 0: # type: ignore
|
if generated_ids.shape[0] > 0: # type: ignore
|
||||||
|
num_tokens = int(generated_ids[0].shape[0])
|
||||||
_log.debug(
|
_log.debug(
|
||||||
f"Generated {int(generated_ids[0].shape[0])} tokens in {generation_time:.2f}s "
|
f"Generated {num_tokens} tokens in {generation_time:.2f}s "
|
||||||
f"for batch size {generated_ids.shape[0]}." # type: ignore
|
f"for batch size {generated_ids.shape[0]}." # type: ignore
|
||||||
)
|
)
|
||||||
|
|
||||||
for text in decoded_texts:
|
for text in decoded_texts:
|
||||||
# Apply decode_response to the output text
|
# Apply decode_response to the output text
|
||||||
decoded_text = self.vlm_options.decode_response(text)
|
decoded_text = self.vlm_options.decode_response(text)
|
||||||
yield VlmPrediction(text=decoded_text, generation_time=generation_time)
|
yield VlmPrediction(
|
||||||
|
text=decoded_text,
|
||||||
|
generation_time=generation_time,
|
||||||
|
num_tokens=num_tokens,
|
||||||
|
)
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import numpy as np
|
|||||||
from PIL.Image import Image
|
from PIL.Image import Image
|
||||||
|
|
||||||
from docling.datamodel.accelerator_options import AcceleratorOptions
|
from docling.datamodel.accelerator_options import AcceleratorOptions
|
||||||
from docling.datamodel.base_models import Page, VlmPrediction
|
from docling.datamodel.base_models import Page, VlmPrediction, VlmPredictionToken
|
||||||
from docling.datamodel.document import ConversionResult
|
from docling.datamodel.document import ConversionResult
|
||||||
from docling.datamodel.pipeline_options_vlm_model import (
|
from docling.datamodel.pipeline_options_vlm_model import (
|
||||||
InlineVlmOptions,
|
InlineVlmOptions,
|
||||||
@@ -88,7 +88,7 @@ class VllmVlmModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
|
|||||||
vlm_options: InlineVlmOptions,
|
vlm_options: InlineVlmOptions,
|
||||||
):
|
):
|
||||||
self.enabled = enabled
|
self.enabled = enabled
|
||||||
self.vlm_options = vlm_options
|
self.vlm_options: InlineVlmOptions = vlm_options
|
||||||
|
|
||||||
self.llm = None
|
self.llm = None
|
||||||
self.sampling_params = None
|
self.sampling_params = None
|
||||||
@@ -234,6 +234,7 @@ class VllmVlmModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
|
|||||||
pages_with_images.append(page)
|
pages_with_images.append(page)
|
||||||
|
|
||||||
if images:
|
if images:
|
||||||
|
with TimeRecorder(conv_res, "vlm_inference"):
|
||||||
predictions = list(self.process_images(images, user_prompts))
|
predictions = list(self.process_images(images, user_prompts))
|
||||||
for page, prediction in zip(pages_with_images, predictions):
|
for page, prediction in zip(pages_with_images, predictions):
|
||||||
page.predictions.vlm_response = prediction
|
page.predictions.vlm_response = prediction
|
||||||
@@ -300,13 +301,34 @@ class VllmVlmModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
|
|||||||
# Optional debug
|
# Optional debug
|
||||||
if outputs:
|
if outputs:
|
||||||
try:
|
try:
|
||||||
num_tokens = len(outputs[0].outputs[0].token_ids)
|
num_tokens_within_batch = len(outputs[0].outputs[0].token_ids)
|
||||||
_log.debug(f"Generated {num_tokens} tokens in {generation_time:.2f}s.")
|
_log.debug(
|
||||||
|
f"Generated {num_tokens_within_batch} tokens for batch in {generation_time:.2f}s."
|
||||||
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
num_tokens_within_batch = 0
|
||||||
|
|
||||||
# Emit predictions
|
# Emit predictions
|
||||||
for output in outputs:
|
for output in outputs:
|
||||||
text = output.outputs[0].text if output.outputs else ""
|
text = output.outputs[0].text if output.outputs else ""
|
||||||
|
stop_reason = output.outputs[0].stop_reason if output.outputs else ""
|
||||||
|
generated_tokens = [
|
||||||
|
VlmPredictionToken(token=int(p)) for p in output.outputs[0].token_ids
|
||||||
|
]
|
||||||
|
num_tokens = len(generated_tokens)
|
||||||
decoded_text = self.vlm_options.decode_response(text)
|
decoded_text = self.vlm_options.decode_response(text)
|
||||||
yield VlmPrediction(text=decoded_text, generation_time=generation_time)
|
if self.vlm_options.track_generated_tokens:
|
||||||
|
yield VlmPrediction(
|
||||||
|
text=decoded_text,
|
||||||
|
generation_time=generation_time,
|
||||||
|
num_tokens=num_tokens,
|
||||||
|
stop_reason=stop_reason,
|
||||||
|
generated_tokens=generated_tokens,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
yield VlmPrediction(
|
||||||
|
text=decoded_text,
|
||||||
|
generation_time=generation_time,
|
||||||
|
num_tokens=num_tokens,
|
||||||
|
stop_reason=stop_reason,
|
||||||
|
)
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import base64
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import Optional
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
@@ -21,7 +21,7 @@ def api_image_request(
|
|||||||
timeout: float = 20,
|
timeout: float = 20,
|
||||||
headers: Optional[dict[str, str]] = None,
|
headers: Optional[dict[str, str]] = None,
|
||||||
**params,
|
**params,
|
||||||
) -> str:
|
) -> Tuple[str, Optional[int]]:
|
||||||
img_io = BytesIO()
|
img_io = BytesIO()
|
||||||
image.save(img_io, "PNG")
|
image.save(img_io, "PNG")
|
||||||
image_base64 = base64.b64encode(img_io.getvalue()).decode("utf-8")
|
image_base64 = base64.b64encode(img_io.getvalue()).decode("utf-8")
|
||||||
@@ -60,7 +60,8 @@ def api_image_request(
|
|||||||
|
|
||||||
api_resp = OpenAiApiResponse.model_validate_json(r.text)
|
api_resp = OpenAiApiResponse.model_validate_json(r.text)
|
||||||
generated_text = api_resp.choices[0].message.content.strip()
|
generated_text = api_resp.choices[0].message.content.strip()
|
||||||
return generated_text
|
num_tokens = api_resp.usage.total_tokens
|
||||||
|
return generated_text, num_tokens
|
||||||
|
|
||||||
|
|
||||||
def api_image_request_streaming(
|
def api_image_request_streaming(
|
||||||
@@ -72,7 +73,7 @@ def api_image_request_streaming(
|
|||||||
headers: Optional[dict[str, str]] = None,
|
headers: Optional[dict[str, str]] = None,
|
||||||
generation_stoppers: list[GenerationStopper] = [],
|
generation_stoppers: list[GenerationStopper] = [],
|
||||||
**params,
|
**params,
|
||||||
) -> str:
|
) -> Tuple[str, Optional[int]]:
|
||||||
"""
|
"""
|
||||||
Stream a chat completion from an OpenAI-compatible server (e.g., vLLM).
|
Stream a chat completion from an OpenAI-compatible server (e.g., vLLM).
|
||||||
Parses SSE lines: 'data: {json}\\n\\n', terminated by 'data: [DONE]'.
|
Parses SSE lines: 'data: {json}\\n\\n', terminated by 'data: [DONE]'.
|
||||||
@@ -150,6 +151,16 @@ def api_image_request_streaming(
|
|||||||
_log.debug("Unexpected SSE chunk shape: %s", e)
|
_log.debug("Unexpected SSE chunk shape: %s", e)
|
||||||
piece = ""
|
piece = ""
|
||||||
|
|
||||||
|
# Try to extract token count
|
||||||
|
num_tokens = None
|
||||||
|
try:
|
||||||
|
if "usage" in obj:
|
||||||
|
usage = obj["usage"]
|
||||||
|
num_tokens = usage.get("total_tokens")
|
||||||
|
except Exception as e:
|
||||||
|
num_tokens = None
|
||||||
|
_log.debug("Usage key not included in response: %s", e)
|
||||||
|
|
||||||
if piece:
|
if piece:
|
||||||
full_text.append(piece)
|
full_text.append(piece)
|
||||||
for stopper in generation_stoppers:
|
for stopper in generation_stoppers:
|
||||||
@@ -162,6 +173,6 @@ def api_image_request_streaming(
|
|||||||
# closing the connection when we exit the 'with' block.
|
# closing the connection when we exit the 'with' block.
|
||||||
# vLLM/OpenAI-compatible servers will detect the client disconnect
|
# vLLM/OpenAI-compatible servers will detect the client disconnect
|
||||||
# and abort the request server-side.
|
# and abort the request server-side.
|
||||||
return "".join(full_text)
|
return "".join(full_text), num_tokens
|
||||||
|
|
||||||
return "".join(full_text)
|
return "".join(full_text), num_tokens
|
||||||
|
|||||||
Reference in New Issue
Block a user