diff --git a/docling/datamodel/pipeline_options_vlm_model.py b/docling/datamodel/pipeline_options_vlm_model.py index bcea2493..f6869475 100644 --- a/docling/datamodel/pipeline_options_vlm_model.py +++ b/docling/datamodel/pipeline_options_vlm_model.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Any, Callable, Dict, List, Literal, Optional, Union +from typing import Any, Dict, List, Literal, Optional from docling_core.types.doc.page import SegmentedPage from pydantic import AnyUrl, BaseModel @@ -10,11 +10,17 @@ from docling.datamodel.accelerator_options import AcceleratorDevice class BaseVlmOptions(BaseModel): kind: str - prompt: Union[str, Callable[[Optional[SegmentedPage]], str]] + prompt: str scale: float = 2.0 max_size: Optional[int] = None temperature: float = 0.0 + def build_prompt(self, page: Optional[SegmentedPage]) -> str: + return self.prompt + + def decode_response(self, text: str) -> str: + return text + class ResponseFormat(str, Enum): DOCTAGS = "doctags" diff --git a/docling/models/api_vlm_model.py b/docling/models/api_vlm_model.py index 164ac285..c48aa0bc 100644 --- a/docling/models/api_vlm_model.py +++ b/docling/models/api_vlm_model.py @@ -53,11 +53,7 @@ class ApiVlmModel(BasePageModel): if hi_res_image.mode != "RGB": hi_res_image = hi_res_image.convert("RGB") - if callable(self.vlm_options.prompt): - prompt = self.vlm_options.prompt(page.parsed_page) - else: - prompt = self.vlm_options.prompt - + prompt = self.vlm_options.build_prompt(page.parsed_page) page_tags = api_image_request( image=hi_res_image, prompt=prompt, @@ -67,6 +63,7 @@ class ApiVlmModel(BasePageModel): **self.params, ) + page_tags = self.vlm_options.decode_response(page_tags) page.predictions.vlm_response = VlmPrediction(text=page_tags) return page diff --git a/docling/models/vlm_models_inline/hf_transformers_model.py b/docling/models/vlm_models_inline/hf_transformers_model.py index d84925dd..5c1f3ea0 100644 --- a/docling/models/vlm_models_inline/hf_transformers_model.py +++ b/docling/models/vlm_models_inline/hf_transformers_model.py @@ -135,10 +135,7 @@ class HuggingFaceTransformersVlmModel(BasePageModel, HuggingFaceModelDownloadMix ) # Define prompt structure - if callable(self.vlm_options.prompt): - user_prompt = self.vlm_options.prompt(page.parsed_page) - else: - user_prompt = self.vlm_options.prompt + user_prompt = self.vlm_options.build_prompt(page.parsed_page) prompt = self.formulate_prompt(user_prompt) inputs = self.processor( @@ -166,6 +163,7 @@ class HuggingFaceTransformersVlmModel(BasePageModel, HuggingFaceModelDownloadMix _log.debug( f"Generated {num_tokens} tokens in time {generation_time:.2f} seconds." ) + generated_texts = self.vlm_options.decode_response(generated_texts) page.predictions.vlm_response = VlmPrediction( text=generated_texts, generation_time=generation_time, diff --git a/docling/models/vlm_models_inline/mlx_model.py b/docling/models/vlm_models_inline/mlx_model.py index 4b37fb48..dfb27869 100644 --- a/docling/models/vlm_models_inline/mlx_model.py +++ b/docling/models/vlm_models_inline/mlx_model.py @@ -84,10 +84,7 @@ class HuggingFaceMlxModel(BasePageModel, HuggingFaceModelDownloadMixin): if hi_res_image.mode != "RGB": hi_res_image = hi_res_image.convert("RGB") - if callable(self.vlm_options.prompt): - user_prompt = self.vlm_options.prompt(page.parsed_page) - else: - user_prompt = self.vlm_options.prompt + user_prompt = self.vlm_options.build_prompt(page.parsed_page) prompt = self.apply_chat_template( self.processor, self.config, user_prompt, num_images=1 ) @@ -142,6 +139,7 @@ class HuggingFaceMlxModel(BasePageModel, HuggingFaceModelDownloadMixin): _log.debug( f"{generation_time:.2f} seconds for {len(tokens)} tokens ({len(tokens) / generation_time} tokens/sec)." ) + page_tags = self.vlm_options.decode_response(page_tags) page.predictions.vlm_response = VlmPrediction( text=page_tags, generation_time=generation_time, diff --git a/docs/examples/vlm_pipeline_api_model.py b/docs/examples/vlm_pipeline_api_model.py index a809b926..7dc4c8b9 100644 --- a/docs/examples/vlm_pipeline_api_model.py +++ b/docs/examples/vlm_pipeline_api_model.py @@ -1,3 +1,4 @@ +import json import logging import os from pathlib import Path @@ -38,57 +39,63 @@ def lms_vlm_options(model: str, prompt: str, format: ResponseFormat): def lms_olmocr_vlm_options(model: str): - def _dynamic_olmocr_prompt(page: Optional[SegmentedPage]): - if page is None: - return ( - "Below is the image of one page of a document. Just return the plain text" - " representation of this document as if you were reading it naturally.\n" - "Do not hallucinate.\n" - ) + class OlmocrVlmOptions(ApiVlmOptions): + def build_prompt(self, page: Optional[SegmentedPage]) -> str: + if page is None: + return self.prompt.replace("#RAW_TEXT#", "") - anchor = [ - f"Page dimensions: {int(page.dimension.width)}x{int(page.dimension.height)}" - ] + anchor = [ + f"Page dimensions: {int(page.dimension.width)}x{int(page.dimension.height)}" + ] - for text_cell in page.textline_cells: - if not text_cell.text.strip(): - continue - bbox = text_cell.rect.to_bounding_box().to_bottom_left_origin( - page.dimension.height - ) - anchor.append(f"[{int(bbox.l)}x{int(bbox.b)}] {text_cell.text}") + for text_cell in page.textline_cells: + if not text_cell.text.strip(): + continue + bbox = text_cell.rect.to_bounding_box().to_bottom_left_origin( + page.dimension.height + ) + anchor.append(f"[{int(bbox.l)}x{int(bbox.b)}] {text_cell.text}") - for image_cell in page.bitmap_resources: - bbox = image_cell.rect.to_bounding_box().to_bottom_left_origin( - page.dimension.height - ) - anchor.append( - f"[Image {int(bbox.l)}x{int(bbox.b)} to {int(bbox.r)}x{int(bbox.t)}]" - ) + for image_cell in page.bitmap_resources: + bbox = image_cell.rect.to_bounding_box().to_bottom_left_origin( + page.dimension.height + ) + anchor.append( + f"[Image {int(bbox.l)}x{int(bbox.b)} to {int(bbox.r)}x{int(bbox.t)}]" + ) - if len(anchor) == 1: - anchor.append( - f"[Image 0x0 to {int(page.dimension.width)}x{int(page.dimension.height)}]" - ) + if len(anchor) == 1: + anchor.append( + f"[Image 0x0 to {int(page.dimension.width)}x{int(page.dimension.height)}]" + ) - # Original prompt uses cells sorting. We are skipping it in this demo. + # Original prompt uses cells sorting. We are skipping it for simplicity. - base_text = "\n".join(anchor) + raw_text = "\n".join(anchor) - return ( - f"Below is the image of one page of a document, as well as some raw textual" - f" content that was previously extracted for it. Just return the plain text" - f" representation of this document as if you were reading it naturally.\n" - f"Do not hallucinate.\n" - f"RAW_TEXT_START\n{base_text}\nRAW_TEXT_END" - ) + return self.prompt.replace("#RAW_TEXT#", raw_text) - options = ApiVlmOptions( + def decode_response(self, text: str) -> str: + # OlmOcr trained to generate json response with language, rotation and other info + try: + generated_json = json.loads(text) + except json.decoder.JSONDecodeError: + return "" + + return generated_json["natural_text"] + + options = OlmocrVlmOptions( url="http://localhost:1234/v1/chat/completions", params=dict( model=model, ), - prompt=_dynamic_olmocr_prompt, + prompt=( + "Below is the image of one page of a document, as well as some raw textual" + " content that was previously extracted for it. Just return the plain text" + " representation of this document as if you were reading it naturally.\n" + "Do not hallucinate.\n" + "RAW_TEXT_START\n#RAW_TEXT#\nRAW_TEXT_END" + ), timeout=90, scale=1.0, max_size=1024, # from OlmOcr pipeline