mirror of
https://github.com/DS4SD/docling.git
synced 2025-12-08 20:58:11 +00:00
feat: add num_tokens as attribtue for VlmPrediction
This commit is contained in:
@@ -192,6 +192,7 @@ class VlmPrediction(BaseModel):
|
|||||||
text: str = ""
|
text: str = ""
|
||||||
generated_tokens: list[VlmPredictionToken] = []
|
generated_tokens: list[VlmPredictionToken] = []
|
||||||
generation_time: float = -1
|
generation_time: float = -1
|
||||||
|
num_tokens: int = 0
|
||||||
|
|
||||||
|
|
||||||
class ContainerElement(
|
class ContainerElement(
|
||||||
|
|||||||
@@ -363,13 +363,15 @@ class HuggingFaceTransformersVlmModel(BaseVlmPageModel, HuggingFaceModelDownload
|
|||||||
decoded_texts = [text.rstrip(pad_token) for text in decoded_texts]
|
decoded_texts = [text.rstrip(pad_token) for text in decoded_texts]
|
||||||
|
|
||||||
# -- Optional logging
|
# -- Optional logging
|
||||||
|
num_tokens = 0
|
||||||
if generated_ids.shape[0] > 0:
|
if generated_ids.shape[0] > 0:
|
||||||
|
num_tokens = int(generated_ids[0].shape[0])
|
||||||
_log.debug(
|
_log.debug(
|
||||||
f"Generated {int(generated_ids[0].shape[0])} tokens in {generation_time:.2f}s "
|
f"Generated {num_tokens} tokens in {generation_time:.2f}s "
|
||||||
f"for batch size {generated_ids.shape[0]}."
|
f"for batch size {generated_ids.shape[0]}."
|
||||||
)
|
)
|
||||||
|
|
||||||
for text in decoded_texts:
|
for text in decoded_texts:
|
||||||
# Apply decode_response to the output text
|
# Apply decode_response to the output text
|
||||||
decoded_text = self.vlm_options.decode_response(text)
|
decoded_text = self.vlm_options.decode_response(text)
|
||||||
yield VlmPrediction(text=decoded_text, generation_time=generation_time)
|
yield VlmPrediction(text=decoded_text, generation_time=generation_time, num_tokens=num_tokens)
|
||||||
|
|||||||
@@ -313,5 +313,6 @@ class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
|
|||||||
text=decoded_output,
|
text=decoded_output,
|
||||||
generation_time=generation_time,
|
generation_time=generation_time,
|
||||||
generated_tokens=tokens,
|
generated_tokens=tokens,
|
||||||
|
num_tokens = len(tokens)
|
||||||
)
|
)
|
||||||
_log.debug("MLX model: Released global lock")
|
_log.debug("MLX model: Released global lock")
|
||||||
|
|||||||
@@ -278,13 +278,15 @@ class NuExtractTransformersModel(BaseVlmModel, HuggingFaceModelDownloadMixin):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Optional logging
|
# Optional logging
|
||||||
|
num_tokens = 0
|
||||||
if generated_ids.shape[0] > 0: # type: ignore
|
if generated_ids.shape[0] > 0: # type: ignore
|
||||||
|
num_tokens = int(generated_ids[0].shape[0])
|
||||||
_log.debug(
|
_log.debug(
|
||||||
f"Generated {int(generated_ids[0].shape[0])} tokens in {generation_time:.2f}s "
|
f"Generated {num_tokens} tokens in {generation_time:.2f}s "
|
||||||
f"for batch size {generated_ids.shape[0]}." # type: ignore
|
f"for batch size {generated_ids.shape[0]}." # type: ignore
|
||||||
)
|
)
|
||||||
|
|
||||||
for text in decoded_texts:
|
for text in decoded_texts:
|
||||||
# Apply decode_response to the output text
|
# Apply decode_response to the output text
|
||||||
decoded_text = self.vlm_options.decode_response(text)
|
decoded_text = self.vlm_options.decode_response(text)
|
||||||
yield VlmPrediction(text=decoded_text, generation_time=generation_time)
|
yield VlmPrediction(text=decoded_text, generation_time=generation_time, num_tokens=num_tokens)
|
||||||
|
|||||||
@@ -291,10 +291,11 @@ class VllmVlmModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
|
|||||||
num_tokens = len(outputs[0].outputs[0].token_ids)
|
num_tokens = len(outputs[0].outputs[0].token_ids)
|
||||||
_log.debug(f"Generated {num_tokens} tokens in {generation_time:.2f}s.")
|
_log.debug(f"Generated {num_tokens} tokens in {generation_time:.2f}s.")
|
||||||
except Exception:
|
except Exception:
|
||||||
|
num_tokens = 0
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# Emit predictions
|
# Emit predictions
|
||||||
for output in outputs:
|
for output in outputs:
|
||||||
text = output.outputs[0].text if output.outputs else ""
|
text = output.outputs[0].text if output.outputs else ""
|
||||||
decoded_text = self.vlm_options.decode_response(text)
|
decoded_text = self.vlm_options.decode_response(text)
|
||||||
yield VlmPrediction(text=decoded_text, generation_time=generation_time)
|
yield VlmPrediction(text=decoded_text, generation_time=generation_time, num_tokens=num_tokens)
|
||||||
|
|||||||
Reference in New Issue
Block a user