mirror of
https://github.com/DS4SD/docling.git
synced 2025-12-08 12:48:28 +00:00
feat(vlm): track generated tokens and stop reasons for VLM models (#2543)
* feat: add enum StopReason and use it in VlmPrediction Signed-off-by: ElHachem02 <peterelhachem02@gmail.com> * add vlm_inference time for api calls and track stop reason Signed-off-by: ElHachem02 <peterelhachem02@gmail.com> * fix: rename enum to VlmStopReason Signed-off-by: ElHachem02 <peterelhachem02@gmail.com> * Propagate partial success status if page reaches max tokens Signed-off-by: ElHachem02 <peterelhachem02@gmail.com> * feat: page with generation stopped by loop detector create partial success status Signed-off-by: Peter El Hachem <peter.el.hachem@ibm.com> * Add hint for future improvement Signed-off-by: Peter El Hachem <peter.el.hachem@ibm.com> * fix: remove vlm_stop_reason from extracted page data, add UNSPECIFIED state as VlmStopReason to avoid null value Signed-off-by: Peter El Hachem <peter.el.hachem@ibm.com> --------- Signed-off-by: ElHachem02 <peterelhachem02@gmail.com> Signed-off-by: Peter El Hachem <peter.el.hachem@ibm.com> Co-authored-by: Peter El Hachem <peter.el.hachem@ibm.com>
This commit is contained in:
@@ -166,6 +166,13 @@ class DoclingComponentType(str, Enum):
|
||||
USER_INPUT = "user_input"
|
||||
|
||||
|
||||
class VlmStopReason(str, Enum):
|
||||
LENGTH = "length" # max tokens reached
|
||||
STOP_SEQUENCE = "stop_sequence" # Custom stopping criteria met
|
||||
END_OF_SEQUENCE = "end_of_sequence" # Model generated end-of-text token
|
||||
UNSPECIFIED = "unspecified" # Defaul none value
|
||||
|
||||
|
||||
class ErrorItem(BaseModel):
|
||||
component_type: DoclingComponentType
|
||||
module_name: str
|
||||
@@ -208,7 +215,7 @@ class VlmPrediction(BaseModel):
|
||||
generated_tokens: list[VlmPredictionToken] = []
|
||||
generation_time: float = -1
|
||||
num_tokens: Optional[int] = None
|
||||
stop_reason: Optional[str] = None # todo define an enum for possible stop reasons
|
||||
stop_reason: VlmStopReason = VlmStopReason.UNSPECIFIED
|
||||
|
||||
|
||||
class ContainerElement(
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional, Type, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from docling.datamodel.base_models import ConversionStatus, ErrorItem
|
||||
from docling.datamodel.base_models import ConversionStatus, ErrorItem, VlmStopReason
|
||||
from docling.datamodel.document import InputDocument
|
||||
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from transformers import StoppingCriteria
|
||||
|
||||
from docling.datamodel.base_models import Page, VlmPrediction
|
||||
from docling.datamodel.base_models import Page, VlmPrediction, VlmStopReason
|
||||
from docling.datamodel.document import ConversionResult
|
||||
from docling.datamodel.pipeline_options_vlm_model import ApiVlmOptions
|
||||
from docling.exceptions import OperationNotAllowed
|
||||
@@ -59,6 +59,7 @@ class ApiVlmModel(BasePageModel):
|
||||
hi_res_image = hi_res_image.convert("RGB")
|
||||
|
||||
prompt = self.vlm_options.build_prompt(page.parsed_page)
|
||||
stop_reason = VlmStopReason.UNSPECIFIED
|
||||
|
||||
if self.vlm_options.custom_stopping_criteria:
|
||||
# Instantiate any GenerationStopper classes before passing to streaming
|
||||
@@ -73,6 +74,7 @@ class ApiVlmModel(BasePageModel):
|
||||
# Skip non-GenerationStopper criteria (should have been caught in validation)
|
||||
|
||||
# Streaming path with early abort support
|
||||
with TimeRecorder(conv_res, "vlm_inference"):
|
||||
page_tags, num_tokens = api_image_request_streaming(
|
||||
image=hi_res_image,
|
||||
prompt=prompt,
|
||||
@@ -82,9 +84,11 @@ class ApiVlmModel(BasePageModel):
|
||||
generation_stoppers=instantiated_stoppers,
|
||||
**self.params,
|
||||
)
|
||||
page_tags = self.vlm_options.decode_response(page_tags)
|
||||
else:
|
||||
# Non-streaming fallback (existing behavior)
|
||||
page_tags, num_tokens = api_image_request(
|
||||
with TimeRecorder(conv_res, "vlm_inference"):
|
||||
page_tags, num_tokens, stop_reason = api_image_request(
|
||||
image=hi_res_image,
|
||||
prompt=prompt,
|
||||
url=self.vlm_options.url,
|
||||
@@ -94,8 +98,9 @@ class ApiVlmModel(BasePageModel):
|
||||
)
|
||||
|
||||
page_tags = self.vlm_options.decode_response(page_tags)
|
||||
|
||||
page.predictions.vlm_response = VlmPrediction(
|
||||
text=page_tags, num_tokens=num_tokens
|
||||
text=page_tags, num_tokens=num_tokens, stop_reason=stop_reason
|
||||
)
|
||||
return page
|
||||
|
||||
|
||||
@@ -51,7 +51,7 @@ class PictureDescriptionApiModel(PictureDescriptionBaseModel):
|
||||
# Note: technically we could make a batch request here,
|
||||
# but not all APIs will allow for it. For example, vllm won't allow more than 1.
|
||||
def _api_request(image):
|
||||
response, _ = api_image_request(
|
||||
page_tags, _, _ = api_image_request(
|
||||
image=image,
|
||||
prompt=self.options.prompt,
|
||||
url=self.options.url,
|
||||
@@ -60,7 +60,7 @@ class PictureDescriptionApiModel(PictureDescriptionBaseModel):
|
||||
**self.options.params,
|
||||
)
|
||||
|
||||
return response
|
||||
return page_tags
|
||||
|
||||
with ThreadPoolExecutor(max_workers=self.concurrency) as executor:
|
||||
yield from executor.map(_api_request, images)
|
||||
|
||||
@@ -13,7 +13,7 @@ from transformers import StoppingCriteria, StoppingCriteriaList, StopStringCrite
|
||||
from docling.datamodel.accelerator_options import (
|
||||
AcceleratorOptions,
|
||||
)
|
||||
from docling.datamodel.base_models import Page, VlmPrediction
|
||||
from docling.datamodel.base_models import Page, VlmPrediction, VlmStopReason
|
||||
from docling.datamodel.document import ConversionResult
|
||||
from docling.datamodel.pipeline_options_vlm_model import (
|
||||
InlineVlmOptions,
|
||||
@@ -382,4 +382,5 @@ class HuggingFaceTransformersVlmModel(BaseVlmPageModel, HuggingFaceModelDownload
|
||||
text=decoded_text,
|
||||
generation_time=generation_time,
|
||||
num_tokens=num_tokens,
|
||||
stop_reason=VlmStopReason.UNSPECIFIED,
|
||||
)
|
||||
|
||||
@@ -13,7 +13,12 @@ from transformers import StoppingCriteria
|
||||
from docling.datamodel.accelerator_options import (
|
||||
AcceleratorOptions,
|
||||
)
|
||||
from docling.datamodel.base_models import Page, VlmPrediction, VlmPredictionToken
|
||||
from docling.datamodel.base_models import (
|
||||
Page,
|
||||
VlmPrediction,
|
||||
VlmPredictionToken,
|
||||
VlmStopReason,
|
||||
)
|
||||
from docling.datamodel.document import ConversionResult
|
||||
from docling.datamodel.pipeline_options_vlm_model import InlineVlmOptions
|
||||
from docling.models.base_model import BaseVlmPageModel
|
||||
@@ -319,5 +324,6 @@ class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
|
||||
generation_time=generation_time,
|
||||
generated_tokens=tokens,
|
||||
num_tokens=len(tokens),
|
||||
stop_reason=VlmStopReason.UNSPECIFIED,
|
||||
)
|
||||
_log.debug("MLX model: Released global lock")
|
||||
|
||||
@@ -12,7 +12,7 @@ from transformers import AutoModelForImageTextToText, AutoProcessor, GenerationC
|
||||
from docling.datamodel.accelerator_options import (
|
||||
AcceleratorOptions,
|
||||
)
|
||||
from docling.datamodel.base_models import VlmPrediction
|
||||
from docling.datamodel.base_models import VlmPrediction, VlmStopReason
|
||||
from docling.datamodel.pipeline_options_vlm_model import InlineVlmOptions
|
||||
from docling.models.base_model import BaseVlmModel
|
||||
from docling.models.utils.hf_model_download import (
|
||||
@@ -284,6 +284,7 @@ class NuExtractTransformersModel(BaseVlmModel, HuggingFaceModelDownloadMixin):
|
||||
# Optional logging
|
||||
num_tokens = None
|
||||
if generated_ids.shape[0] > 0: # type: ignore
|
||||
# Todo: confirm num tokens is actually from first item, code was already like this
|
||||
num_tokens = int(generated_ids[0].shape[0])
|
||||
_log.debug(
|
||||
f"Generated {num_tokens} tokens in {generation_time:.2f}s "
|
||||
@@ -297,4 +298,5 @@ class NuExtractTransformersModel(BaseVlmModel, HuggingFaceModelDownloadMixin):
|
||||
text=decoded_text,
|
||||
generation_time=generation_time,
|
||||
num_tokens=num_tokens,
|
||||
stop_reason=VlmStopReason.UNSPECIFIED,
|
||||
)
|
||||
|
||||
@@ -9,7 +9,12 @@ import numpy as np
|
||||
from PIL.Image import Image
|
||||
|
||||
from docling.datamodel.accelerator_options import AcceleratorOptions
|
||||
from docling.datamodel.base_models import Page, VlmPrediction, VlmPredictionToken
|
||||
from docling.datamodel.base_models import (
|
||||
Page,
|
||||
VlmPrediction,
|
||||
VlmPredictionToken,
|
||||
VlmStopReason,
|
||||
)
|
||||
from docling.datamodel.document import ConversionResult
|
||||
from docling.datamodel.pipeline_options_vlm_model import (
|
||||
InlineVlmOptions,
|
||||
@@ -311,13 +316,18 @@ class VllmVlmModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
|
||||
# Emit predictions
|
||||
for output in outputs:
|
||||
text = output.outputs[0].text if output.outputs else ""
|
||||
stop_reason = output.outputs[0].stop_reason if output.outputs else ""
|
||||
generated_tokens = [
|
||||
VlmPredictionToken(token=int(p)) for p in output.outputs[0].token_ids
|
||||
]
|
||||
stop_reason = (
|
||||
VlmStopReason.END_OF_SEQUENCE
|
||||
if output.outputs[0].stop_reason
|
||||
else VlmStopReason.LENGTH
|
||||
)
|
||||
generated_tokens = (
|
||||
[VlmPredictionToken(token=int(t)) for t in output.outputs[0].token_ids]
|
||||
if self.vlm_options.track_generated_tokens
|
||||
else []
|
||||
)
|
||||
num_tokens = len(generated_tokens)
|
||||
decoded_text = self.vlm_options.decode_response(text)
|
||||
if self.vlm_options.track_generated_tokens:
|
||||
yield VlmPrediction(
|
||||
text=decoded_text,
|
||||
generation_time=generation_time,
|
||||
@@ -325,10 +335,3 @@ class VllmVlmModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
|
||||
stop_reason=stop_reason,
|
||||
generated_tokens=generated_tokens,
|
||||
)
|
||||
else:
|
||||
yield VlmPrediction(
|
||||
text=decoded_text,
|
||||
generation_time=generation_time,
|
||||
num_tokens=num_tokens,
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
|
||||
@@ -8,7 +8,7 @@ from pydantic import BaseModel
|
||||
|
||||
from docling.backend.abstract_backend import PaginatedDocumentBackend
|
||||
from docling.backend.pdf_backend import PdfDocumentBackend
|
||||
from docling.datamodel.base_models import ConversionStatus, ErrorItem
|
||||
from docling.datamodel.base_models import ConversionStatus, ErrorItem, VlmStopReason
|
||||
from docling.datamodel.document import InputDocument
|
||||
from docling.datamodel.extraction import (
|
||||
ExtractedPageData,
|
||||
@@ -83,6 +83,12 @@ class ExtractionVlmPipeline(BaseExtractionPipeline):
|
||||
# Parse the extracted text as JSON if possible, otherwise use as-is
|
||||
extracted_text = predictions[0].text
|
||||
extracted_data = None
|
||||
vlm_stop_reason: VlmStopReason = predictions[0].stop_reason
|
||||
if (
|
||||
vlm_stop_reason == VlmStopReason.LENGTH
|
||||
or vlm_stop_reason == VlmStopReason.STOP_SEQUENCE
|
||||
):
|
||||
ext_res.status = ConversionStatus.PARTIAL_SUCCESS
|
||||
|
||||
try:
|
||||
extracted_data = json.loads(extracted_text)
|
||||
@@ -128,7 +134,11 @@ class ExtractionVlmPipeline(BaseExtractionPipeline):
|
||||
def _determine_status(self, ext_res: ExtractionResult) -> ConversionStatus:
|
||||
"""Determine the status based on extraction results."""
|
||||
if ext_res.pages and not any(page.errors for page in ext_res.pages):
|
||||
return ConversionStatus.SUCCESS
|
||||
return (
|
||||
ConversionStatus.PARTIAL_SUCCESS
|
||||
if ext_res.status == ConversionStatus.PARTIAL_SUCCESS
|
||||
else ConversionStatus.SUCCESS
|
||||
)
|
||||
else:
|
||||
return ConversionStatus.FAILURE
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ import requests
|
||||
from PIL import Image
|
||||
from pydantic import AnyUrl
|
||||
|
||||
from docling.datamodel.base_models import OpenAiApiResponse
|
||||
from docling.datamodel.base_models import OpenAiApiResponse, VlmStopReason
|
||||
from docling.models.utils.generation_utils import GenerationStopper
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
@@ -21,7 +21,7 @@ def api_image_request(
|
||||
timeout: float = 20,
|
||||
headers: Optional[dict[str, str]] = None,
|
||||
**params,
|
||||
) -> Tuple[str, Optional[int]]:
|
||||
) -> Tuple[str, Optional[int], VlmStopReason]:
|
||||
img_io = BytesIO()
|
||||
image.save(img_io, "PNG")
|
||||
image_base64 = base64.b64encode(img_io.getvalue()).decode("utf-8")
|
||||
@@ -61,7 +61,13 @@ def api_image_request(
|
||||
api_resp = OpenAiApiResponse.model_validate_json(r.text)
|
||||
generated_text = api_resp.choices[0].message.content.strip()
|
||||
num_tokens = api_resp.usage.total_tokens
|
||||
return generated_text, num_tokens
|
||||
stop_reason = (
|
||||
VlmStopReason.LENGTH
|
||||
if api_resp.choices[0].finish_reason == "length"
|
||||
else VlmStopReason.END_OF_SEQUENCE
|
||||
)
|
||||
|
||||
return generated_text, num_tokens, stop_reason
|
||||
|
||||
|
||||
def api_image_request_streaming(
|
||||
|
||||
Reference in New Issue
Block a user