mirror of
https://github.com/DS4SD/docling.git
synced 2025-12-08 12:48:28 +00:00
feat(vlm): track generated tokens and stop reasons for VLM models (#2543)
* feat: add enum StopReason and use it in VlmPrediction Signed-off-by: ElHachem02 <peterelhachem02@gmail.com> * add vlm_inference time for api calls and track stop reason Signed-off-by: ElHachem02 <peterelhachem02@gmail.com> * fix: rename enum to VlmStopReason Signed-off-by: ElHachem02 <peterelhachem02@gmail.com> * Propagate partial success status if page reaches max tokens Signed-off-by: ElHachem02 <peterelhachem02@gmail.com> * feat: page with generation stopped by loop detector create partial success status Signed-off-by: Peter El Hachem <peter.el.hachem@ibm.com> * Add hint for future improvement Signed-off-by: Peter El Hachem <peter.el.hachem@ibm.com> * fix: remove vlm_stop_reason from extracted page data, add UNSPECIFIED state as VlmStopReason to avoid null value Signed-off-by: Peter El Hachem <peter.el.hachem@ibm.com> --------- Signed-off-by: ElHachem02 <peterelhachem02@gmail.com> Signed-off-by: Peter El Hachem <peter.el.hachem@ibm.com> Co-authored-by: Peter El Hachem <peter.el.hachem@ibm.com>
This commit is contained in:
@@ -166,6 +166,13 @@ class DoclingComponentType(str, Enum):
|
|||||||
USER_INPUT = "user_input"
|
USER_INPUT = "user_input"
|
||||||
|
|
||||||
|
|
||||||
|
class VlmStopReason(str, Enum):
|
||||||
|
LENGTH = "length" # max tokens reached
|
||||||
|
STOP_SEQUENCE = "stop_sequence" # Custom stopping criteria met
|
||||||
|
END_OF_SEQUENCE = "end_of_sequence" # Model generated end-of-text token
|
||||||
|
UNSPECIFIED = "unspecified" # Defaul none value
|
||||||
|
|
||||||
|
|
||||||
class ErrorItem(BaseModel):
|
class ErrorItem(BaseModel):
|
||||||
component_type: DoclingComponentType
|
component_type: DoclingComponentType
|
||||||
module_name: str
|
module_name: str
|
||||||
@@ -208,7 +215,7 @@ class VlmPrediction(BaseModel):
|
|||||||
generated_tokens: list[VlmPredictionToken] = []
|
generated_tokens: list[VlmPredictionToken] = []
|
||||||
generation_time: float = -1
|
generation_time: float = -1
|
||||||
num_tokens: Optional[int] = None
|
num_tokens: Optional[int] = None
|
||||||
stop_reason: Optional[str] = None # todo define an enum for possible stop reasons
|
stop_reason: VlmStopReason = VlmStopReason.UNSPECIFIED
|
||||||
|
|
||||||
|
|
||||||
class ContainerElement(
|
class ContainerElement(
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional, Type, Union
|
|||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from docling.datamodel.base_models import ConversionStatus, ErrorItem
|
from docling.datamodel.base_models import ConversionStatus, ErrorItem, VlmStopReason
|
||||||
from docling.datamodel.document import InputDocument
|
from docling.datamodel.document import InputDocument
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from concurrent.futures import ThreadPoolExecutor
|
|||||||
|
|
||||||
from transformers import StoppingCriteria
|
from transformers import StoppingCriteria
|
||||||
|
|
||||||
from docling.datamodel.base_models import Page, VlmPrediction
|
from docling.datamodel.base_models import Page, VlmPrediction, VlmStopReason
|
||||||
from docling.datamodel.document import ConversionResult
|
from docling.datamodel.document import ConversionResult
|
||||||
from docling.datamodel.pipeline_options_vlm_model import ApiVlmOptions
|
from docling.datamodel.pipeline_options_vlm_model import ApiVlmOptions
|
||||||
from docling.exceptions import OperationNotAllowed
|
from docling.exceptions import OperationNotAllowed
|
||||||
@@ -59,6 +59,7 @@ class ApiVlmModel(BasePageModel):
|
|||||||
hi_res_image = hi_res_image.convert("RGB")
|
hi_res_image = hi_res_image.convert("RGB")
|
||||||
|
|
||||||
prompt = self.vlm_options.build_prompt(page.parsed_page)
|
prompt = self.vlm_options.build_prompt(page.parsed_page)
|
||||||
|
stop_reason = VlmStopReason.UNSPECIFIED
|
||||||
|
|
||||||
if self.vlm_options.custom_stopping_criteria:
|
if self.vlm_options.custom_stopping_criteria:
|
||||||
# Instantiate any GenerationStopper classes before passing to streaming
|
# Instantiate any GenerationStopper classes before passing to streaming
|
||||||
@@ -73,6 +74,7 @@ class ApiVlmModel(BasePageModel):
|
|||||||
# Skip non-GenerationStopper criteria (should have been caught in validation)
|
# Skip non-GenerationStopper criteria (should have been caught in validation)
|
||||||
|
|
||||||
# Streaming path with early abort support
|
# Streaming path with early abort support
|
||||||
|
with TimeRecorder(conv_res, "vlm_inference"):
|
||||||
page_tags, num_tokens = api_image_request_streaming(
|
page_tags, num_tokens = api_image_request_streaming(
|
||||||
image=hi_res_image,
|
image=hi_res_image,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
@@ -82,9 +84,11 @@ class ApiVlmModel(BasePageModel):
|
|||||||
generation_stoppers=instantiated_stoppers,
|
generation_stoppers=instantiated_stoppers,
|
||||||
**self.params,
|
**self.params,
|
||||||
)
|
)
|
||||||
|
page_tags = self.vlm_options.decode_response(page_tags)
|
||||||
else:
|
else:
|
||||||
# Non-streaming fallback (existing behavior)
|
# Non-streaming fallback (existing behavior)
|
||||||
page_tags, num_tokens = api_image_request(
|
with TimeRecorder(conv_res, "vlm_inference"):
|
||||||
|
page_tags, num_tokens, stop_reason = api_image_request(
|
||||||
image=hi_res_image,
|
image=hi_res_image,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
url=self.vlm_options.url,
|
url=self.vlm_options.url,
|
||||||
@@ -94,8 +98,9 @@ class ApiVlmModel(BasePageModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
page_tags = self.vlm_options.decode_response(page_tags)
|
page_tags = self.vlm_options.decode_response(page_tags)
|
||||||
|
|
||||||
page.predictions.vlm_response = VlmPrediction(
|
page.predictions.vlm_response = VlmPrediction(
|
||||||
text=page_tags, num_tokens=num_tokens
|
text=page_tags, num_tokens=num_tokens, stop_reason=stop_reason
|
||||||
)
|
)
|
||||||
return page
|
return page
|
||||||
|
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ class PictureDescriptionApiModel(PictureDescriptionBaseModel):
|
|||||||
# Note: technically we could make a batch request here,
|
# Note: technically we could make a batch request here,
|
||||||
# but not all APIs will allow for it. For example, vllm won't allow more than 1.
|
# but not all APIs will allow for it. For example, vllm won't allow more than 1.
|
||||||
def _api_request(image):
|
def _api_request(image):
|
||||||
response, _ = api_image_request(
|
page_tags, _, _ = api_image_request(
|
||||||
image=image,
|
image=image,
|
||||||
prompt=self.options.prompt,
|
prompt=self.options.prompt,
|
||||||
url=self.options.url,
|
url=self.options.url,
|
||||||
@@ -60,7 +60,7 @@ class PictureDescriptionApiModel(PictureDescriptionBaseModel):
|
|||||||
**self.options.params,
|
**self.options.params,
|
||||||
)
|
)
|
||||||
|
|
||||||
return response
|
return page_tags
|
||||||
|
|
||||||
with ThreadPoolExecutor(max_workers=self.concurrency) as executor:
|
with ThreadPoolExecutor(max_workers=self.concurrency) as executor:
|
||||||
yield from executor.map(_api_request, images)
|
yield from executor.map(_api_request, images)
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ from transformers import StoppingCriteria, StoppingCriteriaList, StopStringCrite
|
|||||||
from docling.datamodel.accelerator_options import (
|
from docling.datamodel.accelerator_options import (
|
||||||
AcceleratorOptions,
|
AcceleratorOptions,
|
||||||
)
|
)
|
||||||
from docling.datamodel.base_models import Page, VlmPrediction
|
from docling.datamodel.base_models import Page, VlmPrediction, VlmStopReason
|
||||||
from docling.datamodel.document import ConversionResult
|
from docling.datamodel.document import ConversionResult
|
||||||
from docling.datamodel.pipeline_options_vlm_model import (
|
from docling.datamodel.pipeline_options_vlm_model import (
|
||||||
InlineVlmOptions,
|
InlineVlmOptions,
|
||||||
@@ -382,4 +382,5 @@ class HuggingFaceTransformersVlmModel(BaseVlmPageModel, HuggingFaceModelDownload
|
|||||||
text=decoded_text,
|
text=decoded_text,
|
||||||
generation_time=generation_time,
|
generation_time=generation_time,
|
||||||
num_tokens=num_tokens,
|
num_tokens=num_tokens,
|
||||||
|
stop_reason=VlmStopReason.UNSPECIFIED,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -13,7 +13,12 @@ from transformers import StoppingCriteria
|
|||||||
from docling.datamodel.accelerator_options import (
|
from docling.datamodel.accelerator_options import (
|
||||||
AcceleratorOptions,
|
AcceleratorOptions,
|
||||||
)
|
)
|
||||||
from docling.datamodel.base_models import Page, VlmPrediction, VlmPredictionToken
|
from docling.datamodel.base_models import (
|
||||||
|
Page,
|
||||||
|
VlmPrediction,
|
||||||
|
VlmPredictionToken,
|
||||||
|
VlmStopReason,
|
||||||
|
)
|
||||||
from docling.datamodel.document import ConversionResult
|
from docling.datamodel.document import ConversionResult
|
||||||
from docling.datamodel.pipeline_options_vlm_model import InlineVlmOptions
|
from docling.datamodel.pipeline_options_vlm_model import InlineVlmOptions
|
||||||
from docling.models.base_model import BaseVlmPageModel
|
from docling.models.base_model import BaseVlmPageModel
|
||||||
@@ -319,5 +324,6 @@ class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
|
|||||||
generation_time=generation_time,
|
generation_time=generation_time,
|
||||||
generated_tokens=tokens,
|
generated_tokens=tokens,
|
||||||
num_tokens=len(tokens),
|
num_tokens=len(tokens),
|
||||||
|
stop_reason=VlmStopReason.UNSPECIFIED,
|
||||||
)
|
)
|
||||||
_log.debug("MLX model: Released global lock")
|
_log.debug("MLX model: Released global lock")
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ from transformers import AutoModelForImageTextToText, AutoProcessor, GenerationC
|
|||||||
from docling.datamodel.accelerator_options import (
|
from docling.datamodel.accelerator_options import (
|
||||||
AcceleratorOptions,
|
AcceleratorOptions,
|
||||||
)
|
)
|
||||||
from docling.datamodel.base_models import VlmPrediction
|
from docling.datamodel.base_models import VlmPrediction, VlmStopReason
|
||||||
from docling.datamodel.pipeline_options_vlm_model import InlineVlmOptions
|
from docling.datamodel.pipeline_options_vlm_model import InlineVlmOptions
|
||||||
from docling.models.base_model import BaseVlmModel
|
from docling.models.base_model import BaseVlmModel
|
||||||
from docling.models.utils.hf_model_download import (
|
from docling.models.utils.hf_model_download import (
|
||||||
@@ -284,6 +284,7 @@ class NuExtractTransformersModel(BaseVlmModel, HuggingFaceModelDownloadMixin):
|
|||||||
# Optional logging
|
# Optional logging
|
||||||
num_tokens = None
|
num_tokens = None
|
||||||
if generated_ids.shape[0] > 0: # type: ignore
|
if generated_ids.shape[0] > 0: # type: ignore
|
||||||
|
# Todo: confirm num tokens is actually from first item, code was already like this
|
||||||
num_tokens = int(generated_ids[0].shape[0])
|
num_tokens = int(generated_ids[0].shape[0])
|
||||||
_log.debug(
|
_log.debug(
|
||||||
f"Generated {num_tokens} tokens in {generation_time:.2f}s "
|
f"Generated {num_tokens} tokens in {generation_time:.2f}s "
|
||||||
@@ -297,4 +298,5 @@ class NuExtractTransformersModel(BaseVlmModel, HuggingFaceModelDownloadMixin):
|
|||||||
text=decoded_text,
|
text=decoded_text,
|
||||||
generation_time=generation_time,
|
generation_time=generation_time,
|
||||||
num_tokens=num_tokens,
|
num_tokens=num_tokens,
|
||||||
|
stop_reason=VlmStopReason.UNSPECIFIED,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -9,7 +9,12 @@ import numpy as np
|
|||||||
from PIL.Image import Image
|
from PIL.Image import Image
|
||||||
|
|
||||||
from docling.datamodel.accelerator_options import AcceleratorOptions
|
from docling.datamodel.accelerator_options import AcceleratorOptions
|
||||||
from docling.datamodel.base_models import Page, VlmPrediction, VlmPredictionToken
|
from docling.datamodel.base_models import (
|
||||||
|
Page,
|
||||||
|
VlmPrediction,
|
||||||
|
VlmPredictionToken,
|
||||||
|
VlmStopReason,
|
||||||
|
)
|
||||||
from docling.datamodel.document import ConversionResult
|
from docling.datamodel.document import ConversionResult
|
||||||
from docling.datamodel.pipeline_options_vlm_model import (
|
from docling.datamodel.pipeline_options_vlm_model import (
|
||||||
InlineVlmOptions,
|
InlineVlmOptions,
|
||||||
@@ -311,13 +316,18 @@ class VllmVlmModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
|
|||||||
# Emit predictions
|
# Emit predictions
|
||||||
for output in outputs:
|
for output in outputs:
|
||||||
text = output.outputs[0].text if output.outputs else ""
|
text = output.outputs[0].text if output.outputs else ""
|
||||||
stop_reason = output.outputs[0].stop_reason if output.outputs else ""
|
stop_reason = (
|
||||||
generated_tokens = [
|
VlmStopReason.END_OF_SEQUENCE
|
||||||
VlmPredictionToken(token=int(p)) for p in output.outputs[0].token_ids
|
if output.outputs[0].stop_reason
|
||||||
]
|
else VlmStopReason.LENGTH
|
||||||
|
)
|
||||||
|
generated_tokens = (
|
||||||
|
[VlmPredictionToken(token=int(t)) for t in output.outputs[0].token_ids]
|
||||||
|
if self.vlm_options.track_generated_tokens
|
||||||
|
else []
|
||||||
|
)
|
||||||
num_tokens = len(generated_tokens)
|
num_tokens = len(generated_tokens)
|
||||||
decoded_text = self.vlm_options.decode_response(text)
|
decoded_text = self.vlm_options.decode_response(text)
|
||||||
if self.vlm_options.track_generated_tokens:
|
|
||||||
yield VlmPrediction(
|
yield VlmPrediction(
|
||||||
text=decoded_text,
|
text=decoded_text,
|
||||||
generation_time=generation_time,
|
generation_time=generation_time,
|
||||||
@@ -325,10 +335,3 @@ class VllmVlmModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
|
|||||||
stop_reason=stop_reason,
|
stop_reason=stop_reason,
|
||||||
generated_tokens=generated_tokens,
|
generated_tokens=generated_tokens,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
yield VlmPrediction(
|
|
||||||
text=decoded_text,
|
|
||||||
generation_time=generation_time,
|
|
||||||
num_tokens=num_tokens,
|
|
||||||
stop_reason=stop_reason,
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from pydantic import BaseModel
|
|||||||
|
|
||||||
from docling.backend.abstract_backend import PaginatedDocumentBackend
|
from docling.backend.abstract_backend import PaginatedDocumentBackend
|
||||||
from docling.backend.pdf_backend import PdfDocumentBackend
|
from docling.backend.pdf_backend import PdfDocumentBackend
|
||||||
from docling.datamodel.base_models import ConversionStatus, ErrorItem
|
from docling.datamodel.base_models import ConversionStatus, ErrorItem, VlmStopReason
|
||||||
from docling.datamodel.document import InputDocument
|
from docling.datamodel.document import InputDocument
|
||||||
from docling.datamodel.extraction import (
|
from docling.datamodel.extraction import (
|
||||||
ExtractedPageData,
|
ExtractedPageData,
|
||||||
@@ -83,6 +83,12 @@ class ExtractionVlmPipeline(BaseExtractionPipeline):
|
|||||||
# Parse the extracted text as JSON if possible, otherwise use as-is
|
# Parse the extracted text as JSON if possible, otherwise use as-is
|
||||||
extracted_text = predictions[0].text
|
extracted_text = predictions[0].text
|
||||||
extracted_data = None
|
extracted_data = None
|
||||||
|
vlm_stop_reason: VlmStopReason = predictions[0].stop_reason
|
||||||
|
if (
|
||||||
|
vlm_stop_reason == VlmStopReason.LENGTH
|
||||||
|
or vlm_stop_reason == VlmStopReason.STOP_SEQUENCE
|
||||||
|
):
|
||||||
|
ext_res.status = ConversionStatus.PARTIAL_SUCCESS
|
||||||
|
|
||||||
try:
|
try:
|
||||||
extracted_data = json.loads(extracted_text)
|
extracted_data = json.loads(extracted_text)
|
||||||
@@ -128,7 +134,11 @@ class ExtractionVlmPipeline(BaseExtractionPipeline):
|
|||||||
def _determine_status(self, ext_res: ExtractionResult) -> ConversionStatus:
|
def _determine_status(self, ext_res: ExtractionResult) -> ConversionStatus:
|
||||||
"""Determine the status based on extraction results."""
|
"""Determine the status based on extraction results."""
|
||||||
if ext_res.pages and not any(page.errors for page in ext_res.pages):
|
if ext_res.pages and not any(page.errors for page in ext_res.pages):
|
||||||
return ConversionStatus.SUCCESS
|
return (
|
||||||
|
ConversionStatus.PARTIAL_SUCCESS
|
||||||
|
if ext_res.status == ConversionStatus.PARTIAL_SUCCESS
|
||||||
|
else ConversionStatus.SUCCESS
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
return ConversionStatus.FAILURE
|
return ConversionStatus.FAILURE
|
||||||
|
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import requests
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
from pydantic import AnyUrl
|
from pydantic import AnyUrl
|
||||||
|
|
||||||
from docling.datamodel.base_models import OpenAiApiResponse
|
from docling.datamodel.base_models import OpenAiApiResponse, VlmStopReason
|
||||||
from docling.models.utils.generation_utils import GenerationStopper
|
from docling.models.utils.generation_utils import GenerationStopper
|
||||||
|
|
||||||
_log = logging.getLogger(__name__)
|
_log = logging.getLogger(__name__)
|
||||||
@@ -21,7 +21,7 @@ def api_image_request(
|
|||||||
timeout: float = 20,
|
timeout: float = 20,
|
||||||
headers: Optional[dict[str, str]] = None,
|
headers: Optional[dict[str, str]] = None,
|
||||||
**params,
|
**params,
|
||||||
) -> Tuple[str, Optional[int]]:
|
) -> Tuple[str, Optional[int], VlmStopReason]:
|
||||||
img_io = BytesIO()
|
img_io = BytesIO()
|
||||||
image.save(img_io, "PNG")
|
image.save(img_io, "PNG")
|
||||||
image_base64 = base64.b64encode(img_io.getvalue()).decode("utf-8")
|
image_base64 = base64.b64encode(img_io.getvalue()).decode("utf-8")
|
||||||
@@ -61,7 +61,13 @@ def api_image_request(
|
|||||||
api_resp = OpenAiApiResponse.model_validate_json(r.text)
|
api_resp = OpenAiApiResponse.model_validate_json(r.text)
|
||||||
generated_text = api_resp.choices[0].message.content.strip()
|
generated_text = api_resp.choices[0].message.content.strip()
|
||||||
num_tokens = api_resp.usage.total_tokens
|
num_tokens = api_resp.usage.total_tokens
|
||||||
return generated_text, num_tokens
|
stop_reason = (
|
||||||
|
VlmStopReason.LENGTH
|
||||||
|
if api_resp.choices[0].finish_reason == "length"
|
||||||
|
else VlmStopReason.END_OF_SEQUENCE
|
||||||
|
)
|
||||||
|
|
||||||
|
return generated_text, num_tokens, stop_reason
|
||||||
|
|
||||||
|
|
||||||
def api_image_request_streaming(
|
def api_image_request_streaming(
|
||||||
|
|||||||
Reference in New Issue
Block a user