diff --git a/docling/datamodel/base_models.py b/docling/datamodel/base_models.py index 2b206314..ae2f48cc 100644 --- a/docling/datamodel/base_models.py +++ b/docling/datamodel/base_models.py @@ -207,6 +207,8 @@ class VlmPrediction(BaseModel): text: str = "" generated_tokens: list[VlmPredictionToken] = [] generation_time: float = -1 + num_tokens: Optional[int] = None + stop_reason: Optional[str] = None # todo define an enum for possible stop reasons class ContainerElement( diff --git a/docling/datamodel/pipeline_options_vlm_model.py b/docling/datamodel/pipeline_options_vlm_model.py index 02014c34..9b03d58a 100644 --- a/docling/datamodel/pipeline_options_vlm_model.py +++ b/docling/datamodel/pipeline_options_vlm_model.py @@ -82,6 +82,7 @@ class InlineVlmOptions(BaseVlmOptions): use_kv_cache: bool = True max_new_tokens: int = 4096 + track_generated_tokens: bool = False @property def repo_cache_folder(self) -> str: diff --git a/docling/models/api_vlm_model.py b/docling/models/api_vlm_model.py index bcdb97f5..baa7fe44 100644 --- a/docling/models/api_vlm_model.py +++ b/docling/models/api_vlm_model.py @@ -73,7 +73,7 @@ class ApiVlmModel(BasePageModel): # Skip non-GenerationStopper criteria (should have been caught in validation) # Streaming path with early abort support - page_tags = api_image_request_streaming( + page_tags, num_tokens = api_image_request_streaming( image=hi_res_image, prompt=prompt, url=self.vlm_options.url, @@ -84,7 +84,7 @@ class ApiVlmModel(BasePageModel): ) else: # Non-streaming fallback (existing behavior) - page_tags = api_image_request( + page_tags, num_tokens = api_image_request( image=hi_res_image, prompt=prompt, url=self.vlm_options.url, @@ -94,7 +94,9 @@ class ApiVlmModel(BasePageModel): ) page_tags = self.vlm_options.decode_response(page_tags) - page.predictions.vlm_response = VlmPrediction(text=page_tags) + page.predictions.vlm_response = VlmPrediction( + text=page_tags, num_tokens=num_tokens + ) return page with ThreadPoolExecutor(max_workers=self.concurrency) as executor: diff --git a/docling/models/vlm_models_inline/hf_transformers_model.py b/docling/models/vlm_models_inline/hf_transformers_model.py index 482e2a14..6a19be23 100644 --- a/docling/models/vlm_models_inline/hf_transformers_model.py +++ b/docling/models/vlm_models_inline/hf_transformers_model.py @@ -367,13 +367,19 @@ class HuggingFaceTransformersVlmModel(BaseVlmPageModel, HuggingFaceModelDownload decoded_texts = [text.rstrip(pad_token) for text in decoded_texts] # -- Optional logging + num_tokens = None if generated_ids.shape[0] > 0: + num_tokens = int(generated_ids[0].shape[0]) _log.debug( - f"Generated {int(generated_ids[0].shape[0])} tokens in {generation_time:.2f}s " + f"Generated {num_tokens} tokens in {generation_time:.2f}s " f"for batch size {generated_ids.shape[0]}." ) for text in decoded_texts: # Apply decode_response to the output text decoded_text = self.vlm_options.decode_response(text) - yield VlmPrediction(text=decoded_text, generation_time=generation_time) + yield VlmPrediction( + text=decoded_text, + generation_time=generation_time, + num_tokens=num_tokens, + ) diff --git a/docling/models/vlm_models_inline/mlx_model.py b/docling/models/vlm_models_inline/mlx_model.py index 2aa7b48a..b2516355 100644 --- a/docling/models/vlm_models_inline/mlx_model.py +++ b/docling/models/vlm_models_inline/mlx_model.py @@ -318,5 +318,6 @@ class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin): text=decoded_output, generation_time=generation_time, generated_tokens=tokens, + num_tokens=len(tokens), ) _log.debug("MLX model: Released global lock") diff --git a/docling/models/vlm_models_inline/nuextract_transformers_model.py b/docling/models/vlm_models_inline/nuextract_transformers_model.py index b573c2bb..5559093d 100644 --- a/docling/models/vlm_models_inline/nuextract_transformers_model.py +++ b/docling/models/vlm_models_inline/nuextract_transformers_model.py @@ -282,13 +282,19 @@ class NuExtractTransformersModel(BaseVlmModel, HuggingFaceModelDownloadMixin): ) # Optional logging + num_tokens = None if generated_ids.shape[0] > 0: # type: ignore + num_tokens = int(generated_ids[0].shape[0]) _log.debug( - f"Generated {int(generated_ids[0].shape[0])} tokens in {generation_time:.2f}s " + f"Generated {num_tokens} tokens in {generation_time:.2f}s " f"for batch size {generated_ids.shape[0]}." # type: ignore ) for text in decoded_texts: # Apply decode_response to the output text decoded_text = self.vlm_options.decode_response(text) - yield VlmPrediction(text=decoded_text, generation_time=generation_time) + yield VlmPrediction( + text=decoded_text, + generation_time=generation_time, + num_tokens=num_tokens, + ) diff --git a/docling/models/vlm_models_inline/vllm_model.py b/docling/models/vlm_models_inline/vllm_model.py index e89c2f4c..e373f0c5 100644 --- a/docling/models/vlm_models_inline/vllm_model.py +++ b/docling/models/vlm_models_inline/vllm_model.py @@ -9,7 +9,7 @@ import numpy as np from PIL.Image import Image from docling.datamodel.accelerator_options import AcceleratorOptions -from docling.datamodel.base_models import Page, VlmPrediction +from docling.datamodel.base_models import Page, VlmPrediction, VlmPredictionToken from docling.datamodel.document import ConversionResult from docling.datamodel.pipeline_options_vlm_model import ( InlineVlmOptions, @@ -88,7 +88,7 @@ class VllmVlmModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin): vlm_options: InlineVlmOptions, ): self.enabled = enabled - self.vlm_options = vlm_options + self.vlm_options: InlineVlmOptions = vlm_options self.llm = None self.sampling_params = None @@ -234,7 +234,8 @@ class VllmVlmModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin): pages_with_images.append(page) if images: - predictions = list(self.process_images(images, user_prompts)) + with TimeRecorder(conv_res, "vlm_inference"): + predictions = list(self.process_images(images, user_prompts)) for page, prediction in zip(pages_with_images, predictions): page.predictions.vlm_response = prediction @@ -300,13 +301,34 @@ class VllmVlmModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin): # Optional debug if outputs: try: - num_tokens = len(outputs[0].outputs[0].token_ids) - _log.debug(f"Generated {num_tokens} tokens in {generation_time:.2f}s.") + num_tokens_within_batch = len(outputs[0].outputs[0].token_ids) + _log.debug( + f"Generated {num_tokens_within_batch} tokens for batch in {generation_time:.2f}s." + ) except Exception: - pass + num_tokens_within_batch = 0 # Emit predictions for output in outputs: text = output.outputs[0].text if output.outputs else "" + stop_reason = output.outputs[0].stop_reason if output.outputs else "" + generated_tokens = [ + VlmPredictionToken(token=int(p)) for p in output.outputs[0].token_ids + ] + num_tokens = len(generated_tokens) decoded_text = self.vlm_options.decode_response(text) - yield VlmPrediction(text=decoded_text, generation_time=generation_time) + if self.vlm_options.track_generated_tokens: + yield VlmPrediction( + text=decoded_text, + generation_time=generation_time, + num_tokens=num_tokens, + stop_reason=stop_reason, + generated_tokens=generated_tokens, + ) + else: + yield VlmPrediction( + text=decoded_text, + generation_time=generation_time, + num_tokens=num_tokens, + stop_reason=stop_reason, + ) diff --git a/docling/utils/api_image_request.py b/docling/utils/api_image_request.py index c8a58e25..a1e55f7a 100644 --- a/docling/utils/api_image_request.py +++ b/docling/utils/api_image_request.py @@ -2,7 +2,7 @@ import base64 import json import logging from io import BytesIO -from typing import Optional +from typing import Dict, List, Optional, Tuple import requests from PIL import Image @@ -21,7 +21,7 @@ def api_image_request( timeout: float = 20, headers: Optional[dict[str, str]] = None, **params, -) -> str: +) -> Tuple[str, Optional[int]]: img_io = BytesIO() image.save(img_io, "PNG") image_base64 = base64.b64encode(img_io.getvalue()).decode("utf-8") @@ -60,7 +60,8 @@ def api_image_request( api_resp = OpenAiApiResponse.model_validate_json(r.text) generated_text = api_resp.choices[0].message.content.strip() - return generated_text + num_tokens = api_resp.usage.total_tokens + return generated_text, num_tokens def api_image_request_streaming( @@ -72,7 +73,7 @@ def api_image_request_streaming( headers: Optional[dict[str, str]] = None, generation_stoppers: list[GenerationStopper] = [], **params, -) -> str: +) -> Tuple[str, Optional[int]]: """ Stream a chat completion from an OpenAI-compatible server (e.g., vLLM). Parses SSE lines: 'data: {json}\\n\\n', terminated by 'data: [DONE]'. @@ -150,6 +151,16 @@ def api_image_request_streaming( _log.debug("Unexpected SSE chunk shape: %s", e) piece = "" + # Try to extract token count + num_tokens = None + try: + if "usage" in obj: + usage = obj["usage"] + num_tokens = usage.get("total_tokens") + except Exception as e: + num_tokens = None + _log.debug("Usage key not included in response: %s", e) + if piece: full_text.append(piece) for stopper in generation_stoppers: @@ -162,6 +173,6 @@ def api_image_request_streaming( # closing the connection when we exit the 'with' block. # vLLM/OpenAI-compatible servers will detect the client disconnect # and abort the request server-side. - return "".join(full_text) + return "".join(full_text), num_tokens - return "".join(full_text) + return "".join(full_text), num_tokens