mirror of
https://github.com/DS4SD/docling.git
synced 2025-12-10 13:48:13 +00:00
feat(vlm): Dynamic prompts (#1808)
* Unify temperature options for Vlm models * Dynamic prompt support with example * DCO Remediation Commit for Shkarupa Alex <shkarupa.alex@gmail.com> I, Shkarupa Alex <shkarupa.alex@gmail.com>, hereby add my Signed-off-by to this commit:34d446cb98I, Shkarupa Alex <shkarupa.alex@gmail.com>, hereby add my Signed-off-by to this commit:9c595d574fSigned-off-by: Shkarupa Alex <shkarupa.alex@gmail.com> * Replace Page with SegmentedPage * Fix example HF repo link Signed-off-by: Christoph Auer <60343111+cau-git@users.noreply.github.com> * Sign-off Signed-off-by: Shkarupa Alex <shkarupa.alex@gmail.com> * DCO Remediation Commit for Shkarupa Alex <shkarupa.alex@gmail.com> I, Shkarupa Alex <shkarupa.alex@gmail.com>, hereby add my Signed-off-by to this commit:1a162066ddSigned-off-by: Shkarupa Alex <shkarupa.alex@gmail.com> Signed-off-by: Shkarupa Alex <shkarupa.alex@gmail.com> * Use lmstudio-community model Signed-off-by: Christoph Auer <60343111+cau-git@users.noreply.github.com> * Swap inference engine to LM Studio Signed-off-by: Shkarupa Alex <shkarupa.alex@gmail.com> --------- Signed-off-by: Shkarupa Alex <shkarupa.alex@gmail.com> Signed-off-by: Christoph Auer <60343111+cau-git@users.noreply.github.com> Co-authored-by: Christoph Auer <60343111+cau-git@users.noreply.github.com>
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
from typing import Any, Callable, Dict, List, Literal, Optional, Union
|
||||
|
||||
from docling_core.types.doc.page import SegmentedPage
|
||||
from pydantic import AnyUrl, BaseModel
|
||||
from typing_extensions import deprecated
|
||||
|
||||
@@ -9,9 +10,10 @@ from docling.datamodel.accelerator_options import AcceleratorDevice
|
||||
|
||||
class BaseVlmOptions(BaseModel):
|
||||
kind: str
|
||||
prompt: str
|
||||
prompt: Union[str, Callable[[Optional[SegmentedPage]], str]]
|
||||
scale: float = 2.0
|
||||
max_size: Optional[int] = None
|
||||
temperature: float = 0.0
|
||||
|
||||
|
||||
class ResponseFormat(str, Enum):
|
||||
@@ -51,7 +53,6 @@ class InlineVlmOptions(BaseVlmOptions):
|
||||
AcceleratorDevice.MPS,
|
||||
]
|
||||
|
||||
temperature: float = 0.0
|
||||
stop_strings: List[str] = []
|
||||
extra_generation_config: Dict[str, Any] = {}
|
||||
|
||||
|
||||
@@ -29,12 +29,9 @@ class ApiVlmModel(BasePageModel):
|
||||
|
||||
self.timeout = self.vlm_options.timeout
|
||||
self.concurrency = self.vlm_options.concurrency
|
||||
self.prompt_content = (
|
||||
f"This is a page from a document.\n{self.vlm_options.prompt}"
|
||||
)
|
||||
self.params = {
|
||||
**self.vlm_options.params,
|
||||
"temperature": 0,
|
||||
"temperature": self.vlm_options.temperature,
|
||||
}
|
||||
|
||||
def __call__(
|
||||
@@ -56,9 +53,14 @@ class ApiVlmModel(BasePageModel):
|
||||
if hi_res_image.mode != "RGB":
|
||||
hi_res_image = hi_res_image.convert("RGB")
|
||||
|
||||
if callable(self.vlm_options.prompt):
|
||||
prompt = self.vlm_options.prompt(page.parsed_page)
|
||||
else:
|
||||
prompt = self.vlm_options.prompt
|
||||
|
||||
page_tags = api_image_request(
|
||||
image=hi_res_image,
|
||||
prompt=self.prompt_content,
|
||||
prompt=prompt,
|
||||
url=self.vlm_options.url,
|
||||
timeout=self.timeout,
|
||||
headers=self.vlm_options.headers,
|
||||
|
||||
@@ -128,7 +128,11 @@ class HuggingFaceTransformersVlmModel(BasePageModel, HuggingFaceModelDownloadMix
|
||||
)
|
||||
|
||||
# Define prompt structure
|
||||
prompt = self.formulate_prompt()
|
||||
if callable(self.vlm_options.prompt):
|
||||
user_prompt = self.vlm_options.prompt(page.parsed_page)
|
||||
else:
|
||||
user_prompt = self.vlm_options.prompt
|
||||
prompt = self.formulate_prompt(user_prompt)
|
||||
|
||||
inputs = self.processor(
|
||||
text=prompt, images=[hi_res_image], return_tensors="pt"
|
||||
@@ -162,7 +166,7 @@ class HuggingFaceTransformersVlmModel(BasePageModel, HuggingFaceModelDownloadMix
|
||||
|
||||
yield page
|
||||
|
||||
def formulate_prompt(self) -> str:
|
||||
def formulate_prompt(self, user_prompt: str) -> str:
|
||||
"""Formulate a prompt for the VLM."""
|
||||
|
||||
if self.vlm_options.repo_id == "microsoft/Phi-4-multimodal-instruct":
|
||||
@@ -173,7 +177,7 @@ class HuggingFaceTransformersVlmModel(BasePageModel, HuggingFaceModelDownloadMix
|
||||
assistant_prompt = "<|assistant|>"
|
||||
prompt_suffix = "<|end|>"
|
||||
|
||||
prompt = f"{user_prompt}<|image_1|>{self.vlm_options.prompt}{prompt_suffix}{assistant_prompt}"
|
||||
prompt = f"{user_prompt}<|image_1|>{user_prompt}{prompt_suffix}{assistant_prompt}"
|
||||
_log.debug(f"prompt for {self.vlm_options.repo_id}: {prompt}")
|
||||
|
||||
return prompt
|
||||
@@ -187,7 +191,7 @@ class HuggingFaceTransformersVlmModel(BasePageModel, HuggingFaceModelDownloadMix
|
||||
"text": "This is a page from a document.",
|
||||
},
|
||||
{"type": "image"},
|
||||
{"type": "text", "text": self.vlm_options.prompt},
|
||||
{"type": "text", "text": user_prompt},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
@@ -56,8 +56,6 @@ class HuggingFaceMlxModel(BasePageModel, HuggingFaceModelDownloadMixin):
|
||||
elif (artifacts_path / repo_cache_folder).exists():
|
||||
artifacts_path = artifacts_path / repo_cache_folder
|
||||
|
||||
self.param_question = vlm_options.prompt
|
||||
|
||||
## Load the model
|
||||
self.vlm_model, self.processor = load(artifacts_path)
|
||||
self.config = load_config(artifacts_path)
|
||||
@@ -86,8 +84,12 @@ class HuggingFaceMlxModel(BasePageModel, HuggingFaceModelDownloadMixin):
|
||||
if hi_res_image.mode != "RGB":
|
||||
hi_res_image = hi_res_image.convert("RGB")
|
||||
|
||||
if callable(self.vlm_options.prompt):
|
||||
user_prompt = self.vlm_options.prompt(page.parsed_page)
|
||||
else:
|
||||
user_prompt = self.vlm_options.prompt
|
||||
prompt = self.apply_chat_template(
|
||||
self.processor, self.config, self.param_question, num_images=1
|
||||
self.processor, self.config, user_prompt, num_images=1
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
@@ -117,6 +117,7 @@ class VlmPipeline(PaginatedPipeline):
|
||||
page._backend = conv_res.input._backend.load_page(page.page_no) # type: ignore
|
||||
if page._backend is not None and page._backend.is_valid():
|
||||
page.size = page._backend.get_size()
|
||||
page.parsed_page = page._backend.get_segmented_page()
|
||||
|
||||
return page
|
||||
|
||||
|
||||
Reference in New Issue
Block a user