feat(vlm): Dynamic prompts (#1808)

* Unify temperature options for Vlm models

* Dynamic prompt support with example

* DCO Remediation Commit for Shkarupa Alex <shkarupa.alex@gmail.com>

I, Shkarupa Alex <shkarupa.alex@gmail.com>, hereby add my Signed-off-by to this commit: 34d446cb98
I, Shkarupa Alex <shkarupa.alex@gmail.com>, hereby add my Signed-off-by to this commit: 9c595d574f

Signed-off-by: Shkarupa Alex <shkarupa.alex@gmail.com>

* Replace Page with SegmentedPage

* Fix example HF repo link 

Signed-off-by: Christoph Auer <60343111+cau-git@users.noreply.github.com>

* Sign-off

Signed-off-by: Shkarupa Alex <shkarupa.alex@gmail.com>

* DCO Remediation Commit for Shkarupa Alex <shkarupa.alex@gmail.com>

I, Shkarupa Alex <shkarupa.alex@gmail.com>, hereby add my Signed-off-by to this commit: 1a162066dd

Signed-off-by: Shkarupa Alex <shkarupa.alex@gmail.com>

Signed-off-by: Shkarupa Alex <shkarupa.alex@gmail.com>

* Use lmstudio-community model

Signed-off-by: Christoph Auer <60343111+cau-git@users.noreply.github.com>

* Swap inference engine to LM Studio

Signed-off-by: Shkarupa Alex <shkarupa.alex@gmail.com>

---------

Signed-off-by: Shkarupa Alex <shkarupa.alex@gmail.com>
Signed-off-by: Christoph Auer <60343111+cau-git@users.noreply.github.com>
Co-authored-by: Christoph Auer <60343111+cau-git@users.noreply.github.com>
This commit is contained in:
Shkarupa Alex 2025-07-07 17:58:42 +03:00 committed by GitHub
parent edd4356aac
commit b8813eea80
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 96 additions and 15 deletions

View File

@ -1,6 +1,7 @@
from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Union
from typing import Any, Callable, Dict, List, Literal, Optional, Union
from docling_core.types.doc.page import SegmentedPage
from pydantic import AnyUrl, BaseModel
from typing_extensions import deprecated
@ -9,9 +10,10 @@ from docling.datamodel.accelerator_options import AcceleratorDevice
class BaseVlmOptions(BaseModel):
kind: str
prompt: str
prompt: Union[str, Callable[[Optional[SegmentedPage]], str]]
scale: float = 2.0
max_size: Optional[int] = None
temperature: float = 0.0
class ResponseFormat(str, Enum):
@ -51,7 +53,6 @@ class InlineVlmOptions(BaseVlmOptions):
AcceleratorDevice.MPS,
]
temperature: float = 0.0
stop_strings: List[str] = []
extra_generation_config: Dict[str, Any] = {}

View File

