mirror of
https://github.com/DS4SD/docling.git
synced 2025-07-24 19:14:23 +00:00
feat(vlm): Dynamic prompts (#1808)
* Unify temperature options for Vlm models * Dynamic prompt support with example * DCO Remediation Commit for Shkarupa Alex <shkarupa.alex@gmail.com> I, Shkarupa Alex <shkarupa.alex@gmail.com>, hereby add my Signed-off-by to this commit:34d446cb98
I, Shkarupa Alex <shkarupa.alex@gmail.com>, hereby add my Signed-off-by to this commit:9c595d574f
Signed-off-by: Shkarupa Alex <shkarupa.alex@gmail.com> * Replace Page with SegmentedPage * Fix example HF repo link Signed-off-by: Christoph Auer <60343111+cau-git@users.noreply.github.com> * Sign-off Signed-off-by: Shkarupa Alex <shkarupa.alex@gmail.com> * DCO Remediation Commit for Shkarupa Alex <shkarupa.alex@gmail.com> I, Shkarupa Alex <shkarupa.alex@gmail.com>, hereby add my Signed-off-by to this commit:1a162066dd
Signed-off-by: Shkarupa Alex <shkarupa.alex@gmail.com> Signed-off-by: Shkarupa Alex <shkarupa.alex@gmail.com> * Use lmstudio-community model Signed-off-by: Christoph Auer <60343111+cau-git@users.noreply.github.com> * Swap inference engine to LM Studio Signed-off-by: Shkarupa Alex <shkarupa.alex@gmail.com> --------- Signed-off-by: Shkarupa Alex <shkarupa.alex@gmail.com> Signed-off-by: Christoph Auer <60343111+cau-git@users.noreply.github.com> Co-authored-by: Christoph Auer <60343111+cau-git@users.noreply.github.com>
This commit is contained in:
parent
edd4356aac
commit
b8813eea80
@ -1,6 +1,7 @@
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, Literal, Optional, Union
|
from typing import Any, Callable, Dict, List, Literal, Optional, Union
|
||||||
|
|
||||||
|
from docling_core.types.doc.page import SegmentedPage
|
||||||
from pydantic import AnyUrl, BaseModel
|
from pydantic import AnyUrl, BaseModel
|
||||||
from typing_extensions import deprecated
|
from typing_extensions import deprecated
|
||||||
|
|
||||||
@ -9,9 +10,10 @@ from docling.datamodel.accelerator_options import AcceleratorDevice
|
|||||||
|
|
||||||
class BaseVlmOptions(BaseModel):
|
class BaseVlmOptions(BaseModel):
|
||||||
kind: str
|
kind: str
|
||||||
prompt: str
|
prompt: Union[str, Callable[[Optional[SegmentedPage]], str]]
|
||||||
scale: float = 2.0
|
scale: float = 2.0
|
||||||
max_size: Optional[int] = None
|
max_size: Optional[int] = None
|
||||||
|
temperature: float = 0.0
|
||||||
|
|
||||||
|
|
||||||
class ResponseFormat(str, Enum):
|
class ResponseFormat(str, Enum):
|
||||||
@ -51,7 +53,6 @@ class InlineVlmOptions(BaseVlmOptions):
|
|||||||
AcceleratorDevice.MPS,
|
AcceleratorDevice.MPS,
|
||||||
]
|
]
|
||||||
|
|
||||||
temperature: float = 0.0
|
|
||||||
stop_strings: List[str] = []
|
stop_strings: List[str] = []
|
||||||
extra_generation_config: Dict[str, Any] = {}
|
extra_generation_config: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
@ -29,12 +29,9 @@ class ApiVlmModel(BasePageModel):
|
|||||||
|
|
||||||
self.timeout = self.vlm_options.timeout
|
self.timeout = self.vlm_options.timeout
|
||||||
self.concurrency = self.vlm_options.concurrency
|
self.concurrency = self.vlm_options.concurrency
|
||||||
self.prompt_content = (
|
|
||||||
f"This is a page from a document.\n{self.vlm_options.prompt}"
|
|
||||||
)
|
|
||||||
self.params = {
|
self.params = {
|
||||||
**self.vlm_options.params,
|
**self.vlm_options.params,
|
||||||
"temperature": 0,
|
"temperature": self.vlm_options.temperature,
|
||||||
}
|
}
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
@ -56,9 +53,14 @@ class ApiVlmModel(BasePageModel):
|
|||||||
if hi_res_image.mode != "RGB":
|
if hi_res_image.mode != "RGB":
|
||||||
hi_res_image = hi_res_image.convert("RGB")
|
hi_res_image = hi_res_image.convert("RGB")
|
||||||
|
|
||||||
|
if callable(self.vlm_options.prompt):
|
||||||
|
prompt = self.vlm_options.prompt(page.parsed_page)
|
||||||
|
else:
|
||||||
|
prompt = self.vlm_options.prompt
|
||||||
|
|
||||||
page_tags = api_image_request(
|
page_tags = api_image_request(
|
||||||
image=hi_res_image,
|
image=hi_res_image,
|
||||||
prompt=self.prompt_content,
|
prompt=prompt,
|
||||||
url=self.vlm_options.url,
|
url=self.vlm_options.url,
|
||||||
timeout=self.timeout,
|
timeout=self.timeout,
|
||||||
headers=self.vlm_options.headers,
|
headers=self.vlm_options.headers,
|
||||||
|
@ -128,7 +128,11 @@ class HuggingFaceTransformersVlmModel(BasePageModel, HuggingFaceModelDownloadMix
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Define prompt structure
|
# Define prompt structure
|
||||||
prompt = self.formulate_prompt()
|
if callable(self.vlm_options.prompt):
|
||||||
|
user_prompt = self.vlm_options.prompt(page.parsed_page)
|
||||||
|
else:
|
||||||
|
user_prompt = self.vlm_options.prompt
|
||||||
|
prompt = self.formulate_prompt(user_prompt)
|
||||||
|
|
||||||
inputs = self.processor(
|
inputs = self.processor(
|
||||||
text=prompt, images=[hi_res_image], return_tensors="pt"
|
text=prompt, images=[hi_res_image], return_tensors="pt"
|
||||||
@ -162,7 +166,7 @@ class HuggingFaceTransformersVlmModel(BasePageModel, HuggingFaceModelDownloadMix
|
|||||||
|
|
||||||
yield page
|
yield page
|
||||||
|
|
||||||
def formulate_prompt(self) -> str:
|
def formulate_prompt(self, user_prompt: str) -> str:
|
||||||
"""Formulate a prompt for the VLM."""
|
"""Formulate a prompt for the VLM."""
|
||||||
|
|
||||||
if self.vlm_options.repo_id == "microsoft/Phi-4-multimodal-instruct":
|
if self.vlm_options.repo_id == "microsoft/Phi-4-multimodal-instruct":
|
||||||
@ -173,7 +177,7 @@ class HuggingFaceTransformersVlmModel(BasePageModel, HuggingFaceModelDownloadMix
|
|||||||
assistant_prompt = "<|assistant|>"
|
assistant_prompt = "<|assistant|>"
|
||||||
prompt_suffix = "<|end|>"
|
prompt_suffix = "<|end|>"
|
||||||
|
|
||||||
prompt = f"{user_prompt}<|image_1|>{self.vlm_options.prompt}{prompt_suffix}{assistant_prompt}"
|
prompt = f"{user_prompt}<|image_1|>{user_prompt}{prompt_suffix}{assistant_prompt}"
|
||||||
_log.debug(f"prompt for {self.vlm_options.repo_id}: {prompt}")
|
_log.debug(f"prompt for {self.vlm_options.repo_id}: {prompt}")
|
||||||
|
|
||||||
return prompt
|
return prompt
|
||||||
@ -187,7 +191,7 @@ class HuggingFaceTransformersVlmModel(BasePageModel, HuggingFaceModelDownloadMix
|
|||||||
"text": "This is a page from a document.",
|
"text": "This is a page from a document.",
|
||||||
},
|
},
|
||||||
{"type": "image"},
|
{"type": "image"},
|
||||||
{"type": "text", "text": self.vlm_options.prompt},
|
{"type": "text", "text": user_prompt},
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
@ -56,8 +56,6 @@ class HuggingFaceMlxModel(BasePageModel, HuggingFaceModelDownloadMixin):
|
|||||||
elif (artifacts_path / repo_cache_folder).exists():
|
elif (artifacts_path / repo_cache_folder).exists():
|
||||||
artifacts_path = artifacts_path / repo_cache_folder
|
artifacts_path = artifacts_path / repo_cache_folder
|
||||||
|
|
||||||
self.param_question = vlm_options.prompt
|
|
||||||
|
|
||||||
## Load the model
|
## Load the model
|
||||||
self.vlm_model, self.processor = load(artifacts_path)
|
self.vlm_model, self.processor = load(artifacts_path)
|
||||||
self.config = load_config(artifacts_path)
|
self.config = load_config(artifacts_path)
|
||||||
@ -86,8 +84,12 @@ class HuggingFaceMlxModel(BasePageModel, HuggingFaceModelDownloadMixin):
|
|||||||
if hi_res_image.mode != "RGB":
|
if hi_res_image.mode != "RGB":
|
||||||
hi_res_image = hi_res_image.convert("RGB")
|
hi_res_image = hi_res_image.convert("RGB")
|
||||||
|
|
||||||
|
if callable(self.vlm_options.prompt):
|
||||||
|
user_prompt = self.vlm_options.prompt(page.parsed_page)
|
||||||
|
else:
|
||||||
|
user_prompt = self.vlm_options.prompt
|
||||||
prompt = self.apply_chat_template(
|
prompt = self.apply_chat_template(
|
||||||
self.processor, self.config, self.param_question, num_images=1
|
self.processor, self.config, user_prompt, num_images=1
|
||||||
)
|
)
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
@ -117,6 +117,7 @@ class VlmPipeline(PaginatedPipeline):
|
|||||||
page._backend = conv_res.input._backend.load_page(page.page_no) # type: ignore
|
page._backend = conv_res.input._backend.load_page(page.page_no) # type: ignore
|
||||||
if page._backend is not None and page._backend.is_valid():
|
if page._backend is not None and page._backend.is_valid():
|
||||||
page.size = page._backend.get_size()
|
page.size = page._backend.get_size()
|
||||||
|
page.parsed_page = page._backend.get_segmented_page()
|
||||||
|
|
||||||
return page
|
return page
|
||||||
|
|
||||||
|
71
docs/examples/vlm_pipeline_api_model.py
vendored
71
docs/examples/vlm_pipeline_api_model.py
vendored
@ -1,8 +1,10 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
from docling_core.types.doc.page import SegmentedPage
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
from docling.datamodel.base_models import InputFormat
|
from docling.datamodel.base_models import InputFormat
|
||||||
@ -32,6 +34,69 @@ def lms_vlm_options(model: str, prompt: str, format: ResponseFormat):
|
|||||||
return options
|
return options
|
||||||
|
|
||||||
|
|
||||||
|
#### Using LM Studio with OlmOcr model
|
||||||
|
|
||||||
|
|
||||||
|
def lms_olmocr_vlm_options(model: str):
|
||||||
|
def _dynamic_olmocr_prompt(page: Optional[SegmentedPage]):
|
||||||
|
if page is None:
|
||||||
|
return (
|
||||||
|
"Below is the image of one page of a document. Just return the plain text"
|
||||||
|
" representation of this document as if you were reading it naturally.\n"
|
||||||
|
"Do not hallucinate.\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
anchor = [
|
||||||
|
f"Page dimensions: {int(page.dimension.width)}x{int(page.dimension.height)}"
|
||||||
|
]
|
||||||
|
|
||||||
|
for text_cell in page.textline_cells:
|
||||||
|
if not text_cell.text.strip():
|
||||||
|
continue
|
||||||
|
bbox = text_cell.rect.to_bounding_box().to_bottom_left_origin(
|
||||||
|
page.dimension.height
|
||||||
|
)
|
||||||
|
anchor.append(f"[{int(bbox.l)}x{int(bbox.b)}] {text_cell.text}")
|
||||||
|
|
||||||
|
for image_cell in page.bitmap_resources:
|
||||||
|
bbox = image_cell.rect.to_bounding_box().to_bottom_left_origin(
|
||||||
|
page.dimension.height
|
||||||
|
)
|
||||||
|
anchor.append(
|
||||||
|
f"[Image {int(bbox.l)}x{int(bbox.b)} to {int(bbox.r)}x{int(bbox.t)}]"
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(anchor) == 1:
|
||||||
|
anchor.append(
|
||||||
|
f"[Image 0x0 to {int(page.dimension.width)}x{int(page.dimension.height)}]"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Original prompt uses cells sorting. We are skipping it in this demo.
|
||||||
|
|
||||||
|
base_text = "\n".join(anchor)
|
||||||
|
|
||||||
|
return (
|
||||||
|
f"Below is the image of one page of a document, as well as some raw textual"
|
||||||
|
f" content that was previously extracted for it. Just return the plain text"
|
||||||
|
f" representation of this document as if you were reading it naturally.\n"
|
||||||
|
f"Do not hallucinate.\n"
|
||||||
|
f"RAW_TEXT_START\n{base_text}\nRAW_TEXT_END"
|
||||||
|
)
|
||||||
|
|
||||||
|
options = ApiVlmOptions(
|
||||||
|
url="http://localhost:1234/v1/chat/completions",
|
||||||
|
params=dict(
|
||||||
|
model=model,
|
||||||
|
),
|
||||||
|
prompt=_dynamic_olmocr_prompt,
|
||||||
|
timeout=90,
|
||||||
|
scale=1.0,
|
||||||
|
max_size=1024, # from OlmOcr pipeline
|
||||||
|
response_format=ResponseFormat.MARKDOWN,
|
||||||
|
)
|
||||||
|
return options
|
||||||
|
|
||||||
|
|
||||||
#### Using Ollama
|
#### Using Ollama
|
||||||
|
|
||||||
|
|
||||||
@ -123,6 +188,12 @@ def main():
|
|||||||
# format=ResponseFormat.MARKDOWN,
|
# format=ResponseFormat.MARKDOWN,
|
||||||
# )
|
# )
|
||||||
|
|
||||||
|
# Example using the OlmOcr (dynamic prompt) model with LM Studio:
|
||||||
|
# (uncomment the following lines)
|
||||||
|
# pipeline_options.vlm_options = lms_olmocr_vlm_options(
|
||||||
|
# model="hf.co/lmstudio-community/olmOCR-7B-0225-preview-GGUF",
|
||||||
|
# )
|
||||||
|
|
||||||
# Example using the Granite Vision model with Ollama:
|
# Example using the Granite Vision model with Ollama:
|
||||||
# (uncomment the following lines)
|
# (uncomment the following lines)
|
||||||
# pipeline_options.vlm_options = ollama_vlm_options(
|
# pipeline_options.vlm_options = ollama_vlm_options(
|
||||||
|
Loading…
Reference in New Issue
Block a user