From 5e1e82ab3b6f0983aa94a05405b4722af8252eb9 Mon Sep 17 00:00:00 2001 From: Shkarupa Alex Date: Tue, 8 Jul 2025 09:29:22 +0300 Subject: [PATCH] Add ability to preprocess VLM response Signed-off-by: Shkarupa Alex --- docling/datamodel/pipeline_options_vlm_model.py | 1 + docling/models/api_vlm_model.py | 2 ++ .../models/vlm_models_inline/hf_transformers_model.py | 4 ++++ docling/models/vlm_models_inline/mlx_model.py | 2 ++ docs/examples/vlm_pipeline_api_model.py | 10 ++++++++++ 5 files changed, 19 insertions(+) diff --git a/docling/datamodel/pipeline_options_vlm_model.py b/docling/datamodel/pipeline_options_vlm_model.py index bcea2493..5e6df3d7 100644 --- a/docling/datamodel/pipeline_options_vlm_model.py +++ b/docling/datamodel/pipeline_options_vlm_model.py @@ -14,6 +14,7 @@ class BaseVlmOptions(BaseModel): scale: float = 2.0 max_size: Optional[int] = None temperature: float = 0.0 + decode_response: Optional[Callable[[str], str]] = None class ResponseFormat(str, Enum): diff --git a/docling/models/api_vlm_model.py b/docling/models/api_vlm_model.py index 164ac285..64c8a034 100644 --- a/docling/models/api_vlm_model.py +++ b/docling/models/api_vlm_model.py @@ -67,6 +67,8 @@ class ApiVlmModel(BasePageModel): **self.params, ) + if self.vlm_options.decode_response: + page_tags = self.vlm_options.decode_response(page_tags) page.predictions.vlm_response = VlmPrediction(text=page_tags) return page diff --git a/docling/models/vlm_models_inline/hf_transformers_model.py b/docling/models/vlm_models_inline/hf_transformers_model.py index d84925dd..498fc860 100644 --- a/docling/models/vlm_models_inline/hf_transformers_model.py +++ b/docling/models/vlm_models_inline/hf_transformers_model.py @@ -166,6 +166,10 @@ class HuggingFaceTransformersVlmModel(BasePageModel, HuggingFaceModelDownloadMix _log.debug( f"Generated {num_tokens} tokens in time {generation_time:.2f} seconds." ) + if self.vlm_options.decode_response: + generated_texts = self.vlm_options.decode_response( + generated_texts + ) page.predictions.vlm_response = VlmPrediction( text=generated_texts, generation_time=generation_time, diff --git a/docling/models/vlm_models_inline/mlx_model.py b/docling/models/vlm_models_inline/mlx_model.py index 647ce531..d8fd03cc 100644 --- a/docling/models/vlm_models_inline/mlx_model.py +++ b/docling/models/vlm_models_inline/mlx_model.py @@ -142,6 +142,8 @@ class HuggingFaceMlxModel(BasePageModel, HuggingFaceModelDownloadMixin): _log.debug( f"{generation_time:.2f} seconds for {len(tokens)} tokens ({len(tokens) / generation_time} tokens/sec)." ) + if self.vlm_options.decode_response: + page_tags = self.vlm_options.decode_response(page_tags) page.predictions.vlm_response = VlmPrediction( text=page_tags, generation_time=generation_time, diff --git a/docs/examples/vlm_pipeline_api_model.py b/docs/examples/vlm_pipeline_api_model.py index a809b926..56c47388 100644 --- a/docs/examples/vlm_pipeline_api_model.py +++ b/docs/examples/vlm_pipeline_api_model.py @@ -1,3 +1,4 @@ +import json import logging import os from pathlib import Path @@ -83,6 +84,14 @@ def lms_olmocr_vlm_options(model: str): f"RAW_TEXT_START\n{base_text}\nRAW_TEXT_END" ) + def _decode_olmocr_response(generated_text: str) -> str: + try: + generated_json = json.loads(generated_text) + except json.decoder.JSONDecodeError: + return "" + + return generated_json["natural_text"] + options = ApiVlmOptions( url="http://localhost:1234/v1/chat/completions", params=dict( @@ -92,6 +101,7 @@ def lms_olmocr_vlm_options(model: str): timeout=90, scale=1.0, max_size=1024, # from OlmOcr pipeline + decode_response=_decode_olmocr_response, response_format=ResponseFormat.MARKDOWN, ) return options