mirror of
https://github.com/DS4SD/docling.git
synced 2025-07-27 04:24:45 +00:00
generalize input args for other API providers
Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>
This commit is contained in:
parent
66253e8a4b
commit
0f438b3a76
@ -288,9 +288,9 @@ class HuggingFaceVlmOptions(BaseVlmOptions):
|
|||||||
class OpenAiVlmOptions(BaseVlmOptions):
|
class OpenAiVlmOptions(BaseVlmOptions):
|
||||||
kind: Literal["openai_model_options"] = "openai_model_options"
|
kind: Literal["openai_model_options"] = "openai_model_options"
|
||||||
|
|
||||||
model_id: str
|
url: AnyUrl = AnyUrl("http://localhost:11434/v1/chat/completions") # Default to ollama
|
||||||
base_url: str = "http://localhost:11434/v1" # Default to ollama
|
headers: Dict[str, str] = {}
|
||||||
apikey: Optional[str] = None
|
params: Dict[str, Any] = {}
|
||||||
scale: float = 2.0
|
scale: float = 2.0
|
||||||
timeout: float = 60
|
timeout: float = 60
|
||||||
response_format: ResponseFormat
|
response_format: ResponseFormat
|
||||||
@ -320,7 +320,8 @@ granite_vision_vlm_conversion_options = HuggingFaceVlmOptions(
|
|||||||
)
|
)
|
||||||
|
|
||||||
granite_vision_vlm_ollama_conversion_options = OpenAiVlmOptions(
|
granite_vision_vlm_ollama_conversion_options = OpenAiVlmOptions(
|
||||||
model_id="granite3.2-vision:2b",
|
url=AnyUrl("http://localhost:11434/v1/chat/completions"),
|
||||||
|
params={"model": "granite3.2-vision:2b"},
|
||||||
prompt="OCR the full page to markdown.",
|
prompt="OCR the full page to markdown.",
|
||||||
scale=1.0,
|
scale=1.0,
|
||||||
timeout=120,
|
timeout=120,
|
||||||
|
@ -18,15 +18,14 @@ class OpenAiVlmModel(BasePageModel):
|
|||||||
self.enabled = enabled
|
self.enabled = enabled
|
||||||
self.vlm_options = vlm_options
|
self.vlm_options = vlm_options
|
||||||
if self.enabled:
|
if self.enabled:
|
||||||
self.url = "/".join(
|
|
||||||
[self.vlm_options.base_url.rstrip("/"), "chat/completions"]
|
|
||||||
)
|
|
||||||
self.apikey = self.vlm_options.apikey
|
|
||||||
self.model_id = self.vlm_options.model_id
|
|
||||||
self.timeout = self.vlm_options.timeout
|
self.timeout = self.vlm_options.timeout
|
||||||
self.prompt_content = (
|
self.prompt_content = (
|
||||||
f"This is a page from a document.\n{self.vlm_options.prompt}"
|
f"This is a page from a document.\n{self.vlm_options.prompt}"
|
||||||
)
|
)
|
||||||
|
self.params = {
|
||||||
|
**self.vlm_options.params,
|
||||||
|
"temperature": 0,
|
||||||
|
}
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self, conv_res: ConversionResult, page_batch: Iterable[Page]
|
self, conv_res: ConversionResult, page_batch: Iterable[Page]
|
||||||
@ -48,11 +47,10 @@ class OpenAiVlmModel(BasePageModel):
|
|||||||
page_tags = openai_image_request(
|
page_tags = openai_image_request(
|
||||||
image=hi_res_image,
|
image=hi_res_image,
|
||||||
prompt=self.prompt_content,
|
prompt=self.prompt_content,
|
||||||
url=self.url,
|
url=self.vlm_options.url,
|
||||||
apikey=self.apikey,
|
|
||||||
timeout=self.timeout,
|
timeout=self.timeout,
|
||||||
model=self.model_id,
|
headers=self.vlm_options.headers,
|
||||||
temperature=0,
|
**self.params,
|
||||||
)
|
)
|
||||||
|
|
||||||
page.predictions.vlm_response = VlmPrediction(text=page_tags)
|
page.predictions.vlm_response = VlmPrediction(text=page_tags)
|
||||||
|
@ -76,10 +76,7 @@ def download_url_with_progress(url: str, progress: bool = False) -> BytesIO:
|
|||||||
def openai_image_request(
|
def openai_image_request(
|
||||||
image: Image.Image,
|
image: Image.Image,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
url: Union[
|
url: AnyUrl,
|
||||||
AnyUrl, str
|
|
||||||
] = "http://localhost:11434/v1/chat/completions", # Default to ollama
|
|
||||||
apikey: Optional[str] = None,
|
|
||||||
timeout: float = 20,
|
timeout: float = 20,
|
||||||
headers: Optional[Dict[str, str]] = None,
|
headers: Optional[Dict[str, str]] = None,
|
||||||
**params,
|
**params,
|
||||||
@ -109,8 +106,6 @@ def openai_image_request(
|
|||||||
}
|
}
|
||||||
|
|
||||||
headers = headers or {}
|
headers = headers or {}
|
||||||
if apikey is not None:
|
|
||||||
headers["Authorization"] = f"Bearer {apikey}"
|
|
||||||
|
|
||||||
r = requests.post(
|
r = requests.post(
|
||||||
str(url),
|
str(url),
|
||||||
|
Loading…
Reference in New Issue
Block a user