feat(vlm): Ability to preprocess VLM response (#1907)

* Add ability to preprocess VLM response

Signed-off-by: Shkarupa Alex <shkarupa.alex@gmail.com>

* Move response decoding to vlm options (requires inheritance to override). Per-page prompt formulation also moved to vlm options to keep api consistent.

Signed-off-by: Shkarupa Alex <shkarupa.alex@gmail.com>

---------

Signed-off-by: Shkarupa Alex <shkarupa.alex@gmail.com>
This commit is contained in:
Shkarupa Alex
2025-08-12 16:20:24 +03:00
committed by GitHub
parent ccfee05847
commit 5f050f94e1
5 changed files with 60 additions and 54 deletions

View File

@@ -1,5 +1,5 @@
from enum import Enum from enum import Enum
from typing import Any, Callable, Dict, List, Literal, Optional, Union from typing import Any, Dict, List, Literal, Optional
from docling_core.types.doc.page import SegmentedPage from docling_core.types.doc.page import SegmentedPage
from pydantic import AnyUrl, BaseModel from pydantic import AnyUrl, BaseModel
@@ -10,11 +10,17 @@ from docling.datamodel.accelerator_options import AcceleratorDevice
class BaseVlmOptions(BaseModel): class BaseVlmOptions(BaseModel):
kind: str kind: str
prompt: Union[str, Callable[[Optional[SegmentedPage]], str]] prompt: str
scale: float = 2.0 scale: float = 2.0
max_size: Optional[int] = None max_size: Optional[int] = None
temperature: float = 0.0 temperature: float = 0.0
def build_prompt(self, page: Optional[SegmentedPage]) -> str:
return self.prompt
def decode_response(self, text: str) -> str:
return text
class ResponseFormat(str, Enum): class ResponseFormat(str, Enum):
DOCTAGS = "doctags" DOCTAGS = "doctags"

View File

@@ -53,11 +53,7 @@ class ApiVlmModel(BasePageModel):
if hi_res_image.mode != "RGB": if hi_res_image.mode != "RGB":
hi_res_image = hi_res_image.convert("RGB") hi_res_image = hi_res_image.convert("RGB")
if callable(self.vlm_options.prompt): prompt = self.vlm_options.build_prompt(page.parsed_page)
prompt = self.vlm_options.prompt(page.parsed_page)
else:
prompt = self.vlm_options.prompt
page_tags = api_image_request( page_tags = api_image_request(
image=hi_res_image, image=hi_res_image,
prompt=prompt, prompt=prompt,
@@ -67,6 +63,7 @@ class ApiVlmModel(BasePageModel):
**self.params, **self.params,
) )
page_tags = self.vlm_options.decode_response(page_tags)
page.predictions.vlm_response = VlmPrediction(text=page_tags) page.predictions.vlm_response = VlmPrediction(text=page_tags)
return page return page

View File

@@ -135,10 +135,7 @@ class HuggingFaceTransformersVlmModel(BasePageModel, HuggingFaceModelDownloadMix
) )
# Define prompt structure # Define prompt structure
if callable(self.vlm_options.prompt): user_prompt = self.vlm_options.build_prompt(page.parsed_page)
user_prompt = self.vlm_options.prompt(page.parsed_page)
else:
user_prompt = self.vlm_options.prompt
prompt = self.formulate_prompt(user_prompt) prompt = self.formulate_prompt(user_prompt)
inputs = self.processor( inputs = self.processor(
@@ -166,6 +163,7 @@ class HuggingFaceTransformersVlmModel(BasePageModel, HuggingFaceModelDownloadMix
_log.debug( _log.debug(
f"Generated {num_tokens} tokens in time {generation_time:.2f} seconds." f"Generated {num_tokens} tokens in time {generation_time:.2f} seconds."
) )
generated_texts = self.vlm_options.decode_response(generated_texts)
page.predictions.vlm_response = VlmPrediction( page.predictions.vlm_response = VlmPrediction(
text=generated_texts, text=generated_texts,
generation_time=generation_time, generation_time=generation_time,

View File

@@ -84,10 +84,7 @@ class HuggingFaceMlxModel(BasePageModel, HuggingFaceModelDownloadMixin):
if hi_res_image.mode != "RGB": if hi_res_image.mode != "RGB":
hi_res_image = hi_res_image.convert("RGB") hi_res_image = hi_res_image.convert("RGB")
if callable(self.vlm_options.prompt): user_prompt = self.vlm_options.build_prompt(page.parsed_page)
user_prompt = self.vlm_options.prompt(page.parsed_page)
else:
user_prompt = self.vlm_options.prompt
prompt = self.apply_chat_template( prompt = self.apply_chat_template(
self.processor, self.config, user_prompt, num_images=1 self.processor, self.config, user_prompt, num_images=1
) )
@@ -142,6 +139,7 @@ class HuggingFaceMlxModel(BasePageModel, HuggingFaceModelDownloadMixin):
_log.debug( _log.debug(
f"{generation_time:.2f} seconds for {len(tokens)} tokens ({len(tokens) / generation_time} tokens/sec)." f"{generation_time:.2f} seconds for {len(tokens)} tokens ({len(tokens) / generation_time} tokens/sec)."
) )
page_tags = self.vlm_options.decode_response(page_tags)
page.predictions.vlm_response = VlmPrediction( page.predictions.vlm_response = VlmPrediction(
text=page_tags, text=page_tags,
generation_time=generation_time, generation_time=generation_time,

View File

@@ -1,3 +1,4 @@
import json
import logging import logging
import os import os
from pathlib import Path from pathlib import Path
@@ -38,57 +39,63 @@ def lms_vlm_options(model: str, prompt: str, format: ResponseFormat):
def lms_olmocr_vlm_options(model: str): def lms_olmocr_vlm_options(model: str):
def _dynamic_olmocr_prompt(page: Optional[SegmentedPage]): class OlmocrVlmOptions(ApiVlmOptions):
if page is None: def build_prompt(self, page: Optional[SegmentedPage]) -> str:
return ( if page is None:
"Below is the image of one page of a document. Just return the plain text" return self.prompt.replace("#RAW_TEXT#", "")
" representation of this document as if you were reading it naturally.\n"
"Do not hallucinate.\n"
)
anchor = [ anchor = [
f"Page dimensions: {int(page.dimension.width)}x{int(page.dimension.height)}" f"Page dimensions: {int(page.dimension.width)}x{int(page.dimension.height)}"
] ]
for text_cell in page.textline_cells: for text_cell in page.textline_cells:
if not text_cell.text.strip(): if not text_cell.text.strip():
continue continue
bbox = text_cell.rect.to_bounding_box().to_bottom_left_origin( bbox = text_cell.rect.to_bounding_box().to_bottom_left_origin(
page.dimension.height page.dimension.height
) )
anchor.append(f"[{int(bbox.l)}x{int(bbox.b)}] {text_cell.text}") anchor.append(f"[{int(bbox.l)}x{int(bbox.b)}] {text_cell.text}")
for image_cell in page.bitmap_resources: for image_cell in page.bitmap_resources:
bbox = image_cell.rect.to_bounding_box().to_bottom_left_origin( bbox = image_cell.rect.to_bounding_box().to_bottom_left_origin(
page.dimension.height page.dimension.height
) )
anchor.append( anchor.append(
f"[Image {int(bbox.l)}x{int(bbox.b)} to {int(bbox.r)}x{int(bbox.t)}]" f"[Image {int(bbox.l)}x{int(bbox.b)} to {int(bbox.r)}x{int(bbox.t)}]"
) )
if len(anchor) == 1: if len(anchor) == 1:
anchor.append( anchor.append(
f"[Image 0x0 to {int(page.dimension.width)}x{int(page.dimension.height)}]" f"[Image 0x0 to {int(page.dimension.width)}x{int(page.dimension.height)}]"
) )
# Original prompt uses cells sorting. We are skipping it in this demo. # Original prompt uses cells sorting. We are skipping it for simplicity.
base_text = "\n".join(anchor) raw_text = "\n".join(anchor)
return ( return self.prompt.replace("#RAW_TEXT#", raw_text)
f"Below is the image of one page of a document, as well as some raw textual"
f" content that was previously extracted for it. Just return the plain text"
f" representation of this document as if you were reading it naturally.\n"
f"Do not hallucinate.\n"
f"RAW_TEXT_START\n{base_text}\nRAW_TEXT_END"
)
options = ApiVlmOptions( def decode_response(self, text: str) -> str:
# OlmOcr trained to generate json response with language, rotation and other info
try:
generated_json = json.loads(text)
except json.decoder.JSONDecodeError:
return ""
return generated_json["natural_text"]
options = OlmocrVlmOptions(
url="http://localhost:1234/v1/chat/completions", url="http://localhost:1234/v1/chat/completions",
params=dict( params=dict(
model=model, model=model,
), ),
prompt=_dynamic_olmocr_prompt, prompt=(
"Below is the image of one page of a document, as well as some raw textual"
" content that was previously extracted for it. Just return the plain text"
" representation of this document as if you were reading it naturally.\n"
"Do not hallucinate.\n"
"RAW_TEXT_START\n#RAW_TEXT#\nRAW_TEXT_END"
),
timeout=90, timeout=90,
scale=1.0, scale=1.0,
max_size=1024, # from OlmOcr pipeline max_size=1024, # from OlmOcr pipeline