diff --git a/docling/datamodel/base_models.py b/docling/datamodel/base_models.py index ae2f48cc..cd826f96 100644 --- a/docling/datamodel/base_models.py +++ b/docling/datamodel/base_models.py @@ -166,6 +166,13 @@ class DoclingComponentType(str, Enum): USER_INPUT = "user_input" +class VlmStopReason(str, Enum): + LENGTH = "length" # max tokens reached + STOP_SEQUENCE = "stop_sequence" # Custom stopping criteria met + END_OF_SEQUENCE = "end_of_sequence" # Model generated end-of-text token + UNSPECIFIED = "unspecified" # Defaul none value + + class ErrorItem(BaseModel): component_type: DoclingComponentType module_name: str @@ -208,7 +215,7 @@ class VlmPrediction(BaseModel): generated_tokens: list[VlmPredictionToken] = [] generation_time: float = -1 num_tokens: Optional[int] = None - stop_reason: Optional[str] = None # todo define an enum for possible stop reasons + stop_reason: VlmStopReason = VlmStopReason.UNSPECIFIED class ContainerElement( diff --git a/docling/datamodel/extraction.py b/docling/datamodel/extraction.py index 8b5b2bb6..185547e7 100644 --- a/docling/datamodel/extraction.py +++ b/docling/datamodel/extraction.py @@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional, Type, Union from pydantic import BaseModel, Field -from docling.datamodel.base_models import ConversionStatus, ErrorItem +from docling.datamodel.base_models import ConversionStatus, ErrorItem, VlmStopReason from docling.datamodel.document import InputDocument diff --git a/docling/models/api_vlm_model.py b/docling/models/api_vlm_model.py index baa7fe44..2c9a1f9a 100644 --- a/docling/models/api_vlm_model.py +++ b/docling/models/api_vlm_model.py @@ -3,7 +3,7 @@ from concurrent.futures import ThreadPoolExecutor from transformers import StoppingCriteria -from docling.datamodel.base_models import Page, VlmPrediction +from docling.datamodel.base_models import Page, VlmPrediction, VlmStopReason from docling.datamodel.document import ConversionResult from docling.datamodel.pipeline_options_vlm_model import ApiVlmOptions from docling.exceptions import OperationNotAllowed @@ -59,6 +59,7 @@ class ApiVlmModel(BasePageModel): hi_res_image = hi_res_image.convert("RGB") prompt = self.vlm_options.build_prompt(page.parsed_page) + stop_reason = VlmStopReason.UNSPECIFIED if self.vlm_options.custom_stopping_criteria: # Instantiate any GenerationStopper classes before passing to streaming @@ -73,29 +74,33 @@ class ApiVlmModel(BasePageModel): # Skip non-GenerationStopper criteria (should have been caught in validation) # Streaming path with early abort support - page_tags, num_tokens = api_image_request_streaming( - image=hi_res_image, - prompt=prompt, - url=self.vlm_options.url, - timeout=self.timeout, - headers=self.vlm_options.headers, - generation_stoppers=instantiated_stoppers, - **self.params, - ) + with TimeRecorder(conv_res, "vlm_inference"): + page_tags, num_tokens = api_image_request_streaming( + image=hi_res_image, + prompt=prompt, + url=self.vlm_options.url, + timeout=self.timeout, + headers=self.vlm_options.headers, + generation_stoppers=instantiated_stoppers, + **self.params, + ) + page_tags = self.vlm_options.decode_response(page_tags) else: # Non-streaming fallback (existing behavior) - page_tags, num_tokens = api_image_request( - image=hi_res_image, - prompt=prompt, - url=self.vlm_options.url, - timeout=self.timeout, - headers=self.vlm_options.headers, - **self.params, - ) + with TimeRecorder(conv_res, "vlm_inference"): + page_tags, num_tokens, stop_reason = api_image_request( + image=hi_res_image, + prompt=prompt, + url=self.vlm_options.url, + timeout=self.timeout, + headers=self.vlm_options.headers, + **self.params, + ) + + page_tags = self.vlm_options.decode_response(page_tags) - page_tags = self.vlm_options.decode_response(page_tags) page.predictions.vlm_response = VlmPrediction( - text=page_tags, num_tokens=num_tokens + text=page_tags, num_tokens=num_tokens, stop_reason=stop_reason ) return page diff --git a/docling/models/picture_description_api_model.py b/docling/models/picture_description_api_model.py index 771608e2..4a8272b4 100644 --- a/docling/models/picture_description_api_model.py +++ b/docling/models/picture_description_api_model.py @@ -51,7 +51,7 @@ class PictureDescriptionApiModel(PictureDescriptionBaseModel): # Note: technically we could make a batch request here, # but not all APIs will allow for it. For example, vllm won't allow more than 1. def _api_request(image): - response, _ = api_image_request( + page_tags, _, _ = api_image_request( image=image, prompt=self.options.prompt, url=self.options.url, @@ -60,7 +60,7 @@ class PictureDescriptionApiModel(PictureDescriptionBaseModel): **self.options.params, ) - return response + return page_tags with ThreadPoolExecutor(max_workers=self.concurrency) as executor: yield from executor.map(_api_request, images) diff --git a/docling/models/vlm_models_inline/hf_transformers_model.py b/docling/models/vlm_models_inline/hf_transformers_model.py index 6a19be23..f9aefcb8 100644 --- a/docling/models/vlm_models_inline/hf_transformers_model.py +++ b/docling/models/vlm_models_inline/hf_transformers_model.py @@ -13,7 +13,7 @@ from transformers import StoppingCriteria, StoppingCriteriaList, StopStringCrite from docling.datamodel.accelerator_options import ( AcceleratorOptions, ) -from docling.datamodel.base_models import Page, VlmPrediction +from docling.datamodel.base_models import Page, VlmPrediction, VlmStopReason from docling.datamodel.document import ConversionResult from docling.datamodel.pipeline_options_vlm_model import ( InlineVlmOptions, @@ -382,4 +382,5 @@ class HuggingFaceTransformersVlmModel(BaseVlmPageModel, HuggingFaceModelDownload text=decoded_text, generation_time=generation_time, num_tokens=num_tokens, + stop_reason=VlmStopReason.UNSPECIFIED, ) diff --git a/docling/models/vlm_models_inline/mlx_model.py b/docling/models/vlm_models_inline/mlx_model.py index b2516355..871c19ba 100644 --- a/docling/models/vlm_models_inline/mlx_model.py +++ b/docling/models/vlm_models_inline/mlx_model.py @@ -13,7 +13,12 @@ from transformers import StoppingCriteria from docling.datamodel.accelerator_options import ( AcceleratorOptions, ) -from docling.datamodel.base_models import Page, VlmPrediction, VlmPredictionToken +from docling.datamodel.base_models import ( + Page, + VlmPrediction, + VlmPredictionToken, + VlmStopReason, +) from docling.datamodel.document import ConversionResult from docling.datamodel.pipeline_options_vlm_model import InlineVlmOptions from docling.models.base_model import BaseVlmPageModel @@ -319,5 +324,6 @@ class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin): generation_time=generation_time, generated_tokens=tokens, num_tokens=len(tokens), + stop_reason=VlmStopReason.UNSPECIFIED, ) _log.debug("MLX model: Released global lock") diff --git a/docling/models/vlm_models_inline/nuextract_transformers_model.py b/docling/models/vlm_models_inline/nuextract_transformers_model.py index 5559093d..194a1d9d 100644 --- a/docling/models/vlm_models_inline/nuextract_transformers_model.py +++ b/docling/models/vlm_models_inline/nuextract_transformers_model.py @@ -12,7 +12,7 @@ from transformers import AutoModelForImageTextToText, AutoProcessor, GenerationC from docling.datamodel.accelerator_options import ( AcceleratorOptions, ) -from docling.datamodel.base_models import VlmPrediction +from docling.datamodel.base_models import VlmPrediction, VlmStopReason from docling.datamodel.pipeline_options_vlm_model import InlineVlmOptions from docling.models.base_model import BaseVlmModel from docling.models.utils.hf_model_download import ( @@ -284,6 +284,7 @@ class NuExtractTransformersModel(BaseVlmModel, HuggingFaceModelDownloadMixin): # Optional logging num_tokens = None if generated_ids.shape[0] > 0: # type: ignore + # Todo: confirm num tokens is actually from first item, code was already like this num_tokens = int(generated_ids[0].shape[0]) _log.debug( f"Generated {num_tokens} tokens in {generation_time:.2f}s " @@ -297,4 +298,5 @@ class NuExtractTransformersModel(BaseVlmModel, HuggingFaceModelDownloadMixin): text=decoded_text, generation_time=generation_time, num_tokens=num_tokens, + stop_reason=VlmStopReason.UNSPECIFIED, ) diff --git a/docling/models/vlm_models_inline/vllm_model.py b/docling/models/vlm_models_inline/vllm_model.py index e373f0c5..9254f4d5 100644 --- a/docling/models/vlm_models_inline/vllm_model.py +++ b/docling/models/vlm_models_inline/vllm_model.py @@ -9,7 +9,12 @@ import numpy as np from PIL.Image import Image from docling.datamodel.accelerator_options import AcceleratorOptions -from docling.datamodel.base_models import Page, VlmPrediction, VlmPredictionToken +from docling.datamodel.base_models import ( + Page, + VlmPrediction, + VlmPredictionToken, + VlmStopReason, +) from docling.datamodel.document import ConversionResult from docling.datamodel.pipeline_options_vlm_model import ( InlineVlmOptions, @@ -311,24 +316,22 @@ class VllmVlmModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin): # Emit predictions for output in outputs: text = output.outputs[0].text if output.outputs else "" - stop_reason = output.outputs[0].stop_reason if output.outputs else "" - generated_tokens = [ - VlmPredictionToken(token=int(p)) for p in output.outputs[0].token_ids - ] + stop_reason = ( + VlmStopReason.END_OF_SEQUENCE + if output.outputs[0].stop_reason + else VlmStopReason.LENGTH + ) + generated_tokens = ( + [VlmPredictionToken(token=int(t)) for t in output.outputs[0].token_ids] + if self.vlm_options.track_generated_tokens + else [] + ) num_tokens = len(generated_tokens) decoded_text = self.vlm_options.decode_response(text) - if self.vlm_options.track_generated_tokens: - yield VlmPrediction( - text=decoded_text, - generation_time=generation_time, - num_tokens=num_tokens, - stop_reason=stop_reason, - generated_tokens=generated_tokens, - ) - else: - yield VlmPrediction( - text=decoded_text, - generation_time=generation_time, - num_tokens=num_tokens, - stop_reason=stop_reason, - ) + yield VlmPrediction( + text=decoded_text, + generation_time=generation_time, + num_tokens=num_tokens, + stop_reason=stop_reason, + generated_tokens=generated_tokens, + ) diff --git a/docling/pipeline/extraction_vlm_pipeline.py b/docling/pipeline/extraction_vlm_pipeline.py index 47aba8cd..b995b013 100644 --- a/docling/pipeline/extraction_vlm_pipeline.py +++ b/docling/pipeline/extraction_vlm_pipeline.py @@ -8,7 +8,7 @@ from pydantic import BaseModel from docling.backend.abstract_backend import PaginatedDocumentBackend from docling.backend.pdf_backend import PdfDocumentBackend -from docling.datamodel.base_models import ConversionStatus, ErrorItem +from docling.datamodel.base_models import ConversionStatus, ErrorItem, VlmStopReason from docling.datamodel.document import InputDocument from docling.datamodel.extraction import ( ExtractedPageData, @@ -83,6 +83,12 @@ class ExtractionVlmPipeline(BaseExtractionPipeline): # Parse the extracted text as JSON if possible, otherwise use as-is extracted_text = predictions[0].text extracted_data = None + vlm_stop_reason: VlmStopReason = predictions[0].stop_reason + if ( + vlm_stop_reason == VlmStopReason.LENGTH + or vlm_stop_reason == VlmStopReason.STOP_SEQUENCE + ): + ext_res.status = ConversionStatus.PARTIAL_SUCCESS try: extracted_data = json.loads(extracted_text) @@ -128,7 +134,11 @@ class ExtractionVlmPipeline(BaseExtractionPipeline): def _determine_status(self, ext_res: ExtractionResult) -> ConversionStatus: """Determine the status based on extraction results.""" if ext_res.pages and not any(page.errors for page in ext_res.pages): - return ConversionStatus.SUCCESS + return ( + ConversionStatus.PARTIAL_SUCCESS + if ext_res.status == ConversionStatus.PARTIAL_SUCCESS + else ConversionStatus.SUCCESS + ) else: return ConversionStatus.FAILURE diff --git a/docling/utils/api_image_request.py b/docling/utils/api_image_request.py index a1e55f7a..d638eec8 100644 --- a/docling/utils/api_image_request.py +++ b/docling/utils/api_image_request.py @@ -8,7 +8,7 @@ import requests from PIL import Image from pydantic import AnyUrl -from docling.datamodel.base_models import OpenAiApiResponse +from docling.datamodel.base_models import OpenAiApiResponse, VlmStopReason from docling.models.utils.generation_utils import GenerationStopper _log = logging.getLogger(__name__) @@ -21,7 +21,7 @@ def api_image_request( timeout: float = 20, headers: Optional[dict[str, str]] = None, **params, -) -> Tuple[str, Optional[int]]: +) -> Tuple[str, Optional[int], VlmStopReason]: img_io = BytesIO() image.save(img_io, "PNG") image_base64 = base64.b64encode(img_io.getvalue()).decode("utf-8") @@ -61,7 +61,13 @@ def api_image_request( api_resp = OpenAiApiResponse.model_validate_json(r.text) generated_text = api_resp.choices[0].message.content.strip() num_tokens = api_resp.usage.total_tokens - return generated_text, num_tokens + stop_reason = ( + VlmStopReason.LENGTH + if api_resp.choices[0].finish_reason == "length" + else VlmStopReason.END_OF_SEQUENCE + ) + + return generated_text, num_tokens, stop_reason def api_image_request_streaming(