refactor: Refactor from Ollama SDK to generic OpenAI API

Branch: OllamaVlmModel

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
This commit is contained in:
Gabe Goodhart 2025-04-09 09:23:28 -06:00
parent ad1541e8cf
commit 7b7a3a2004
4 changed files with 67 additions and 104 deletions

View File

@ -266,7 +266,7 @@ class ResponseFormat(str, Enum):
class InferenceFramework(str, Enum):
MLX = "mlx"
TRANSFORMERS = "transformers"
OLLAMA = "ollama"
OPENAI = "openai"
class HuggingFaceVlmOptions(BaseVlmOptions):
@ -285,13 +285,14 @@ class HuggingFaceVlmOptions(BaseVlmOptions):
return self.repo_id.replace("/", "--")
class OllamaVlmOptions(BaseVlmOptions):
kind: Literal["ollama_model_options"] = "ollama_model_options"
class OpenAiVlmOptions(BaseVlmOptions):
kind: Literal["openai_model_options"] = "openai_model_options"
model_id: str
base_url: str = "http://localhost:11434"
num_ctx: int | None = None
base_url: str = "http://localhost:11434/v1" # Default to ollama
apikey: str | None = None,
scale: float = 2.0
timeout: float = 60
response_format: ResponseFormat
@ -318,10 +319,11 @@ granite_vision_vlm_conversion_options = HuggingFaceVlmOptions(
inference_framework=InferenceFramework.TRANSFORMERS,
)
granite_vision_vlm_ollama_conversion_options = OllamaVlmOptions(
granite_vision_vlm_ollama_conversion_options = OpenAiVlmOptions(
model_id="granite3.2-vision:2b",
prompt="OCR the full page to markdown.",
scale = 1.0,
timeout = 120,
response_format=ResponseFormat.MARKDOWN,
)

View File

@ -1,94 +0,0 @@
import base64
import io
import logging
import time
from pathlib import Path
from typing import Iterable, Optional
from PIL import Image
import ollama
from docling.datamodel.base_models import Page, VlmPrediction
from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options import (
AcceleratorDevice,
AcceleratorOptions,
OllamaVlmOptions,
)
from docling.datamodel.settings import settings
from docling.models.base_model import BasePageModel
from docling.utils.accelerator_utils import decide_device
from docling.utils.profiling import TimeRecorder
_log = logging.getLogger(__name__)
class OllamaVlmModel(BasePageModel):
def __init__(
self,
enabled: bool,
vlm_options: OllamaVlmOptions,
):
self.enabled = enabled
self.vlm_options = vlm_options
if self.enabled:
self.client = ollama.Client(self.vlm_options.base_url)
self.model_id = self.vlm_options.model_id
self.client.pull(self.model_id)
self.options = {}
self.prompt_content = f"This is a page from a document.\n{self.vlm_options.prompt}"
if self.vlm_options.num_ctx:
self.options["num_ctx"] = self.vlm_options.num_ctx
@staticmethod
def _encode_image(image: Image) -> str:
img_byte_arr = io.BytesIO()
image.save(img_byte_arr, format="png")
return base64.b64encode(img_byte_arr.getvalue()).decode("utf-8")
def __call__(
self, conv_res: ConversionResult, page_batch: Iterable[Page]
) -> Iterable[Page]:
for page in page_batch:
assert page._backend is not None
if not page._backend.is_valid():
yield page
else:
with TimeRecorder(conv_res, "vlm"):
assert page.size is not None
hi_res_image = page.get_image(scale=self.vlm_options.scale)
# populate page_tags with predicted doc tags
page_tags = ""
if hi_res_image:
if hi_res_image.mode != "RGB":
hi_res_image = hi_res_image.convert("RGB")
res = self.client.chat(
model=self.model_id,
messages=[
{
"role": "user",
"content": self.prompt_content,
"images": [self._encode_image(hi_res_image)],
},
],
options={
"temperature": 0,
}
)
page_tags = res.message.content
# inference_time = time.time() - start_time
# tokens_per_second = num_tokens / generation_time
# print("")
# print(f"Page Inference Time: {inference_time:.2f} seconds")
# print(f"Total tokens on page: {num_tokens:.2f}")
# print(f"Tokens/sec: {tokens_per_second:.2f}")
# print("")
page.predictions.vlm_response = VlmPrediction(text=page_tags)
yield page

View File

@ -0,0 +1,55 @@
from typing import Iterable
from docling.datamodel.base_models import Page, VlmPrediction
from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options import OpenAiVlmOptions
from docling.models.base_model import BasePageModel
from docling.utils.profiling import TimeRecorder
from docling.utils.utils import openai_image_request
class OpenAiVlmModel(BasePageModel):
def __init__(
self,
enabled: bool,
vlm_options: OpenAiVlmOptions,
):
self.enabled = enabled
self.vlm_options = vlm_options
if self.enabled:
self.url = "/".join([self.vlm_options.base_url.rstrip("/"), "chat/completions"])
self.apikey = self.vlm_options.apikey
self.model_id = self.vlm_options.model_id
self.timeout = self.vlm_options.timeout
self.prompt_content = f"This is a page from a document.\n{self.vlm_options.prompt}"
def __call__(
self, conv_res: ConversionResult, page_batch: Iterable[Page]
) -> Iterable[Page]:
for page in page_batch:
assert page._backend is not None
if not page._backend.is_valid():
yield page
else:
with TimeRecorder(conv_res, "vlm"):
assert page.size is not None
hi_res_image = page.get_image(scale=self.vlm_options.scale)
if hi_res_image:
if hi_res_image.mode != "RGB":
hi_res_image = hi_res_image.convert("RGB")
page_tags = openai_image_request(
image=hi_res_image,
prompt=self.prompt_content,
url=self.url,
apikey=self.apikey,
timeout=self.timeout,
model=self.model_id,
temperature=0,
)
page.predictions.vlm_response = VlmPrediction(text=page_tags)
yield page

View File

@ -17,14 +17,14 @@ from docling.datamodel.document import ConversionResult, InputDocument
from docling.datamodel.pipeline_options import (
HuggingFaceVlmOptions,
InferenceFramework,
OllamaVlmOptions,
OpenAiVlmOptions,
ResponseFormat,
VlmPipelineOptions,
)
from docling.datamodel.settings import settings
from docling.models.hf_mlx_model import HuggingFaceMlxModel
from docling.models.hf_vlm_model import HuggingFaceVlmModel
from docling.models.ollama_vlm_model import OllamaVlmModel
from docling.models.openai_vlm_model import OpenAiVlmModel
from docling.pipeline.base_pipeline import PaginatedPipeline
from docling.utils.profiling import ProfilingScope, TimeRecorder
@ -60,9 +60,9 @@ class VlmPipeline(PaginatedPipeline):
self.keep_images = self.pipeline_options.generate_page_images
if isinstance(pipeline_options.vlm_options, OllamaVlmOptions):
if isinstance(pipeline_options.vlm_options, OpenAiVlmOptions):
self.build_pipe = [
OllamaVlmModel(
OpenAiVlmModel(
enabled=True, # must be always enabled for this pipeline to make sense.
vlm_options=self.pipeline_options.vlm_options,
),