mirror of
https://github.com/DS4SD/docling.git
synced 2025-07-27 04:24:45 +00:00
refactor: Move OpenAI API call logic into utils.utils
This will allow reuse of this logic in a generic VLM model NOTE: There is a subtle change here in the ordering of the text prompt and the image in the call to the OpenAI API. When run against Ollama, this ordering makes a big difference. If the prompt comes before the image, the result is terse and not usable whereas the prompt coming after the image works as expected and matches the non-OpenAI chat API. Branch: OllamaVlmModel Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
This commit is contained in:
parent
8ef0b897c8
commit
ad1541e8cf
@ -262,3 +262,34 @@ class Page(BaseModel):
|
|||||||
@property
|
@property
|
||||||
def image(self) -> Optional[Image]:
|
def image(self) -> Optional[Image]:
|
||||||
return self.get_image(scale=self._default_image_scale)
|
return self.get_image(scale=self._default_image_scale)
|
||||||
|
|
||||||
|
|
||||||
|
## OpenAI API Request / Response Models ##
|
||||||
|
|
||||||
|
class OpenAiChatMessage(BaseModel):
|
||||||
|
role: str
|
||||||
|
content: str
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAiResponseChoice(BaseModel):
|
||||||
|
index: int
|
||||||
|
message: OpenAiChatMessage
|
||||||
|
finish_reason: str
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAiResponseUsage(BaseModel):
|
||||||
|
prompt_tokens: int
|
||||||
|
completion_tokens: int
|
||||||
|
total_tokens: int
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAiApiResponse(BaseModel):
|
||||||
|
model_config = ConfigDict(
|
||||||
|
protected_namespaces=(),
|
||||||
|
)
|
||||||
|
|
||||||
|
id: str
|
||||||
|
model: Optional[str] = None # returned by openai
|
||||||
|
choices: List[OpenAiResponseChoice]
|
||||||
|
created: int
|
||||||
|
usage: OpenAiResponseUsage
|
||||||
|
@ -1,12 +1,7 @@
|
|||||||
import base64
|
|
||||||
import io
|
|
||||||
import logging
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Iterable, List, Optional, Type, Union
|
from typing import Iterable, Optional, Type, Union
|
||||||
|
|
||||||
import requests
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from pydantic import BaseModel, ConfigDict
|
|
||||||
|
|
||||||
from docling.datamodel.pipeline_options import (
|
from docling.datamodel.pipeline_options import (
|
||||||
AcceleratorOptions,
|
AcceleratorOptions,
|
||||||
@ -15,37 +10,7 @@ from docling.datamodel.pipeline_options import (
|
|||||||
)
|
)
|
||||||
from docling.exceptions import OperationNotAllowed
|
from docling.exceptions import OperationNotAllowed
|
||||||
from docling.models.picture_description_base_model import PictureDescriptionBaseModel
|
from docling.models.picture_description_base_model import PictureDescriptionBaseModel
|
||||||
|
from docling.utils.utils import openai_image_request
|
||||||
_log = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class ChatMessage(BaseModel):
|
|
||||||
role: str
|
|
||||||
content: str
|
|
||||||
|
|
||||||
|
|
||||||
class ResponseChoice(BaseModel):
|
|
||||||
index: int
|
|
||||||
message: ChatMessage
|
|
||||||
finish_reason: str
|
|
||||||
|
|
||||||
|
|
||||||
class ResponseUsage(BaseModel):
|
|
||||||
prompt_tokens: int
|
|
||||||
completion_tokens: int
|
|
||||||
total_tokens: int
|
|
||||||
|
|
||||||
|
|
||||||
class ApiResponse(BaseModel):
|
|
||||||
model_config = ConfigDict(
|
|
||||||
protected_namespaces=(),
|
|
||||||
)
|
|
||||||
|
|
||||||
id: str
|
|
||||||
model: Optional[str] = None # returned by openai
|
|
||||||
choices: List[ResponseChoice]
|
|
||||||
created: int
|
|
||||||
usage: ResponseUsage
|
|
||||||
|
|
||||||
|
|
||||||
class PictureDescriptionApiModel(PictureDescriptionBaseModel):
|
class PictureDescriptionApiModel(PictureDescriptionBaseModel):
|
||||||
@ -83,43 +48,11 @@ class PictureDescriptionApiModel(PictureDescriptionBaseModel):
|
|||||||
# Note: technically we could make a batch request here,
|
# Note: technically we could make a batch request here,
|
||||||
# but not all APIs will allow for it. For example, vllm won't allow more than 1.
|
# but not all APIs will allow for it. For example, vllm won't allow more than 1.
|
||||||
for image in images:
|
for image in images:
|
||||||
img_io = io.BytesIO()
|
yield openai_image_request(
|
||||||
image.save(img_io, "PNG")
|
image=image,
|
||||||
image_base64 = base64.b64encode(img_io.getvalue()).decode("utf-8")
|
prompt=self.options.prompt,
|
||||||
|
url=self.options.url,
|
||||||
messages = [
|
timeout=self.options.headers,
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": [
|
|
||||||
{
|
|
||||||
"type": "text",
|
|
||||||
"text": self.options.prompt,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "image_url",
|
|
||||||
"image_url": {
|
|
||||||
"url": f"data:image/png;base64,{image_base64}"
|
|
||||||
},
|
|
||||||
},
|
|
||||||
],
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
payload = {
|
|
||||||
"messages": messages,
|
|
||||||
**self.options.params,
|
|
||||||
}
|
|
||||||
|
|
||||||
r = requests.post(
|
|
||||||
str(self.options.url),
|
|
||||||
headers=self.options.headers,
|
headers=self.options.headers,
|
||||||
json=payload,
|
**self.options.params,
|
||||||
timeout=self.options.timeout,
|
|
||||||
)
|
)
|
||||||
if not r.ok:
|
|
||||||
_log.error(f"Error calling the API. Reponse was {r.text}")
|
|
||||||
r.raise_for_status()
|
|
||||||
|
|
||||||
api_resp = ApiResponse.model_validate_json(r.text)
|
|
||||||
generated_text = api_resp.choices[0].message.content.strip()
|
|
||||||
yield generated_text
|
|
||||||
|
@ -1,12 +1,19 @@
|
|||||||
|
import base64
|
||||||
import hashlib
|
import hashlib
|
||||||
|
import logging
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from itertools import islice
|
from itertools import islice
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Union
|
from typing import List, Union
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
from PIL import Image
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from docling.datamodel.base_models import OpenAiApiResponse
|
||||||
|
|
||||||
|
_log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def chunkify(iterator, chunk_size):
|
def chunkify(iterator, chunk_size):
|
||||||
"""Yield successive chunks of chunk_size from the iterable."""
|
"""Yield successive chunks of chunk_size from the iterable."""
|
||||||
@ -63,3 +70,57 @@ def download_url_with_progress(url: str, progress: bool = False) -> BytesIO:
|
|||||||
|
|
||||||
buf.seek(0)
|
buf.seek(0)
|
||||||
return buf
|
return buf
|
||||||
|
|
||||||
|
|
||||||
|
def openai_image_request(
|
||||||
|
image: Image.Image,
|
||||||
|
prompt: str,
|
||||||
|
url: str = "http://localhost:11434/v1/chat/completions", # Default to ollama
|
||||||
|
apikey: str | None = None,
|
||||||
|
timeout: float = 20,
|
||||||
|
headers: dict[str, str] | None = None,
|
||||||
|
**params,
|
||||||
|
) -> str:
|
||||||
|
img_io = BytesIO()
|
||||||
|
image.save(img_io, "PNG")
|
||||||
|
image_base64 = base64.b64encode(img_io.getvalue()).decode("utf-8")
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": f"data:image/png;base64,{image_base64}"
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": prompt,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"messages": messages,
|
||||||
|
**params,
|
||||||
|
}
|
||||||
|
|
||||||
|
headers = headers or {}
|
||||||
|
if apikey is not None:
|
||||||
|
headers["Authorization"] = f"Bearer {apikey}"
|
||||||
|
|
||||||
|
r = requests.post(
|
||||||
|
str(url),
|
||||||
|
headers=headers,
|
||||||
|
json=payload,
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
if not r.ok:
|
||||||
|
_log.error(f"Error calling the API. Response was {r.text}")
|
||||||
|
r.raise_for_status()
|
||||||
|
|
||||||
|
api_resp = OpenAiApiResponse.model_validate_json(r.text)
|
||||||
|
generated_text = api_resp.choices[0].message.content.strip()
|
||||||
|
return generated_text
|
||||||
|
Loading…
Reference in New Issue
Block a user