From 0f438b3a766533b644442b3c1e29a06de591f19a Mon Sep 17 00:00:00 2001 From: Michele Dolfi Date: Thu, 10 Apr 2025 10:27:41 +0200 Subject: [PATCH] generalize input args for other API providers Signed-off-by: Michele Dolfi --- docling/datamodel/pipeline_options.py | 9 +++++---- docling/models/openai_vlm_model.py | 16 +++++++--------- docling/utils/utils.py | 7 +------ 3 files changed, 13 insertions(+), 19 deletions(-) diff --git a/docling/datamodel/pipeline_options.py b/docling/datamodel/pipeline_options.py index ddafa02c..63e88a66 100644 --- a/docling/datamodel/pipeline_options.py +++ b/docling/datamodel/pipeline_options.py @@ -288,9 +288,9 @@ class HuggingFaceVlmOptions(BaseVlmOptions): class OpenAiVlmOptions(BaseVlmOptions): kind: Literal["openai_model_options"] = "openai_model_options" - model_id: str - base_url: str = "http://localhost:11434/v1" # Default to ollama - apikey: Optional[str] = None + url: AnyUrl = AnyUrl("http://localhost:11434/v1/chat/completions") # Default to ollama + headers: Dict[str, str] = {} + params: Dict[str, Any] = {} scale: float = 2.0 timeout: float = 60 response_format: ResponseFormat @@ -320,7 +320,8 @@ granite_vision_vlm_conversion_options = HuggingFaceVlmOptions( ) granite_vision_vlm_ollama_conversion_options = OpenAiVlmOptions( - model_id="granite3.2-vision:2b", + url=AnyUrl("http://localhost:11434/v1/chat/completions"), + params={"model": "granite3.2-vision:2b"}, prompt="OCR the full page to markdown.", scale=1.0, timeout=120, diff --git a/docling/models/openai_vlm_model.py b/docling/models/openai_vlm_model.py index eb7c45ce..bf06fb2c 100644 --- a/docling/models/openai_vlm_model.py +++ b/docling/models/openai_vlm_model.py @@ -18,15 +18,14 @@ class OpenAiVlmModel(BasePageModel): self.enabled = enabled self.vlm_options = vlm_options if self.enabled: - self.url = "/".join( - [self.vlm_options.base_url.rstrip("/"), "chat/completions"] - ) - self.apikey = self.vlm_options.apikey - self.model_id = self.vlm_options.model_id self.timeout = self.vlm_options.timeout self.prompt_content = ( f"This is a page from a document.\n{self.vlm_options.prompt}" ) + self.params = { + **self.vlm_options.params, + "temperature": 0, + } def __call__( self, conv_res: ConversionResult, page_batch: Iterable[Page] @@ -48,11 +47,10 @@ class OpenAiVlmModel(BasePageModel): page_tags = openai_image_request( image=hi_res_image, prompt=self.prompt_content, - url=self.url, - apikey=self.apikey, + url=self.vlm_options.url, timeout=self.timeout, - model=self.model_id, - temperature=0, + headers=self.vlm_options.headers, + **self.params, ) page.predictions.vlm_response = VlmPrediction(text=page_tags) diff --git a/docling/utils/utils.py b/docling/utils/utils.py index ecc21baf..51dca03a 100644 --- a/docling/utils/utils.py +++ b/docling/utils/utils.py @@ -76,10 +76,7 @@ def download_url_with_progress(url: str, progress: bool = False) -> BytesIO: def openai_image_request( image: Image.Image, prompt: str, - url: Union[ - AnyUrl, str - ] = "http://localhost:11434/v1/chat/completions", # Default to ollama - apikey: Optional[str] = None, + url: AnyUrl, timeout: float = 20, headers: Optional[Dict[str, str]] = None, **params, @@ -109,8 +106,6 @@ def openai_image_request( } headers = headers or {} - if apikey is not None: - headers["Authorization"] = f"Bearer {apikey}" r = requests.post( str(url),