mirror of
https://github.com/DS4SD/docling.git
synced 2025-07-27 04:24:45 +00:00
refactor: Move OpenAI API call logic into utils.utils
This will allow reuse of this logic in a generic VLM model NOTE: There is a subtle change here in the ordering of the text prompt and the image in the call to the OpenAI API. When run against Ollama, this ordering makes a big difference. If the prompt comes before the image, the result is terse and not usable whereas the prompt coming after the image works as expected and matches the non-OpenAI chat API. Branch: OllamaVlmModel Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
This commit is contained in:
parent
8ef0b897c8
commit
ad1541e8cf
@ -262,3 +262,34 @@ class Page(BaseModel):
|
||||
@property
|
||||
def image(self) -> Optional[Image]:
|
||||
return self.get_image(scale=self._default_image_scale)
|
||||
|
||||
|
||||
## OpenAI API Request / Response Models ##
|
||||
|
||||
class OpenAiChatMessage(BaseModel):
|
||||
role: str
|
||||
content: str
|
||||
|
||||
|
||||
class OpenAiResponseChoice(BaseModel):
|
||||
index: int
|
||||
message: OpenAiChatMessage
|
||||
finish_reason: str
|
||||
|
||||
|
||||
class OpenAiResponseUsage(BaseModel):
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
total_tokens: int
|
||||
|
||||
|
||||
class OpenAiApiResponse(BaseModel):
|
||||
model_config = ConfigDict(
|
||||
protected_namespaces=(),
|
||||
)
|
||||
|
||||
id: str
|
||||
model: Optional[str] = None # returned by openai
|
||||
choices: List[OpenAiResponseChoice]
|
||||
created: int
|
||||
usage: OpenAiResponseUsage
|
||||
|
@ -1,12 +1,7 @@
|
||||
import base64
|
||||
import io
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Iterable, List, Optional, Type, Union
|
||||
from typing import Iterable, Optional, Type, Union
|
||||
|
||||
import requests
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from docling.datamodel.pipeline_options import (
|
||||
AcceleratorOptions,
|
||||
@ -15,37 +10,7 @@ from docling.datamodel.pipeline_options import (
|
||||
)
|
||||
from docling.exceptions import OperationNotAllowed
|
||||
from docling.models.picture_description_base_model import PictureDescriptionBaseModel
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
role: str
|
||||
content: str
|
||||
|
||||
|
||||
class ResponseChoice(BaseModel):
|
||||
index: int
|
||||
message: ChatMessage
|
||||
finish_reason: str
|
||||
|
||||
|
||||
class ResponseUsage(BaseModel):
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
total_tokens: int
|
||||
|
||||
|
||||
class ApiResponse(BaseModel):
|
||||
model_config = ConfigDict(
|
||||
protected_namespaces=(),
|
||||
)
|
||||
|
||||
id: str
|
||||
model: Optional[str] = None # returned by openai
|
||||
choices: List[ResponseChoice]
|
||||
created: int
|
||||
usage: ResponseUsage
|
||||
from docling.utils.utils import openai_image_request
|
||||
|
||||
|
||||
class PictureDescriptionApiModel(PictureDescriptionBaseModel):
|
||||
@ -83,43 +48,11 @@ class PictureDescriptionApiModel(PictureDescriptionBaseModel):
|
||||
# Note: technically we could make a batch request here,
|
||||
# but not all APIs will allow for it. For example, vllm won't allow more than 1.
|
||||
for image in images:
|
||||
img_io = io.BytesIO()
|
||||
image.save(img_io, "PNG")
|
||||
image_base64 = base64.b64encode(img_io.getvalue()).decode("utf-8")
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": self.options.prompt,
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/png;base64,{image_base64}"
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
payload = {
|
||||
"messages": messages,
|
||||
**self.options.params,
|
||||
}
|
||||
|
||||
r = requests.post(
|
||||
str(self.options.url),
|
||||
yield openai_image_request(
|
||||
image=image,
|
||||
prompt=self.options.prompt,
|
||||
url=self.options.url,
|
||||
timeout=self.options.headers,
|
||||
headers=self.options.headers,
|
||||
json=payload,
|
||||
timeout=self.options.timeout,
|
||||
**self.options.params,
|
||||
)
|
||||
if not r.ok:
|
||||
_log.error(f"Error calling the API. Reponse was {r.text}")
|
||||
r.raise_for_status()
|
||||
|
||||
api_resp = ApiResponse.model_validate_json(r.text)
|
||||
generated_text = api_resp.choices[0].message.content.strip()
|
||||
yield generated_text
|
||||
|
@ -1,12 +1,19 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import logging
|
||||
from io import BytesIO
|
||||
from itertools import islice
|
||||
from pathlib import Path
|
||||
from typing import List, Union
|
||||
|
||||
import requests
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
from docling.datamodel.base_models import OpenAiApiResponse
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def chunkify(iterator, chunk_size):
|
||||
"""Yield successive chunks of chunk_size from the iterable."""
|
||||
@ -63,3 +70,57 @@ def download_url_with_progress(url: str, progress: bool = False) -> BytesIO:
|
||||
|
||||
buf.seek(0)
|
||||
return buf
|
||||
|
||||
|
||||
def openai_image_request(
|
||||
image: Image.Image,
|
||||
prompt: str,
|
||||
url: str = "http://localhost:11434/v1/chat/completions", # Default to ollama
|
||||
apikey: str | None = None,
|
||||
timeout: float = 20,
|
||||
headers: dict[str, str] | None = None,
|
||||
**params,
|
||||
) -> str:
|
||||
img_io = BytesIO()
|
||||
image.save(img_io, "PNG")
|
||||
image_base64 = base64.b64encode(img_io.getvalue()).decode("utf-8")
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/png;base64,{image_base64}"
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": prompt,
|
||||
},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
payload = {
|
||||
"messages": messages,
|
||||
**params,
|
||||
}
|
||||
|
||||
headers = headers or {}
|
||||
if apikey is not None:
|
||||
headers["Authorization"] = f"Bearer {apikey}"
|
||||
|
||||
r = requests.post(
|
||||
str(url),
|
||||
headers=headers,
|
||||
json=payload,
|
||||
timeout=timeout,
|
||||
)
|
||||
if not r.ok:
|
||||
_log.error(f"Error calling the API. Response was {r.text}")
|
||||
r.raise_for_status()
|
||||
|
||||
api_resp = OpenAiApiResponse.model_validate_json(r.text)
|
||||
generated_text = api_resp.choices[0].message.content.strip()
|
||||
return generated_text
|
||||
|
Loading…
Reference in New Issue
Block a user