generalize input args for other API providers

Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>
This commit is contained in:
Michele Dolfi 2025-04-10 10:27:41 +02:00
parent 66253e8a4b
commit 0f438b3a76
3 changed files with 13 additions and 19 deletions

View File

@ -288,9 +288,9 @@ class HuggingFaceVlmOptions(BaseVlmOptions):
class OpenAiVlmOptions(BaseVlmOptions): class OpenAiVlmOptions(BaseVlmOptions):
kind: Literal["openai_model_options"] = "openai_model_options" kind: Literal["openai_model_options"] = "openai_model_options"
model_id: str url: AnyUrl = AnyUrl("http://localhost:11434/v1/chat/completions") # Default to ollama
base_url: str = "http://localhost:11434/v1" # Default to ollama headers: Dict[str, str] = {}
apikey: Optional[str] = None params: Dict[str, Any] = {}
scale: float = 2.0 scale: float = 2.0
timeout: float = 60 timeout: float = 60
response_format: ResponseFormat response_format: ResponseFormat
@ -320,7 +320,8 @@ granite_vision_vlm_conversion_options = HuggingFaceVlmOptions(
) )
granite_vision_vlm_ollama_conversion_options = OpenAiVlmOptions( granite_vision_vlm_ollama_conversion_options = OpenAiVlmOptions(
model_id="granite3.2-vision:2b", url=AnyUrl("http://localhost:11434/v1/chat/completions"),
params={"model": "granite3.2-vision:2b"},
prompt="OCR the full page to markdown.", prompt="OCR the full page to markdown.",
scale=1.0, scale=1.0,
timeout=120, timeout=120,

View File

@ -18,15 +18,14 @@ class OpenAiVlmModel(BasePageModel):
self.enabled = enabled self.enabled = enabled
self.vlm_options = vlm_options self.vlm_options = vlm_options
if self.enabled: if self.enabled:
self.url = "/".join(
[self.vlm_options.base_url.rstrip("/"), "chat/completions"]
)
self.apikey = self.vlm_options.apikey
self.model_id = self.vlm_options.model_id
self.timeout = self.vlm_options.timeout self.timeout = self.vlm_options.timeout
self.prompt_content = ( self.prompt_content = (
f"This is a page from a document.\n{self.vlm_options.prompt}" f"This is a page from a document.\n{self.vlm_options.prompt}"
) )
self.params = {
**self.vlm_options.params,
"temperature": 0,
}
def __call__( def __call__(
self, conv_res: ConversionResult, page_batch: Iterable[Page] self, conv_res: ConversionResult, page_batch: Iterable[Page]
@ -48,11 +47,10 @@ class OpenAiVlmModel(BasePageModel):
page_tags = openai_image_request( page_tags = openai_image_request(
image=hi_res_image, image=hi_res_image,
prompt=self.prompt_content, prompt=self.prompt_content,
url=self.url, url=self.vlm_options.url,
apikey=self.apikey,
timeout=self.timeout, timeout=self.timeout,
model=self.model_id, headers=self.vlm_options.headers,
temperature=0, **self.params,
) )
page.predictions.vlm_response = VlmPrediction(text=page_tags) page.predictions.vlm_response = VlmPrediction(text=page_tags)

View File

@ -76,10 +76,7 @@ def download_url_with_progress(url: str, progress: bool = False) -> BytesIO:
def openai_image_request( def openai_image_request(
image: Image.Image, image: Image.Image,
prompt: str, prompt: str,
url: Union[ url: AnyUrl,
AnyUrl, str
] = "http://localhost:11434/v1/chat/completions", # Default to ollama
apikey: Optional[str] = None,
timeout: float = 20, timeout: float = 20,
headers: Optional[Dict[str, str]] = None, headers: Optional[Dict[str, str]] = None,
**params, **params,
@ -109,8 +106,6 @@ def openai_image_request(
} }
headers = headers or {} headers = headers or {}
if apikey is not None:
headers["Authorization"] = f"Bearer {apikey}"
r = requests.post( r = requests.post(
str(url), str(url),