mirror of
https://github.com/DS4SD/docling.git
synced 2025-12-08 12:48:28 +00:00
feat: add num_tokens as attribtue for VlmPrediction
This commit is contained in:
@@ -192,6 +192,7 @@ class VlmPrediction(BaseModel):
|
||||
text: str = ""
|
||||
generated_tokens: list[VlmPredictionToken] = []
|
||||
generation_time: float = -1
|
||||
num_tokens: int = 0
|
||||
|
||||
|
||||
class ContainerElement(
|
||||
|
||||
@@ -363,13 +363,15 @@ class HuggingFaceTransformersVlmModel(BaseVlmPageModel, HuggingFaceModelDownload
|
||||
decoded_texts = [text.rstrip(pad_token) for text in decoded_texts]
|
||||
|
||||
# -- Optional logging
|
||||
num_tokens = 0
|
||||
if generated_ids.shape[0] > 0:
|
||||
num_tokens = int(generated_ids[0].shape[0])
|
||||
_log.debug(
|
||||
f"Generated {int(generated_ids[0].shape[0])} tokens in {generation_time:.2f}s "
|
||||
f"Generated {num_tokens} tokens in {generation_time:.2f}s "
|
||||
f"for batch size {generated_ids.shape[0]}."
|
||||
)
|
||||
|
||||
for text in decoded_texts:
|
||||
# Apply decode_response to the output text
|
||||
decoded_text = self.vlm_options.decode_response(text)
|
||||
yield VlmPrediction(text=decoded_text, generation_time=generation_time)
|
||||
yield VlmPrediction(text=decoded_text, generation_time=generation_time, num_tokens=num_tokens)
|
||||
|
||||
@@ -313,5 +313,6 @@ class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
|
||||
text=decoded_output,
|
||||
generation_time=generation_time,
|
||||
generated_tokens=tokens,
|
||||
num_tokens = len(tokens)
|
||||
)
|
||||
_log.debug("MLX model: Released global lock")
|
||||
|
||||
@@ -278,13 +278,15 @@ class NuExtractTransformersModel(BaseVlmModel, HuggingFaceModelDownloadMixin):
|
||||
)
|
||||
|
||||
# Optional logging
|
||||
num_tokens = 0
|
||||
if generated_ids.shape[0] > 0: # type: ignore
|
||||
num_tokens = int(generated_ids[0].shape[0])
|
||||
_log.debug(
|
||||
f"Generated {int(generated_ids[0].shape[0])} tokens in {generation_time:.2f}s "
|
||||
f"Generated {num_tokens} tokens in {generation_time:.2f}s "
|
||||
f"for batch size {generated_ids.shape[0]}." # type: ignore
|
||||
)
|
||||
|
||||
for text in decoded_texts:
|
||||
# Apply decode_response to the output text
|
||||
decoded_text = self.vlm_options.decode_response(text)
|
||||
yield VlmPrediction(text=decoded_text, generation_time=generation_time)
|
||||
yield VlmPrediction(text=decoded_text, generation_time=generation_time, num_tokens=num_tokens)
|
||||
|
||||
@@ -291,10 +291,11 @@ class VllmVlmModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
|
||||
num_tokens = len(outputs[0].outputs[0].token_ids)
|
||||
_log.debug(f"Generated {num_tokens} tokens in {generation_time:.2f}s.")
|
||||
except Exception:
|
||||
num_tokens = 0
|
||||
pass
|
||||
|
||||
# Emit predictions
|
||||
for output in outputs:
|
||||
text = output.outputs[0].text if output.outputs else ""
|
||||
decoded_text = self.vlm_options.decode_response(text)
|
||||
yield VlmPrediction(text=decoded_text, generation_time=generation_time)
|
||||
yield VlmPrediction(text=decoded_text, generation_time=generation_time, num_tokens=num_tokens)
|
||||
|
||||
Reference in New Issue
Block a user