From 311287f562923854bd056b57ec94d1eacf17cec0 Mon Sep 17 00:00:00 2001 From: ElHachem02 Date: Fri, 17 Oct 2025 16:41:35 +0200 Subject: [PATCH] feat: add num_tokens as attribtue for VlmPrediction --- docling/datamodel/base_models.py | 1 + docling/models/vlm_models_inline/hf_transformers_model.py | 6 ++++-- docling/models/vlm_models_inline/mlx_model.py | 1 + .../vlm_models_inline/nuextract_transformers_model.py | 6 ++++-- docling/models/vlm_models_inline/vllm_model.py | 3 ++- 5 files changed, 12 insertions(+), 5 deletions(-) diff --git a/docling/datamodel/base_models.py b/docling/datamodel/base_models.py index 627ecf5f..1f2d9e04 100644 --- a/docling/datamodel/base_models.py +++ b/docling/datamodel/base_models.py @@ -192,6 +192,7 @@ class VlmPrediction(BaseModel): text: str = "" generated_tokens: list[VlmPredictionToken] = [] generation_time: float = -1 + num_tokens: int = 0 class ContainerElement( diff --git a/docling/models/vlm_models_inline/hf_transformers_model.py b/docling/models/vlm_models_inline/hf_transformers_model.py index 1f4d752c..fe9ed4b9 100644 --- a/docling/models/vlm_models_inline/hf_transformers_model.py +++ b/docling/models/vlm_models_inline/hf_transformers_model.py @@ -363,13 +363,15 @@ class HuggingFaceTransformersVlmModel(BaseVlmPageModel, HuggingFaceModelDownload decoded_texts = [text.rstrip(pad_token) for text in decoded_texts] # -- Optional logging + num_tokens = 0 if generated_ids.shape[0] > 0: + num_tokens = int(generated_ids[0].shape[0]) _log.debug( - f"Generated {int(generated_ids[0].shape[0])} tokens in {generation_time:.2f}s " + f"Generated {num_tokens} tokens in {generation_time:.2f}s " f"for batch size {generated_ids.shape[0]}." ) for text in decoded_texts: # Apply decode_response to the output text decoded_text = self.vlm_options.decode_response(text) - yield VlmPrediction(text=decoded_text, generation_time=generation_time) + yield VlmPrediction(text=decoded_text, generation_time=generation_time, num_tokens=num_tokens) diff --git a/docling/models/vlm_models_inline/mlx_model.py b/docling/models/vlm_models_inline/mlx_model.py index ac4cf9c8..04bfa548 100644 --- a/docling/models/vlm_models_inline/mlx_model.py +++ b/docling/models/vlm_models_inline/mlx_model.py @@ -313,5 +313,6 @@ class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin): text=decoded_output, generation_time=generation_time, generated_tokens=tokens, + num_tokens = len(tokens) ) _log.debug("MLX model: Released global lock") diff --git a/docling/models/vlm_models_inline/nuextract_transformers_model.py b/docling/models/vlm_models_inline/nuextract_transformers_model.py index 3eb64d49..0e47f35d 100644 --- a/docling/models/vlm_models_inline/nuextract_transformers_model.py +++ b/docling/models/vlm_models_inline/nuextract_transformers_model.py @@ -278,13 +278,15 @@ class NuExtractTransformersModel(BaseVlmModel, HuggingFaceModelDownloadMixin): ) # Optional logging + num_tokens = 0 if generated_ids.shape[0] > 0: # type: ignore + num_tokens = int(generated_ids[0].shape[0]) _log.debug( - f"Generated {int(generated_ids[0].shape[0])} tokens in {generation_time:.2f}s " + f"Generated {num_tokens} tokens in {generation_time:.2f}s " f"for batch size {generated_ids.shape[0]}." # type: ignore ) for text in decoded_texts: # Apply decode_response to the output text decoded_text = self.vlm_options.decode_response(text) - yield VlmPrediction(text=decoded_text, generation_time=generation_time) + yield VlmPrediction(text=decoded_text, generation_time=generation_time, num_tokens=num_tokens) diff --git a/docling/models/vlm_models_inline/vllm_model.py b/docling/models/vlm_models_inline/vllm_model.py index 8f019b96..f839acf3 100644 --- a/docling/models/vlm_models_inline/vllm_model.py +++ b/docling/models/vlm_models_inline/vllm_model.py @@ -291,10 +291,11 @@ class VllmVlmModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin): num_tokens = len(outputs[0].outputs[0].token_ids) _log.debug(f"Generated {num_tokens} tokens in {generation_time:.2f}s.") except Exception: + num_tokens = 0 pass # Emit predictions for output in outputs: text = output.outputs[0].text if output.outputs else "" decoded_text = self.vlm_options.decode_response(text) - yield VlmPrediction(text=decoded_text, generation_time=generation_time) + yield VlmPrediction(text=decoded_text, generation_time=generation_time, num_tokens=num_tokens)