@ -29,12 +29,9 @@ class ApiVlmModel(BasePageModel):
self.timeout = self.vlm_options.timeout
self.concurrency = self.vlm_options.concurrency
self.prompt_content = (
f"This is a page from a document.\n{self.vlm_options.prompt}"
)
self.params = {
**self.vlm_options.params,
"temperature": 0,
"temperature": self.vlm_options.temperature,
}
def __call__(
@ -56,9 +53,14 @@ class ApiVlmModel(BasePageModel):
if hi_res_image.mode != "RGB":
hi_res_image = hi_res_image.convert("RGB")
if callable(self.vlm_options.prompt):
prompt = self.vlm_options.prompt(page.parsed_page)
else:
prompt = self.vlm_options.prompt
page_tags = api_image_request(
image=hi_res_image,
prompt=self.prompt_content,
prompt=prompt,
url=self.vlm_options.url,
timeout=self.timeout,
headers=self.vlm_options.headers,

View File

@ -128,7 +128,11 @@ class HuggingFaceTransformersVlmModel(BasePageModel, HuggingFaceModelDownloadMix
)
# Define prompt structure
prompt = self.formulate_prompt()
if callable(self.vlm_options.prompt):
user_prompt = self.vlm_options.prompt(page.parsed_page)
else:
user_prompt = self.vlm_options.prompt
prompt = self.formulate_prompt(user_prompt)
inputs = self.processor(
text=prompt, images=[hi_res_image], return_tensors="pt"
@ -162,7 +166,7 @@ class HuggingFaceTransformersVlmModel(BasePageModel, HuggingFaceModelDownloadMix
yield page
def formulate_prompt(self) -> str:
def formulate_prompt(self, user_prompt: str) -> str:
"""Formulate a prompt for the VLM."""
if self.vlm_options.repo_id == "microsoft/Phi-4-multimodal-instruct":
@ -173,7 +177,7 @@ class HuggingFaceTransformersVlmModel(BasePageModel, HuggingFaceModelDownloadMix
assistant_prompt = "<|assistant|>"
prompt_suffix = "<|end|>"
prompt = f"{user_prompt}<|image_1|>{self.vlm_options.prompt}{prompt_suffix}{assistant_prompt}"
prompt = f"{user_prompt}<|image_1|>{user_prompt}{prompt_suffix}{assistant_prompt}"
_log.debug(f"prompt for {self.vlm_options.repo_id}: {prompt}")
return prompt
@ -187,7 +191,7 @@ class HuggingFaceTransformersVlmModel(BasePageModel, HuggingFaceModelDownloadMix
"text": "This is a page from a document.",
},
{"type": "image"},
{"type": "text", "text": self.vlm_options.prompt},
{"type": "text", "text": user_prompt},
],
}
]

View File

@ -56,8 +56,6 @@ class HuggingFaceMlxModel(BasePageModel, HuggingFaceModelDownloadMixin):
elif (artifacts_path / repo_cache_folder).exists():
artifacts_path = artifacts_path / repo_cache_folder
self.param_question = vlm_options.prompt
## Load the model
self.vlm_model, self.processor = load(artifacts_path)
self.config = load_config(artifacts_path)
@ -86,8 +84,12 @@ class HuggingFaceMlxModel(BasePageModel, HuggingFaceModelDownloadMixin):
if hi_res_image.mode != "RGB":
hi_res_image = hi_res_image.convert("RGB")
if callable(self.vlm_options.prompt):
user_prompt = self.vlm_options.prompt(page.parsed_page)
else:
user_prompt = self.vlm_options.prompt
prompt = self.apply_chat_template(
self.processor, self.config, self.param_question, num_images=1
self.processor, self.config, user_prompt, num_images=1
)
start_time = time.time()

View File

@ -117,6 +117,7 @@ class VlmPipeline(PaginatedPipeline):
page._backend = conv_res.input._backend.load_page(page.page_no) # type: ignore
if page._backend is not None and page._backend.is_valid():
page.size = page._backend.get_size()
page.parsed_page = page._backend.get_segmented_page()
return page

View File

@ -1,8 +1,10 @@
import logging
import os
from pathlib import Path
from typing import Optional
import requests
from docling_core.types.doc.page import SegmentedPage
from dotenv import load_dotenv
from docling.datamodel.base_models import InputFormat
@ -32,6 +34,69 @@ def lms_vlm_options(model: str, prompt: str, format: ResponseFormat):
return options
#### Using LM Studio with OlmOcr model
def lms_olmocr_vlm_options(model: str):
def _dynamic_olmocr_prompt(page: Optional[SegmentedPage]):
if page is None:
return (
"Below is the image of one page of a document. Just return the plain text"
" representation of this document as if you were reading it naturally.\n"
"Do not hallucinate.\n"
)
anchor = [
f"Page dimensions: {int(page.dimension.width)}x{int(page.dimension.height)}"
]
for text_cell in page.textline_cells:
if not text_cell.text.strip():
continue
bbox = text_cell.rect.to_bounding_box().to_bottom_left_origin(
page.dimension.height
)
anchor.append(f"[{int(bbox.l)}x{int(bbox.b)}] {text_cell.text}")
for image_cell in page.bitmap_resources:
bbox = image_cell.rect.to_bounding_box().to_bottom_left_origin(
page.dimension.height
)
anchor.append(
f"[Image {int(bbox.l)}x{int(bbox.b)} to {int(bbox.r)}x{int(bbox.t)}]"
)
if len(anchor) == 1:
anchor.append(
f"[Image 0x0 to {int(page.dimension.width)}x{int(page.dimension.height)}]"
)
# Original prompt uses cells sorting. We are skipping it in this demo.
base_text = "\n".join(anchor)
return (
f"Below is the image of one page of a document, as well as some raw textual"
f" content that was previously extracted for it. Just return the plain text"
f" representation of this document as if you were reading it naturally.\n"
f"Do not hallucinate.\n"
f"RAW_TEXT_START\n{base_text}\nRAW_TEXT_END"
)
options = ApiVlmOptions(
url="http://localhost:1234/v1/chat/completions",
params=dict(
model=model,
),
prompt=_dynamic_olmocr_prompt,
timeout=90,
scale=1.0,
max_size=1024, # from OlmOcr pipeline
response_format=ResponseFormat.MARKDOWN,
)
return options
#### Using Ollama
@ -123,6 +188,12 @@ def main():
# format=ResponseFormat.MARKDOWN,
# )
# Example using the OlmOcr (dynamic prompt) model with LM Studio:
# (uncomment the following lines)
# pipeline_options.vlm_options = lms_olmocr_vlm_options(
# model="hf.co/lmstudio-community/olmOCR-7B-0225-preview-GGUF",
# )
# Example using the Granite Vision model with Ollama:
# (uncomment the following lines)
# pipeline_options.vlm_options = ollama_vlm_options(