diff --git a/docling/datamodel/pipeline_options_vlm_model.py b/docling/datamodel/pipeline_options_vlm_model.py index 96c61e86..5e60bcc2 100644 --- a/docling/datamodel/pipeline_options_vlm_model.py +++ b/docling/datamodel/pipeline_options_vlm_model.py @@ -1,15 +1,16 @@ from enum import Enum -from typing import Any, Dict, List, Literal, Optional, Union +from typing import Any, Callable, Dict, List, Literal, Optional, Union from pydantic import AnyUrl, BaseModel from typing_extensions import deprecated from docling.datamodel.accelerator_options import AcceleratorDevice +from docling.datamodel.base_models import Page class BaseVlmOptions(BaseModel): kind: str - prompt: str + prompt: Union[str, Callable[[Page], str]] scale: float = 2.0 max_size: Optional[int] = None temperature: float = 0.0 diff --git a/docling/models/api_vlm_model.py b/docling/models/api_vlm_model.py index 63d64a25..646c2cef 100644 --- a/docling/models/api_vlm_model.py +++ b/docling/models/api_vlm_model.py @@ -29,9 +29,6 @@ class ApiVlmModel(BasePageModel): self.timeout = self.vlm_options.timeout self.concurrency = self.vlm_options.concurrency - self.prompt_content = ( - f"This is a page from a document.\n{self.vlm_options.prompt}" - ) self.params = { **self.vlm_options.params, "temperature": self.vlm_options.temperature, @@ -56,9 +53,14 @@ class ApiVlmModel(BasePageModel): if hi_res_image.mode != "RGB": hi_res_image = hi_res_image.convert("RGB") + if callable(self.vlm_options.prompt): + prompt = self.vlm_options.prompt(page) + else: + prompt = self.vlm_options.prompt + page_tags = api_image_request( image=hi_res_image, - prompt=self.prompt_content, + prompt=prompt, url=self.vlm_options.url, timeout=self.timeout, headers=self.vlm_options.headers, diff --git a/docling/models/vlm_models_inline/hf_transformers_model.py b/docling/models/vlm_models_inline/hf_transformers_model.py index bd35888d..ac58ba87 100644 --- a/docling/models/vlm_models_inline/hf_transformers_model.py +++ b/docling/models/vlm_models_inline/hf_transformers_model.py @@ -128,7 +128,11 @@ class HuggingFaceTransformersVlmModel(BasePageModel, HuggingFaceModelDownloadMix ) # Define prompt structure - prompt = self.formulate_prompt() + if callable(self.vlm_options.prompt): + user_prompt = self.vlm_options.prompt(page) + else: + user_prompt = self.vlm_options.prompt + prompt = self.formulate_prompt(user_prompt) inputs = self.processor( text=prompt, images=[hi_res_image], return_tensors="pt" @@ -162,7 +166,7 @@ class HuggingFaceTransformersVlmModel(BasePageModel, HuggingFaceModelDownloadMix yield page - def formulate_prompt(self) -> str: + def formulate_prompt(self, user_prompt: str) -> str: """Formulate a prompt for the VLM.""" if self.vlm_options.repo_id == "microsoft/Phi-4-multimodal-instruct": @@ -173,7 +177,7 @@ class HuggingFaceTransformersVlmModel(BasePageModel, HuggingFaceModelDownloadMix assistant_prompt = "<|assistant|>" prompt_suffix = "<|end|>" - prompt = f"{user_prompt}<|image_1|>{self.vlm_options.prompt}{prompt_suffix}{assistant_prompt}" + prompt = f"{user_prompt}<|image_1|>{user_prompt}{prompt_suffix}{assistant_prompt}" _log.debug(f"prompt for {self.vlm_options.repo_id}: {prompt}") return prompt @@ -187,7 +191,7 @@ class HuggingFaceTransformersVlmModel(BasePageModel, HuggingFaceModelDownloadMix "text": "This is a page from a document.", }, {"type": "image"}, - {"type": "text", "text": self.vlm_options.prompt}, + {"type": "text", "text": user_prompt}, ], } ] diff --git a/docling/models/vlm_models_inline/mlx_model.py b/docling/models/vlm_models_inline/mlx_model.py index 58f037fc..cf403069 100644 --- a/docling/models/vlm_models_inline/mlx_model.py +++ b/docling/models/vlm_models_inline/mlx_model.py @@ -56,8 +56,6 @@ class HuggingFaceMlxModel(BasePageModel, HuggingFaceModelDownloadMixin): elif (artifacts_path / repo_cache_folder).exists(): artifacts_path = artifacts_path / repo_cache_folder - self.param_question = vlm_options.prompt - ## Load the model self.vlm_model, self.processor = load(artifacts_path) self.config = load_config(artifacts_path) @@ -86,8 +84,12 @@ class HuggingFaceMlxModel(BasePageModel, HuggingFaceModelDownloadMixin): if hi_res_image.mode != "RGB": hi_res_image = hi_res_image.convert("RGB") + if callable(self.vlm_options.prompt): + user_prompt = self.vlm_options.prompt(page) + else: + user_prompt = self.vlm_options.prompt prompt = self.apply_chat_template( - self.processor, self.config, self.param_question, num_images=1 + self.processor, self.config, user_prompt, num_images=1 ) start_time = time.time() diff --git a/docs/examples/vlm_pipeline_api_model.py b/docs/examples/vlm_pipeline_api_model.py index 20ca2591..ff377808 100644 --- a/docs/examples/vlm_pipeline_api_model.py +++ b/docs/examples/vlm_pipeline_api_model.py @@ -5,7 +5,7 @@ from pathlib import Path import requests from dotenv import load_dotenv -from docling.datamodel.base_models import InputFormat +from docling.datamodel.base_models import InputFormat, Page from docling.datamodel.pipeline_options import ( VlmPipelineOptions, ) @@ -49,6 +49,54 @@ def ollama_vlm_options(model: str, prompt: str): return options +#### Using Ollama with OlmOcr + + +def ollama_olmocr_vlm_options(model: str): + def _dynamic_olmocr_prompt(page: Page): + anchor = [f"Page dimensions: {int(page.size.width)}x{int(page.size.height)}"] + + for cell in page._backend.get_text_cells(): + if not cell.text.strip(): + continue + bbox = cell.to_bounding_box().to_bottom_left_origin(page.size.height) + anchor.append(f"[{int(bbox.l)}x{int(bbox.b)}] {cell.text}") + + for rect in page._backend.get_bitmap_rects(): + bbox = rect.to_bottom_left_origin(page.size.height) + anchor.append( + f"[Image {int(bbox.l)}x{int(bbox.b)} to {int(bbox.r)}x{int(bbox.t)}]" + ) + + if len(anchor) == 1: + anchor.append( + f"[Image 0x0 to {int(page.size.width)}x{int(page.size.height)}]" + ) + + base_text = "\n".join(anchor) + + return ( + f"Below is the image of one page of a document, as well as some raw textual" + f" content that was previously extracted for it. Just return the plain text" + f" representation of this document as if you were reading it naturally.\n" + f"Do not hallucinate.\n" + f"RAW_TEXT_START\n{base_text}\nRAW_TEXT_END" + ) + + options = ApiVlmOptions( + url="http://localhost:11434/v1/chat/completions", # the default Ollama endpoint + params=dict( + model=model, + ), + prompt=_dynamic_olmocr_prompt, + timeout=90, + scale=1.0, + max_size=1024, # from OlmOcr pipeline + response_format=ResponseFormat.MARKDOWN, + ) + return options + + #### Using a cloud service like IBM watsonx.ai @@ -130,6 +178,12 @@ def main(): # prompt="OCR the full page to markdown.", # ) + # Example using the OlmOcr (dynamic prompt) model with Ollama: + # (uncomment the following lines) + # pipeline_options.vlm_options = ollama_olmocr_vlm_options( + # model="hf.co/mradermacher/olmOCR-7B-0225-preview-GGUF:Q8_0", + # ) + # Another possibility is using online services, e.g. watsonx.ai. # Using requires setting the env variables WX_API_KEY and WX_PROJECT_ID. # (uncomment the following lines)