From 7b7a3a200404c6cfd48c26a888cee2c5525dedc8 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Wed, 9 Apr 2025 09:23:28 -0600 Subject: [PATCH] refactor: Refactor from Ollama SDK to generic OpenAI API Branch: OllamaVlmModel Signed-off-by: Gabe Goodhart --- docling/datamodel/pipeline_options.py | 14 ++-- docling/models/ollama_vlm_model.py | 94 --------------------------- docling/models/openai_vlm_model.py | 55 ++++++++++++++++ docling/pipeline/vlm_pipeline.py | 8 +-- 4 files changed, 67 insertions(+), 104 deletions(-) delete mode 100644 docling/models/ollama_vlm_model.py create mode 100644 docling/models/openai_vlm_model.py diff --git a/docling/datamodel/pipeline_options.py b/docling/datamodel/pipeline_options.py index d6a6c5bd..aeb8f7b2 100644 --- a/docling/datamodel/pipeline_options.py +++ b/docling/datamodel/pipeline_options.py @@ -266,7 +266,7 @@ class ResponseFormat(str, Enum): class InferenceFramework(str, Enum): MLX = "mlx" TRANSFORMERS = "transformers" - OLLAMA = "ollama" + OPENAI = "openai" class HuggingFaceVlmOptions(BaseVlmOptions): @@ -285,13 +285,14 @@ class HuggingFaceVlmOptions(BaseVlmOptions): return self.repo_id.replace("/", "--") -class OllamaVlmOptions(BaseVlmOptions): - kind: Literal["ollama_model_options"] = "ollama_model_options" +class OpenAiVlmOptions(BaseVlmOptions): + kind: Literal["openai_model_options"] = "openai_model_options" model_id: str - base_url: str = "http://localhost:11434" - num_ctx: int | None = None + base_url: str = "http://localhost:11434/v1" # Default to ollama + apikey: str | None = None, scale: float = 2.0 + timeout: float = 60 response_format: ResponseFormat @@ -318,10 +319,11 @@ granite_vision_vlm_conversion_options = HuggingFaceVlmOptions( inference_framework=InferenceFramework.TRANSFORMERS, ) -granite_vision_vlm_ollama_conversion_options = OllamaVlmOptions( +granite_vision_vlm_ollama_conversion_options = OpenAiVlmOptions( model_id="granite3.2-vision:2b", prompt="OCR the full page to markdown.", scale = 1.0, + timeout = 120, response_format=ResponseFormat.MARKDOWN, ) diff --git a/docling/models/ollama_vlm_model.py b/docling/models/ollama_vlm_model.py deleted file mode 100644 index 48fc3521..00000000 --- a/docling/models/ollama_vlm_model.py +++ /dev/null @@ -1,94 +0,0 @@ -import base64 -import io -import logging -import time -from pathlib import Path -from typing import Iterable, Optional - -from PIL import Image -import ollama - -from docling.datamodel.base_models import Page, VlmPrediction -from docling.datamodel.document import ConversionResult -from docling.datamodel.pipeline_options import ( - AcceleratorDevice, - AcceleratorOptions, - OllamaVlmOptions, -) -from docling.datamodel.settings import settings -from docling.models.base_model import BasePageModel -from docling.utils.accelerator_utils import decide_device -from docling.utils.profiling import TimeRecorder - -_log = logging.getLogger(__name__) - - -class OllamaVlmModel(BasePageModel): - - def __init__( - self, - enabled: bool, - vlm_options: OllamaVlmOptions, - ): - self.enabled = enabled - self.vlm_options = vlm_options - if self.enabled: - self.client = ollama.Client(self.vlm_options.base_url) - self.model_id = self.vlm_options.model_id - self.client.pull(self.model_id) - self.options = {} - self.prompt_content = f"This is a page from a document.\n{self.vlm_options.prompt}" - if self.vlm_options.num_ctx: - self.options["num_ctx"] = self.vlm_options.num_ctx - - @staticmethod - def _encode_image(image: Image) -> str: - img_byte_arr = io.BytesIO() - image.save(img_byte_arr, format="png") - return base64.b64encode(img_byte_arr.getvalue()).decode("utf-8") - - def __call__( - self, conv_res: ConversionResult, page_batch: Iterable[Page] - ) -> Iterable[Page]: - for page in page_batch: - assert page._backend is not None - if not page._backend.is_valid(): - yield page - else: - with TimeRecorder(conv_res, "vlm"): - assert page.size is not None - - hi_res_image = page.get_image(scale=self.vlm_options.scale) - - # populate page_tags with predicted doc tags - page_tags = "" - - if hi_res_image: - if hi_res_image.mode != "RGB": - hi_res_image = hi_res_image.convert("RGB") - - res = self.client.chat( - model=self.model_id, - messages=[ - { - "role": "user", - "content": self.prompt_content, - "images": [self._encode_image(hi_res_image)], - }, - ], - options={ - "temperature": 0, - } - ) - page_tags = res.message.content - - # inference_time = time.time() - start_time - # tokens_per_second = num_tokens / generation_time - # print("") - # print(f"Page Inference Time: {inference_time:.2f} seconds") - # print(f"Total tokens on page: {num_tokens:.2f}") - # print(f"Tokens/sec: {tokens_per_second:.2f}") - # print("") - page.predictions.vlm_response = VlmPrediction(text=page_tags) - - yield page diff --git a/docling/models/openai_vlm_model.py b/docling/models/openai_vlm_model.py new file mode 100644 index 00000000..4a1baa9d --- /dev/null +++ b/docling/models/openai_vlm_model.py @@ -0,0 +1,55 @@ +from typing import Iterable + +from docling.datamodel.base_models import Page, VlmPrediction +from docling.datamodel.document import ConversionResult +from docling.datamodel.pipeline_options import OpenAiVlmOptions +from docling.models.base_model import BasePageModel +from docling.utils.profiling import TimeRecorder +from docling.utils.utils import openai_image_request + + +class OpenAiVlmModel(BasePageModel): + + def __init__( + self, + enabled: bool, + vlm_options: OpenAiVlmOptions, + ): + self.enabled = enabled + self.vlm_options = vlm_options + if self.enabled: + self.url = "/".join([self.vlm_options.base_url.rstrip("/"), "chat/completions"]) + self.apikey = self.vlm_options.apikey + self.model_id = self.vlm_options.model_id + self.timeout = self.vlm_options.timeout + self.prompt_content = f"This is a page from a document.\n{self.vlm_options.prompt}" + + def __call__( + self, conv_res: ConversionResult, page_batch: Iterable[Page] + ) -> Iterable[Page]: + for page in page_batch: + assert page._backend is not None + if not page._backend.is_valid(): + yield page + else: + with TimeRecorder(conv_res, "vlm"): + assert page.size is not None + + hi_res_image = page.get_image(scale=self.vlm_options.scale) + if hi_res_image: + if hi_res_image.mode != "RGB": + hi_res_image = hi_res_image.convert("RGB") + + page_tags = openai_image_request( + image=hi_res_image, + prompt=self.prompt_content, + url=self.url, + apikey=self.apikey, + timeout=self.timeout, + model=self.model_id, + temperature=0, + ) + + page.predictions.vlm_response = VlmPrediction(text=page_tags) + + yield page diff --git a/docling/pipeline/vlm_pipeline.py b/docling/pipeline/vlm_pipeline.py index 3590d068..46d936dc 100644 --- a/docling/pipeline/vlm_pipeline.py +++ b/docling/pipeline/vlm_pipeline.py @@ -17,14 +17,14 @@ from docling.datamodel.document import ConversionResult, InputDocument from docling.datamodel.pipeline_options import ( HuggingFaceVlmOptions, InferenceFramework, - OllamaVlmOptions, + OpenAiVlmOptions, ResponseFormat, VlmPipelineOptions, ) from docling.datamodel.settings import settings from docling.models.hf_mlx_model import HuggingFaceMlxModel from docling.models.hf_vlm_model import HuggingFaceVlmModel -from docling.models.ollama_vlm_model import OllamaVlmModel +from docling.models.openai_vlm_model import OpenAiVlmModel from docling.pipeline.base_pipeline import PaginatedPipeline from docling.utils.profiling import ProfilingScope, TimeRecorder @@ -60,9 +60,9 @@ class VlmPipeline(PaginatedPipeline): self.keep_images = self.pipeline_options.generate_page_images - if isinstance(pipeline_options.vlm_options, OllamaVlmOptions): + if isinstance(pipeline_options.vlm_options, OpenAiVlmOptions): self.build_pipe = [ - OllamaVlmModel( + OpenAiVlmModel( enabled=True, # must be always enabled for this pipeline to make sense. vlm_options=self.pipeline_options.vlm_options, ),