feat(vlm): track generated tokens and stop reasons for VLM models (#2543)

* feat: add enum StopReason and use it in VlmPrediction

Signed-off-by: ElHachem02 <peterelhachem02@gmail.com>

* add vlm_inference time for api calls and track stop reason

Signed-off-by: ElHachem02 <peterelhachem02@gmail.com>

* fix: rename enum to VlmStopReason

Signed-off-by: ElHachem02 <peterelhachem02@gmail.com>

* Propagate partial success status if page reaches max tokens

Signed-off-by: ElHachem02 <peterelhachem02@gmail.com>

* feat: page with generation stopped by loop detector create partial success status

Signed-off-by: Peter El Hachem <peter.el.hachem@ibm.com>

* Add hint for future improvement

Signed-off-by: Peter El Hachem <peter.el.hachem@ibm.com>

* fix: remove vlm_stop_reason from extracted page data, add UNSPECIFIED state as VlmStopReason to avoid null value

Signed-off-by: Peter El Hachem <peter.el.hachem@ibm.com>

---------

Signed-off-by: ElHachem02 <peterelhachem02@gmail.com>
Signed-off-by: Peter El Hachem <peter.el.hachem@ibm.com>
Co-authored-by: Peter El Hachem <peter.el.hachem@ibm.com>
This commit is contained in:
peets
2025-11-04 19:39:09 +01:00
committed by GitHub
parent 1a5146abc9
commit 6a04e27352
10 changed files with 92 additions and 52 deletions

View File

@@ -166,6 +166,13 @@ class DoclingComponentType(str, Enum):
USER_INPUT = "user_input" USER_INPUT = "user_input"
class VlmStopReason(str, Enum):
LENGTH = "length" # max tokens reached
STOP_SEQUENCE = "stop_sequence" # Custom stopping criteria met
END_OF_SEQUENCE = "end_of_sequence" # Model generated end-of-text token
UNSPECIFIED = "unspecified" # Defaul none value
class ErrorItem(BaseModel): class ErrorItem(BaseModel):
component_type: DoclingComponentType component_type: DoclingComponentType
module_name: str module_name: str
@@ -208,7 +215,7 @@ class VlmPrediction(BaseModel):
generated_tokens: list[VlmPredictionToken] = [] generated_tokens: list[VlmPredictionToken] = []
generation_time: float = -1 generation_time: float = -1
num_tokens: Optional[int] = None num_tokens: Optional[int] = None
stop_reason: Optional[str] = None # todo define an enum for possible stop reasons stop_reason: VlmStopReason = VlmStopReason.UNSPECIFIED
class ContainerElement( class ContainerElement(

View File

@@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional, Type, Union
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from docling.datamodel.base_models import ConversionStatus, ErrorItem from docling.datamodel.base_models import ConversionStatus, ErrorItem, VlmStopReason
from docling.datamodel.document import InputDocument from docling.datamodel.document import InputDocument

View File

@@ -3,7 +3,7 @@ from concurrent.futures import ThreadPoolExecutor
from transformers import StoppingCriteria from transformers import StoppingCriteria
from docling.datamodel.base_models import Page, VlmPrediction from docling.datamodel.base_models import Page, VlmPrediction, VlmStopReason
from docling.datamodel.document import ConversionResult from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options_vlm_model import ApiVlmOptions from docling.datamodel.pipeline_options_vlm_model import ApiVlmOptions
from docling.exceptions import OperationNotAllowed from docling.exceptions import OperationNotAllowed
@@ -59,6 +59,7 @@ class ApiVlmModel(BasePageModel):
hi_res_image = hi_res_image.convert("RGB") hi_res_image = hi_res_image.convert("RGB")
prompt = self.vlm_options.build_prompt(page.parsed_page) prompt = self.vlm_options.build_prompt(page.parsed_page)
stop_reason = VlmStopReason.UNSPECIFIED
if self.vlm_options.custom_stopping_criteria: if self.vlm_options.custom_stopping_criteria:
# Instantiate any GenerationStopper classes before passing to streaming # Instantiate any GenerationStopper classes before passing to streaming
@@ -73,29 +74,33 @@ class ApiVlmModel(BasePageModel):
# Skip non-GenerationStopper criteria (should have been caught in validation) # Skip non-GenerationStopper criteria (should have been caught in validation)
# Streaming path with early abort support # Streaming path with early abort support
page_tags, num_tokens = api_image_request_streaming( with TimeRecorder(conv_res, "vlm_inference"):
image=hi_res_image, page_tags, num_tokens = api_image_request_streaming(
prompt=prompt, image=hi_res_image,
url=self.vlm_options.url, prompt=prompt,
timeout=self.timeout, url=self.vlm_options.url,
headers=self.vlm_options.headers, timeout=self.timeout,
generation_stoppers=instantiated_stoppers, headers=self.vlm_options.headers,
**self.params, generation_stoppers=instantiated_stoppers,
) **self.params,
)
page_tags = self.vlm_options.decode_response(page_tags)
else: else:
# Non-streaming fallback (existing behavior) # Non-streaming fallback (existing behavior)
page_tags, num_tokens = api_image_request( with TimeRecorder(conv_res, "vlm_inference"):
image=hi_res_image, page_tags, num_tokens, stop_reason = api_image_request(
prompt=prompt, image=hi_res_image,
url=self.vlm_options.url, prompt=prompt,
timeout=self.timeout, url=self.vlm_options.url,
headers=self.vlm_options.headers, timeout=self.timeout,
**self.params, headers=self.vlm_options.headers,
) **self.params,
)
page_tags = self.vlm_options.decode_response(page_tags)
page_tags = self.vlm_options.decode_response(page_tags)
page.predictions.vlm_response = VlmPrediction( page.predictions.vlm_response = VlmPrediction(
text=page_tags, num_tokens=num_tokens text=page_tags, num_tokens=num_tokens, stop_reason=stop_reason
) )
return page return page

View File

@@ -51,7 +51,7 @@ class PictureDescriptionApiModel(PictureDescriptionBaseModel):
# Note: technically we could make a batch request here, # Note: technically we could make a batch request here,
# but not all APIs will allow for it. For example, vllm won't allow more than 1. # but not all APIs will allow for it. For example, vllm won't allow more than 1.
def _api_request(image): def _api_request(image):
response, _ = api_image_request( page_tags, _, _ = api_image_request(
image=image, image=image,
prompt=self.options.prompt, prompt=self.options.prompt,
url=self.options.url, url=self.options.url,
@@ -60,7 +60,7 @@ class PictureDescriptionApiModel(PictureDescriptionBaseModel):
**self.options.params, **self.options.params,
) )
return response return page_tags
with ThreadPoolExecutor(max_workers=self.concurrency) as executor: with ThreadPoolExecutor(max_workers=self.concurrency) as executor:
yield from executor.map(_api_request, images) yield from executor.map(_api_request, images)

View File

@@ -13,7 +13,7 @@ from transformers import StoppingCriteria, StoppingCriteriaList, StopStringCrite
from docling.datamodel.accelerator_options import ( from docling.datamodel.accelerator_options import (
AcceleratorOptions, AcceleratorOptions,
) )
from docling.datamodel.base_models import Page, VlmPrediction from docling.datamodel.base_models import Page, VlmPrediction, VlmStopReason
from docling.datamodel.document import ConversionResult from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options_vlm_model import ( from docling.datamodel.pipeline_options_vlm_model import (
InlineVlmOptions, InlineVlmOptions,
@@ -382,4 +382,5 @@ class HuggingFaceTransformersVlmModel(BaseVlmPageModel, HuggingFaceModelDownload
text=decoded_text, text=decoded_text,
generation_time=generation_time, generation_time=generation_time,
num_tokens=num_tokens, num_tokens=num_tokens,
stop_reason=VlmStopReason.UNSPECIFIED,
) )

View File

@@ -13,7 +13,12 @@ from transformers import StoppingCriteria
from docling.datamodel.accelerator_options import ( from docling.datamodel.accelerator_options import (
AcceleratorOptions, AcceleratorOptions,
) )
from docling.datamodel.base_models import Page, VlmPrediction, VlmPredictionToken from docling.datamodel.base_models import (
Page,
VlmPrediction,
VlmPredictionToken,
VlmStopReason,
)
from docling.datamodel.document import ConversionResult from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options_vlm_model import InlineVlmOptions from docling.datamodel.pipeline_options_vlm_model import InlineVlmOptions
from docling.models.base_model import BaseVlmPageModel from docling.models.base_model import BaseVlmPageModel
@@ -319,5 +324,6 @@ class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
generation_time=generation_time, generation_time=generation_time,
generated_tokens=tokens, generated_tokens=tokens,
num_tokens=len(tokens), num_tokens=len(tokens),
stop_reason=VlmStopReason.UNSPECIFIED,
) )
_log.debug("MLX model: Released global lock") _log.debug("MLX model: Released global lock")

View File

@@ -12,7 +12,7 @@ from transformers import AutoModelForImageTextToText, AutoProcessor, GenerationC
from docling.datamodel.accelerator_options import ( from docling.datamodel.accelerator_options import (
AcceleratorOptions, AcceleratorOptions,
) )
from docling.datamodel.base_models import VlmPrediction from docling.datamodel.base_models import VlmPrediction, VlmStopReason
from docling.datamodel.pipeline_options_vlm_model import InlineVlmOptions from docling.datamodel.pipeline_options_vlm_model import InlineVlmOptions
from docling.models.base_model import BaseVlmModel from docling.models.base_model import BaseVlmModel
from docling.models.utils.hf_model_download import ( from docling.models.utils.hf_model_download import (
@@ -284,6 +284,7 @@ class NuExtractTransformersModel(BaseVlmModel, HuggingFaceModelDownloadMixin):
# Optional logging # Optional logging
num_tokens = None num_tokens = None
if generated_ids.shape[0] > 0: # type: ignore if generated_ids.shape[0] > 0: # type: ignore
# Todo: confirm num tokens is actually from first item, code was already like this
num_tokens = int(generated_ids[0].shape[0]) num_tokens = int(generated_ids[0].shape[0])
_log.debug( _log.debug(
f"Generated {num_tokens} tokens in {generation_time:.2f}s " f"Generated {num_tokens} tokens in {generation_time:.2f}s "
@@ -297,4 +298,5 @@ class NuExtractTransformersModel(BaseVlmModel, HuggingFaceModelDownloadMixin):
text=decoded_text, text=decoded_text,
generation_time=generation_time, generation_time=generation_time,
num_tokens=num_tokens, num_tokens=num_tokens,
stop_reason=VlmStopReason.UNSPECIFIED,
) )

View File

@@ -9,7 +9,12 @@ import numpy as np
from PIL.Image import Image from PIL.Image import Image
from docling.datamodel.accelerator_options import AcceleratorOptions from docling.datamodel.accelerator_options import AcceleratorOptions
from docling.datamodel.base_models import Page, VlmPrediction, VlmPredictionToken from docling.datamodel.base_models import (
Page,
VlmPrediction,
VlmPredictionToken,
VlmStopReason,
)
from docling.datamodel.document import ConversionResult from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options_vlm_model import ( from docling.datamodel.pipeline_options_vlm_model import (
InlineVlmOptions, InlineVlmOptions,
@@ -311,24 +316,22 @@ class VllmVlmModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
# Emit predictions # Emit predictions
for output in outputs: for output in outputs:
text = output.outputs[0].text if output.outputs else "" text = output.outputs[0].text if output.outputs else ""
stop_reason = output.outputs[0].stop_reason if output.outputs else "" stop_reason = (
generated_tokens = [ VlmStopReason.END_OF_SEQUENCE
VlmPredictionToken(token=int(p)) for p in output.outputs[0].token_ids if output.outputs[0].stop_reason
] else VlmStopReason.LENGTH
)
generated_tokens = (
[VlmPredictionToken(token=int(t)) for t in output.outputs[0].token_ids]
if self.vlm_options.track_generated_tokens
else []
)
num_tokens = len(generated_tokens) num_tokens = len(generated_tokens)
decoded_text = self.vlm_options.decode_response(text) decoded_text = self.vlm_options.decode_response(text)
if self.vlm_options.track_generated_tokens: yield VlmPrediction(
yield VlmPrediction( text=decoded_text,
text=decoded_text, generation_time=generation_time,
generation_time=generation_time, num_tokens=num_tokens,
num_tokens=num_tokens, stop_reason=stop_reason,
stop_reason=stop_reason, generated_tokens=generated_tokens,
generated_tokens=generated_tokens, )
)
else:
yield VlmPrediction(
text=decoded_text,
generation_time=generation_time,
num_tokens=num_tokens,
stop_reason=stop_reason,
)

View File

@@ -8,7 +8,7 @@ from pydantic import BaseModel
from docling.backend.abstract_backend import PaginatedDocumentBackend from docling.backend.abstract_backend import PaginatedDocumentBackend
from docling.backend.pdf_backend import PdfDocumentBackend from docling.backend.pdf_backend import PdfDocumentBackend
from docling.datamodel.base_models import ConversionStatus, ErrorItem from docling.datamodel.base_models import ConversionStatus, ErrorItem, VlmStopReason
from docling.datamodel.document import InputDocument from docling.datamodel.document import InputDocument
from docling.datamodel.extraction import ( from docling.datamodel.extraction import (
ExtractedPageData, ExtractedPageData,
@@ -83,6 +83,12 @@ class ExtractionVlmPipeline(BaseExtractionPipeline):
# Parse the extracted text as JSON if possible, otherwise use as-is # Parse the extracted text as JSON if possible, otherwise use as-is
extracted_text = predictions[0].text extracted_text = predictions[0].text
extracted_data = None extracted_data = None
vlm_stop_reason: VlmStopReason = predictions[0].stop_reason
if (
vlm_stop_reason == VlmStopReason.LENGTH
or vlm_stop_reason == VlmStopReason.STOP_SEQUENCE
):
ext_res.status = ConversionStatus.PARTIAL_SUCCESS
try: try:
extracted_data = json.loads(extracted_text) extracted_data = json.loads(extracted_text)
@@ -128,7 +134,11 @@ class ExtractionVlmPipeline(BaseExtractionPipeline):
def _determine_status(self, ext_res: ExtractionResult) -> ConversionStatus: def _determine_status(self, ext_res: ExtractionResult) -> ConversionStatus:
"""Determine the status based on extraction results.""" """Determine the status based on extraction results."""
if ext_res.pages and not any(page.errors for page in ext_res.pages): if ext_res.pages and not any(page.errors for page in ext_res.pages):
return ConversionStatus.SUCCESS return (
ConversionStatus.PARTIAL_SUCCESS
if ext_res.status == ConversionStatus.PARTIAL_SUCCESS
else ConversionStatus.SUCCESS
)
else: else:
return ConversionStatus.FAILURE return ConversionStatus.FAILURE

View File

@@ -8,7 +8,7 @@ import requests
from PIL import Image from PIL import Image
from pydantic import AnyUrl from pydantic import AnyUrl
from docling.datamodel.base_models import OpenAiApiResponse from docling.datamodel.base_models import OpenAiApiResponse, VlmStopReason
from docling.models.utils.generation_utils import GenerationStopper from docling.models.utils.generation_utils import GenerationStopper
_log = logging.getLogger(__name__) _log = logging.getLogger(__name__)
@@ -21,7 +21,7 @@ def api_image_request(
timeout: float = 20, timeout: float = 20,
headers: Optional[dict[str, str]] = None, headers: Optional[dict[str, str]] = None,
**params, **params,
) -> Tuple[str, Optional[int]]: ) -> Tuple[str, Optional[int], VlmStopReason]:
img_io = BytesIO() img_io = BytesIO()
image.save(img_io, "PNG") image.save(img_io, "PNG")
image_base64 = base64.b64encode(img_io.getvalue()).decode("utf-8") image_base64 = base64.b64encode(img_io.getvalue()).decode("utf-8")
@@ -61,7 +61,13 @@ def api_image_request(
api_resp = OpenAiApiResponse.model_validate_json(r.text) api_resp = OpenAiApiResponse.model_validate_json(r.text)
generated_text = api_resp.choices[0].message.content.strip() generated_text = api_resp.choices[0].message.content.strip()
num_tokens = api_resp.usage.total_tokens num_tokens = api_resp.usage.total_tokens
return generated_text, num_tokens stop_reason = (
VlmStopReason.LENGTH
if api_resp.choices[0].finish_reason == "length"
else VlmStopReason.END_OF_SEQUENCE
)
return generated_text, num_tokens, stop_reason
def api_image_request_streaming( def api_image_request_streaming(