feat: add num_tokens as attribtue for VlmPrediction

This commit is contained in:
ElHachem02
2025-10-17 16:41:35 +02:00
parent dd03b53117
commit 311287f562
5 changed files with 12 additions and 5 deletions

View File

@@ -192,6 +192,7 @@ class VlmPrediction(BaseModel):
text: str = "" text: str = ""
generated_tokens: list[VlmPredictionToken] = [] generated_tokens: list[VlmPredictionToken] = []
generation_time: float = -1 generation_time: float = -1
num_tokens: int = 0
class ContainerElement( class ContainerElement(

View File

@@ -363,13 +363,15 @@ class HuggingFaceTransformersVlmModel(BaseVlmPageModel, HuggingFaceModelDownload
decoded_texts = [text.rstrip(pad_token) for text in decoded_texts] decoded_texts = [text.rstrip(pad_token) for text in decoded_texts]
# -- Optional logging # -- Optional logging
num_tokens = 0
if generated_ids.shape[0] > 0: if generated_ids.shape[0] > 0:
num_tokens = int(generated_ids[0].shape[0])
_log.debug( _log.debug(
f"Generated {int(generated_ids[0].shape[0])} tokens in {generation_time:.2f}s " f"Generated {num_tokens} tokens in {generation_time:.2f}s "
f"for batch size {generated_ids.shape[0]}." f"for batch size {generated_ids.shape[0]}."
) )
for text in decoded_texts: for text in decoded_texts:
# Apply decode_response to the output text # Apply decode_response to the output text
decoded_text = self.vlm_options.decode_response(text) decoded_text = self.vlm_options.decode_response(text)
yield VlmPrediction(text=decoded_text, generation_time=generation_time) yield VlmPrediction(text=decoded_text, generation_time=generation_time, num_tokens=num_tokens)

View File

@@ -313,5 +313,6 @@ class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
text=decoded_output, text=decoded_output,
generation_time=generation_time, generation_time=generation_time,
generated_tokens=tokens, generated_tokens=tokens,
num_tokens = len(tokens)
) )
_log.debug("MLX model: Released global lock") _log.debug("MLX model: Released global lock")

View File

@@ -278,13 +278,15 @@ class NuExtractTransformersModel(BaseVlmModel, HuggingFaceModelDownloadMixin):
) )
# Optional logging # Optional logging
num_tokens = 0
if generated_ids.shape[0] > 0: # type: ignore if generated_ids.shape[0] > 0: # type: ignore
num_tokens = int(generated_ids[0].shape[0])
_log.debug( _log.debug(
f"Generated {int(generated_ids[0].shape[0])} tokens in {generation_time:.2f}s " f"Generated {num_tokens} tokens in {generation_time:.2f}s "
f"for batch size {generated_ids.shape[0]}." # type: ignore f"for batch size {generated_ids.shape[0]}." # type: ignore
) )
for text in decoded_texts: for text in decoded_texts:
# Apply decode_response to the output text # Apply decode_response to the output text
decoded_text = self.vlm_options.decode_response(text) decoded_text = self.vlm_options.decode_response(text)
yield VlmPrediction(text=decoded_text, generation_time=generation_time) yield VlmPrediction(text=decoded_text, generation_time=generation_time, num_tokens=num_tokens)

View File

@@ -291,10 +291,11 @@ class VllmVlmModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
num_tokens = len(outputs[0].outputs[0].token_ids) num_tokens = len(outputs[0].outputs[0].token_ids)
_log.debug(f"Generated {num_tokens} tokens in {generation_time:.2f}s.") _log.debug(f"Generated {num_tokens} tokens in {generation_time:.2f}s.")
except Exception: except Exception:
num_tokens = 0
pass pass
# Emit predictions # Emit predictions
for output in outputs: for output in outputs:
text = output.outputs[0].text if output.outputs else "" text = output.outputs[0].text if output.outputs else ""
decoded_text = self.vlm_options.decode_response(text) decoded_text = self.vlm_options.decode_response(text)
yield VlmPrediction(text=decoded_text, generation_time=generation_time) yield VlmPrediction(text=decoded_text, generation_time=generation_time, num_tokens=num_tokens)