mirror of
https://github.com/DS4SD/docling.git
synced 2025-07-25 19:44:34 +00:00
Merge 5e1e82ab3b
into 98e2fcff63
This commit is contained in:
commit
b87a1d9ccb
@ -14,6 +14,7 @@ class BaseVlmOptions(BaseModel):
|
|||||||
scale: float = 2.0
|
scale: float = 2.0
|
||||||
max_size: Optional[int] = None
|
max_size: Optional[int] = None
|
||||||
temperature: float = 0.0
|
temperature: float = 0.0
|
||||||
|
decode_response: Optional[Callable[[str], str]] = None
|
||||||
|
|
||||||
|
|
||||||
class ResponseFormat(str, Enum):
|
class ResponseFormat(str, Enum):
|
||||||
|
@ -67,6 +67,8 @@ class ApiVlmModel(BasePageModel):
|
|||||||
**self.params,
|
**self.params,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.vlm_options.decode_response:
|
||||||
|
page_tags = self.vlm_options.decode_response(page_tags)
|
||||||
page.predictions.vlm_response = VlmPrediction(text=page_tags)
|
page.predictions.vlm_response = VlmPrediction(text=page_tags)
|
||||||
|
|
||||||
return page
|
return page
|
||||||
|
@ -166,6 +166,10 @@ class HuggingFaceTransformersVlmModel(BasePageModel, HuggingFaceModelDownloadMix
|
|||||||
_log.debug(
|
_log.debug(
|
||||||
f"Generated {num_tokens} tokens in time {generation_time:.2f} seconds."
|
f"Generated {num_tokens} tokens in time {generation_time:.2f} seconds."
|
||||||
)
|
)
|
||||||
|
if self.vlm_options.decode_response:
|
||||||
|
generated_texts = self.vlm_options.decode_response(
|
||||||
|
generated_texts
|
||||||
|
)
|
||||||
page.predictions.vlm_response = VlmPrediction(
|
page.predictions.vlm_response = VlmPrediction(
|
||||||
text=generated_texts,
|
text=generated_texts,
|
||||||
generation_time=generation_time,
|
generation_time=generation_time,
|
||||||
|
@ -142,6 +142,8 @@ class HuggingFaceMlxModel(BasePageModel, HuggingFaceModelDownloadMixin):
|
|||||||
_log.debug(
|
_log.debug(
|
||||||
f"{generation_time:.2f} seconds for {len(tokens)} tokens ({len(tokens) / generation_time} tokens/sec)."
|
f"{generation_time:.2f} seconds for {len(tokens)} tokens ({len(tokens) / generation_time} tokens/sec)."
|
||||||
)
|
)
|
||||||
|
if self.vlm_options.decode_response:
|
||||||
|
page_tags = self.vlm_options.decode_response(page_tags)
|
||||||
page.predictions.vlm_response = VlmPrediction(
|
page.predictions.vlm_response = VlmPrediction(
|
||||||
text=page_tags,
|
text=page_tags,
|
||||||
generation_time=generation_time,
|
generation_time=generation_time,
|
||||||
|
10
docs/examples/vlm_pipeline_api_model.py
vendored
10
docs/examples/vlm_pipeline_api_model.py
vendored
@ -1,3 +1,4 @@
|
|||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -83,6 +84,14 @@ def lms_olmocr_vlm_options(model: str):
|
|||||||
f"RAW_TEXT_START\n{base_text}\nRAW_TEXT_END"
|
f"RAW_TEXT_START\n{base_text}\nRAW_TEXT_END"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _decode_olmocr_response(generated_text: str) -> str:
|
||||||
|
try:
|
||||||
|
generated_json = json.loads(generated_text)
|
||||||
|
except json.decoder.JSONDecodeError:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
return generated_json["natural_text"]
|
||||||
|
|
||||||
options = ApiVlmOptions(
|
options = ApiVlmOptions(
|
||||||
url="http://localhost:1234/v1/chat/completions",
|
url="http://localhost:1234/v1/chat/completions",
|
||||||
params=dict(
|
params=dict(
|
||||||
@ -92,6 +101,7 @@ def lms_olmocr_vlm_options(model: str):
|
|||||||
timeout=90,
|
timeout=90,
|
||||||
scale=1.0,
|
scale=1.0,
|
||||||
max_size=1024, # from OlmOcr pipeline
|
max_size=1024, # from OlmOcr pipeline
|
||||||
|
decode_response=_decode_olmocr_response,
|
||||||
response_format=ResponseFormat.MARKDOWN,
|
response_format=ResponseFormat.MARKDOWN,
|
||||||
)
|
)
|
||||||
return options
|
return options
|
||||||
|
Loading…
Reference in New Issue
Block a user