mirror of
https://github.com/DS4SD/docling.git
synced 2025-07-27 04:24:45 +00:00
refactor: Refactor from Ollama SDK to generic OpenAI API
Branch: OllamaVlmModel Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
This commit is contained in:
parent
ad1541e8cf
commit
7b7a3a2004
@ -266,7 +266,7 @@ class ResponseFormat(str, Enum):
|
|||||||
class InferenceFramework(str, Enum):
|
class InferenceFramework(str, Enum):
|
||||||
MLX = "mlx"
|
MLX = "mlx"
|
||||||
TRANSFORMERS = "transformers"
|
TRANSFORMERS = "transformers"
|
||||||
OLLAMA = "ollama"
|
OPENAI = "openai"
|
||||||
|
|
||||||
|
|
||||||
class HuggingFaceVlmOptions(BaseVlmOptions):
|
class HuggingFaceVlmOptions(BaseVlmOptions):
|
||||||
@ -285,13 +285,14 @@ class HuggingFaceVlmOptions(BaseVlmOptions):
|
|||||||
return self.repo_id.replace("/", "--")
|
return self.repo_id.replace("/", "--")
|
||||||
|
|
||||||
|
|
||||||
class OllamaVlmOptions(BaseVlmOptions):
|
class OpenAiVlmOptions(BaseVlmOptions):
|
||||||
kind: Literal["ollama_model_options"] = "ollama_model_options"
|
kind: Literal["openai_model_options"] = "openai_model_options"
|
||||||
|
|
||||||
model_id: str
|
model_id: str
|
||||||
base_url: str = "http://localhost:11434"
|
base_url: str = "http://localhost:11434/v1" # Default to ollama
|
||||||
num_ctx: int | None = None
|
apikey: str | None = None,
|
||||||
scale: float = 2.0
|
scale: float = 2.0
|
||||||
|
timeout: float = 60
|
||||||
response_format: ResponseFormat
|
response_format: ResponseFormat
|
||||||
|
|
||||||
|
|
||||||
@ -318,10 +319,11 @@ granite_vision_vlm_conversion_options = HuggingFaceVlmOptions(
|
|||||||
inference_framework=InferenceFramework.TRANSFORMERS,
|
inference_framework=InferenceFramework.TRANSFORMERS,
|
||||||
)
|
)
|
||||||
|
|
||||||
granite_vision_vlm_ollama_conversion_options = OllamaVlmOptions(
|
granite_vision_vlm_ollama_conversion_options = OpenAiVlmOptions(
|
||||||
model_id="granite3.2-vision:2b",
|
model_id="granite3.2-vision:2b",
|
||||||
prompt="OCR the full page to markdown.",
|
prompt="OCR the full page to markdown.",
|
||||||
scale = 1.0,
|
scale = 1.0,
|
||||||
|
timeout = 120,
|
||||||
response_format=ResponseFormat.MARKDOWN,
|
response_format=ResponseFormat.MARKDOWN,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1,94 +0,0 @@
|
|||||||
import base64
|
|
||||||
import io
|
|
||||||
import logging
|
|
||||||
import time
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Iterable, Optional
|
|
||||||
|
|
||||||
from PIL import Image
|
|
||||||
import ollama
|
|
||||||
|
|
||||||
from docling.datamodel.base_models import Page, VlmPrediction
|
|
||||||
from docling.datamodel.document import ConversionResult
|
|
||||||
from docling.datamodel.pipeline_options import (
|
|
||||||
AcceleratorDevice,
|
|
||||||
AcceleratorOptions,
|
|
||||||
OllamaVlmOptions,
|
|
||||||
)
|
|
||||||
from docling.datamodel.settings import settings
|
|
||||||
from docling.models.base_model import BasePageModel
|
|
||||||
from docling.utils.accelerator_utils import decide_device
|
|
||||||
from docling.utils.profiling import TimeRecorder
|
|
||||||
|
|
||||||
_log = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class OllamaVlmModel(BasePageModel):
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
enabled: bool,
|
|
||||||
vlm_options: OllamaVlmOptions,
|
|
||||||
):
|
|
||||||
self.enabled = enabled
|
|
||||||
self.vlm_options = vlm_options
|
|
||||||
if self.enabled:
|
|
||||||
self.client = ollama.Client(self.vlm_options.base_url)
|
|
||||||
self.model_id = self.vlm_options.model_id
|
|
||||||
self.client.pull(self.model_id)
|
|
||||||
self.options = {}
|
|
||||||
self.prompt_content = f"This is a page from a document.\n{self.vlm_options.prompt}"
|
|
||||||
if self.vlm_options.num_ctx:
|
|
||||||
self.options["num_ctx"] = self.vlm_options.num_ctx
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _encode_image(image: Image) -> str:
|
|
||||||
img_byte_arr = io.BytesIO()
|
|
||||||
image.save(img_byte_arr, format="png")
|
|
||||||
return base64.b64encode(img_byte_arr.getvalue()).decode("utf-8")
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self, conv_res: ConversionResult, page_batch: Iterable[Page]
|
|
||||||
) -> Iterable[Page]:
|
|
||||||
for page in page_batch:
|
|
||||||
assert page._backend is not None
|
|
||||||
if not page._backend.is_valid():
|
|
||||||
yield page
|
|
||||||
else:
|
|
||||||
with TimeRecorder(conv_res, "vlm"):
|
|
||||||
assert page.size is not None
|
|
||||||
|
|
||||||
hi_res_image = page.get_image(scale=self.vlm_options.scale)
|
|
||||||
|
|
||||||
# populate page_tags with predicted doc tags
|
|
||||||
page_tags = ""
|
|
||||||
|
|
||||||
if hi_res_image:
|
|
||||||
if hi_res_image.mode != "RGB":
|
|
||||||
hi_res_image = hi_res_image.convert("RGB")
|
|
||||||
|
|
||||||
res = self.client.chat(
|
|
||||||
model=self.model_id,
|
|
||||||
messages=[
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": self.prompt_content,
|
|
||||||
"images": [self._encode_image(hi_res_image)],
|
|
||||||
},
|
|
||||||
],
|
|
||||||
options={
|
|
||||||
"temperature": 0,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
page_tags = res.message.content
|
|
||||||
|
|
||||||
# inference_time = time.time() - start_time
|
|
||||||
# tokens_per_second = num_tokens / generation_time
|
|
||||||
# print("")
|
|
||||||
# print(f"Page Inference Time: {inference_time:.2f} seconds")
|
|
||||||
# print(f"Total tokens on page: {num_tokens:.2f}")
|
|
||||||
# print(f"Tokens/sec: {tokens_per_second:.2f}")
|
|
||||||
# print("")
|
|
||||||
page.predictions.vlm_response = VlmPrediction(text=page_tags)
|
|
||||||
|
|
||||||
yield page
|
|
55
docling/models/openai_vlm_model.py
Normal file
55
docling/models/openai_vlm_model.py
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
from typing import Iterable
|
||||||
|
|
||||||
|
from docling.datamodel.base_models import Page, VlmPrediction
|
||||||
|
from docling.datamodel.document import ConversionResult
|
||||||
|
from docling.datamodel.pipeline_options import OpenAiVlmOptions
|
||||||
|
from docling.models.base_model import BasePageModel
|
||||||
|
from docling.utils.profiling import TimeRecorder
|
||||||
|
from docling.utils.utils import openai_image_request
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAiVlmModel(BasePageModel):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
enabled: bool,
|
||||||
|
vlm_options: OpenAiVlmOptions,
|
||||||
|
):
|
||||||
|
self.enabled = enabled
|
||||||
|
self.vlm_options = vlm_options
|
||||||
|
if self.enabled:
|
||||||
|
self.url = "/".join([self.vlm_options.base_url.rstrip("/"), "chat/completions"])
|
||||||
|
self.apikey = self.vlm_options.apikey
|
||||||
|
self.model_id = self.vlm_options.model_id
|
||||||
|
self.timeout = self.vlm_options.timeout
|
||||||
|
self.prompt_content = f"This is a page from a document.\n{self.vlm_options.prompt}"
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self, conv_res: ConversionResult, page_batch: Iterable[Page]
|
||||||
|
) -> Iterable[Page]:
|
||||||
|
for page in page_batch:
|
||||||
|
assert page._backend is not None
|
||||||
|
if not page._backend.is_valid():
|
||||||
|
yield page
|
||||||
|
else:
|
||||||
|
with TimeRecorder(conv_res, "vlm"):
|
||||||
|
assert page.size is not None
|
||||||
|
|
||||||
|
hi_res_image = page.get_image(scale=self.vlm_options.scale)
|
||||||
|
if hi_res_image:
|
||||||
|
if hi_res_image.mode != "RGB":
|
||||||
|
hi_res_image = hi_res_image.convert("RGB")
|
||||||
|
|
||||||
|
page_tags = openai_image_request(
|
||||||
|
image=hi_res_image,
|
||||||
|
prompt=self.prompt_content,
|
||||||
|
url=self.url,
|
||||||
|
apikey=self.apikey,
|
||||||
|
timeout=self.timeout,
|
||||||
|
model=self.model_id,
|
||||||
|
temperature=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
page.predictions.vlm_response = VlmPrediction(text=page_tags)
|
||||||
|
|
||||||
|
yield page
|
@ -17,14 +17,14 @@ from docling.datamodel.document import ConversionResult, InputDocument
|
|||||||
from docling.datamodel.pipeline_options import (
|
from docling.datamodel.pipeline_options import (
|
||||||
HuggingFaceVlmOptions,
|
HuggingFaceVlmOptions,
|
||||||
InferenceFramework,
|
InferenceFramework,
|
||||||
OllamaVlmOptions,
|
OpenAiVlmOptions,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
VlmPipelineOptions,
|
VlmPipelineOptions,
|
||||||
)
|
)
|
||||||
from docling.datamodel.settings import settings
|
from docling.datamodel.settings import settings
|
||||||
from docling.models.hf_mlx_model import HuggingFaceMlxModel
|
from docling.models.hf_mlx_model import HuggingFaceMlxModel
|
||||||
from docling.models.hf_vlm_model import HuggingFaceVlmModel
|
from docling.models.hf_vlm_model import HuggingFaceVlmModel
|
||||||
from docling.models.ollama_vlm_model import OllamaVlmModel
|
from docling.models.openai_vlm_model import OpenAiVlmModel
|
||||||
from docling.pipeline.base_pipeline import PaginatedPipeline
|
from docling.pipeline.base_pipeline import PaginatedPipeline
|
||||||
from docling.utils.profiling import ProfilingScope, TimeRecorder
|
from docling.utils.profiling import ProfilingScope, TimeRecorder
|
||||||
|
|
||||||
@ -60,9 +60,9 @@ class VlmPipeline(PaginatedPipeline):
|
|||||||
|
|
||||||
self.keep_images = self.pipeline_options.generate_page_images
|
self.keep_images = self.pipeline_options.generate_page_images
|
||||||
|
|
||||||
if isinstance(pipeline_options.vlm_options, OllamaVlmOptions):
|
if isinstance(pipeline_options.vlm_options, OpenAiVlmOptions):
|
||||||
self.build_pipe = [
|
self.build_pipe = [
|
||||||
OllamaVlmModel(
|
OpenAiVlmModel(
|
||||||
enabled=True, # must be always enabled for this pipeline to make sense.
|
enabled=True, # must be always enabled for this pipeline to make sense.
|
||||||
vlm_options=self.pipeline_options.vlm_options,
|
vlm_options=self.pipeline_options.vlm_options,
|
||||||
),
|
),
|
||||||
|
Loading…
Reference in New Issue
Block a user