feat(vlm): Ability to preprocess VLM response (#1907)

* Add ability to preprocess VLM response

Signed-off-by: Shkarupa Alex <shkarupa.alex@gmail.com>

* Move response decoding to vlm options (requires inheritance to override). Per-page prompt formulation also moved to vlm options to keep api consistent.

Signed-off-by: Shkarupa Alex <shkarupa.alex@gmail.com>

---------

Signed-off-by: Shkarupa Alex <shkarupa.alex@gmail.com>
This commit is contained in:
Shkarupa Alex
2025-08-12 16:20:24 +03:00
committed by GitHub
parent ccfee05847
commit 5f050f94e1
5 changed files with 60 additions and 54 deletions

View File

@@ -1,5 +1,5 @@
from enum import Enum
from typing import Any, Callable, Dict, List, Literal, Optional, Union
from typing import Any, Dict, List, Literal, Optional
from docling_core.types.doc.page import SegmentedPage
from pydantic import AnyUrl, BaseModel
@@ -10,11 +10,17 @@ from docling.datamodel.accelerator_options import AcceleratorDevice
class BaseVlmOptions(BaseModel):
kind: str
prompt: Union[str, Callable[[Optional[SegmentedPage]], str]]
prompt: str
scale: float = 2.0
max_size: Optional[int] = None
temperature: float = 0.0
def build_prompt(self, page: Optional[SegmentedPage]) -> str:
return self.prompt
def decode_response(self, text: str) -> str:
return text
class ResponseFormat(str, Enum):
DOCTAGS = "doctags"

View File

@@ -53,11 +53,7 @@ class ApiVlmModel(BasePageModel):
if hi_res_image.mode != "RGB":
hi_res_image = hi_res_image.convert("RGB")
if callable(self.vlm_options.prompt):
prompt = self.vlm_options.prompt(page.parsed_page)
else:
prompt = self.vlm_options.prompt
prompt = self.vlm_options.build_prompt(page.parsed_page)
page_tags = api_image_request(
image=hi_res_image,
prompt=prompt,
@@ -67,6 +63,7 @@ class ApiVlmModel(BasePageModel):
**self.params,
)
page_tags = self.vlm_options.decode_response(page_tags)
page.predictions.vlm_response = VlmPrediction(text=page_tags)
return page

View File

@@ -135,10 +135,7 @@ class HuggingFaceTransformersVlmModel(BasePageModel, HuggingFaceModelDownloadMix
)
# Define prompt structure
if callable(self.vlm_options.prompt):
user_prompt = self.vlm_options.prompt(page.parsed_page)
else:
user_prompt = self.vlm_options.prompt
user_prompt = self.vlm_options.build_prompt(page.parsed_page)
prompt = self.formulate_prompt(user_prompt)
inputs = self.processor(
@@ -166,6 +163,7 @@ class HuggingFaceTransformersVlmModel(BasePageModel, HuggingFaceModelDownloadMix
_log.debug(
f"Generated {num_tokens} tokens in time {generation_time:.2f} seconds."
)
generated_texts = self.vlm_options.decode_response(generated_texts)
page.predictions.vlm_response = VlmPrediction(
text=generated_texts,
generation_time=generation_time,

View File

@@ -84,10 +84,7 @@ class HuggingFaceMlxModel(BasePageModel, HuggingFaceModelDownloadMixin):
if hi_res_image.mode != "RGB":
hi_res_image = hi_res_image.convert("RGB")
if callable(self.vlm_options.prompt):
user_prompt = self.vlm_options.prompt(page.parsed_page)
else:
user_prompt = self.vlm_options.prompt
user_prompt = self.vlm_options.build_prompt(page.parsed_page)
prompt = self.apply_chat_template(
self.processor, self.config, user_prompt, num_images=1
)
@@ -142,6 +139,7 @@ class HuggingFaceMlxModel(BasePageModel, HuggingFaceModelDownloadMixin):
_log.debug(
f"{generation_time:.2f} seconds for {len(tokens)} tokens ({len(tokens) / generation_time} tokens/sec)."
)
page_tags = self.vlm_options.decode_response(page_tags)
page.predictions.vlm_response = VlmPrediction(
text=page_tags,
generation_time=generation_time,