diff --git a/docling/datamodel/base_models.py b/docling/datamodel/base_models.py index 76827a1b..ca879131 100644 --- a/docling/datamodel/base_models.py +++ b/docling/datamodel/base_models.py @@ -262,3 +262,34 @@ class Page(BaseModel): @property def image(self) -> Optional[Image]: return self.get_image(scale=self._default_image_scale) + + +## OpenAI API Request / Response Models ## + +class OpenAiChatMessage(BaseModel): + role: str + content: str + + +class OpenAiResponseChoice(BaseModel): + index: int + message: OpenAiChatMessage + finish_reason: str + + +class OpenAiResponseUsage(BaseModel): + prompt_tokens: int + completion_tokens: int + total_tokens: int + + +class OpenAiApiResponse(BaseModel): + model_config = ConfigDict( + protected_namespaces=(), + ) + + id: str + model: Optional[str] = None # returned by openai + choices: List[OpenAiResponseChoice] + created: int + usage: OpenAiResponseUsage diff --git a/docling/models/picture_description_api_model.py b/docling/models/picture_description_api_model.py index 6ef8a7fc..610a4f5b 100644 --- a/docling/models/picture_description_api_model.py +++ b/docling/models/picture_description_api_model.py @@ -1,12 +1,7 @@ -import base64 -import io -import logging from pathlib import Path -from typing import Iterable, List, Optional, Type, Union +from typing import Iterable, Optional, Type, Union -import requests from PIL import Image -from pydantic import BaseModel, ConfigDict from docling.datamodel.pipeline_options import ( AcceleratorOptions, @@ -15,37 +10,7 @@ from docling.datamodel.pipeline_options import ( ) from docling.exceptions import OperationNotAllowed from docling.models.picture_description_base_model import PictureDescriptionBaseModel - -_log = logging.getLogger(__name__) - - -class ChatMessage(BaseModel): - role: str - content: str - - -class ResponseChoice(BaseModel): - index: int - message: ChatMessage - finish_reason: str - - -class ResponseUsage(BaseModel): - prompt_tokens: int - completion_tokens: int - total_tokens: int - - -class ApiResponse(BaseModel): - model_config = ConfigDict( - protected_namespaces=(), - ) - - id: str - model: Optional[str] = None # returned by openai - choices: List[ResponseChoice] - created: int - usage: ResponseUsage +from docling.utils.utils import openai_image_request class PictureDescriptionApiModel(PictureDescriptionBaseModel): @@ -83,43 +48,11 @@ class PictureDescriptionApiModel(PictureDescriptionBaseModel): # Note: technically we could make a batch request here, # but not all APIs will allow for it. For example, vllm won't allow more than 1. for image in images: - img_io = io.BytesIO() - image.save(img_io, "PNG") - image_base64 = base64.b64encode(img_io.getvalue()).decode("utf-8") - - messages = [ - { - "role": "user", - "content": [ - { - "type": "text", - "text": self.options.prompt, - }, - { - "type": "image_url", - "image_url": { - "url": f"data:image/png;base64,{image_base64}" - }, - }, - ], - } - ] - - payload = { - "messages": messages, - **self.options.params, - } - - r = requests.post( - str(self.options.url), + yield openai_image_request( + image=image, + prompt=self.options.prompt, + url=self.options.url, + timeout=self.options.headers, headers=self.options.headers, - json=payload, - timeout=self.options.timeout, + **self.options.params, ) - if not r.ok: - _log.error(f"Error calling the API. Reponse was {r.text}") - r.raise_for_status() - - api_resp = ApiResponse.model_validate_json(r.text) - generated_text = api_resp.choices[0].message.content.strip() - yield generated_text diff --git a/docling/utils/utils.py b/docling/utils/utils.py index 1261f860..98c3e692 100644 --- a/docling/utils/utils.py +++ b/docling/utils/utils.py @@ -1,12 +1,19 @@ +import base64 import hashlib +import logging from io import BytesIO from itertools import islice from pathlib import Path from typing import List, Union import requests +from PIL import Image from tqdm import tqdm +from docling.datamodel.base_models import OpenAiApiResponse + +_log = logging.getLogger(__name__) + def chunkify(iterator, chunk_size): """Yield successive chunks of chunk_size from the iterable.""" @@ -63,3 +70,57 @@ def download_url_with_progress(url: str, progress: bool = False) -> BytesIO: buf.seek(0) return buf + + +def openai_image_request( + image: Image.Image, + prompt: str, + url: str = "http://localhost:11434/v1/chat/completions", # Default to ollama + apikey: str | None = None, + timeout: float = 20, + headers: dict[str, str] | None = None, + **params, +) -> str: + img_io = BytesIO() + image.save(img_io, "PNG") + image_base64 = base64.b64encode(img_io.getvalue()).decode("utf-8") + messages = [ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": f"data:image/png;base64,{image_base64}" + }, + }, + { + "type": "text", + "text": prompt, + }, + ], + } + ] + + payload = { + "messages": messages, + **params, + } + + headers = headers or {} + if apikey is not None: + headers["Authorization"] = f"Bearer {apikey}" + + r = requests.post( + str(url), + headers=headers, + json=payload, + timeout=timeout, + ) + if not r.ok: + _log.error(f"Error calling the API. Response was {r.text}") + r.raise_for_status() + + api_resp = OpenAiApiResponse.model_validate_json(r.text) + generated_text = api_resp.choices[0].message.content.strip() + return generated_text