This commit is contained in:
Shkarupa Alex 2025-07-23 20:39:41 +03:00 committed by GitHub
commit b87a1d9ccb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 19 additions and 0 deletions

View File

@ -14,6 +14,7 @@ class BaseVlmOptions(BaseModel):
scale: float = 2.0
max_size: Optional[int] = None
temperature: float = 0.0
decode_response: Optional[Callable[[str], str]] = None
class ResponseFormat(str, Enum):

View File

@ -67,6 +67,8 @@ class ApiVlmModel(BasePageModel):
**self.params,
)
if self.vlm_options.decode_response:
page_tags = self.vlm_options.decode_response(page_tags)
page.predictions.vlm_response = VlmPrediction(text=page_tags)
return page

View File

@ -166,6 +166,10 @@ class HuggingFaceTransformersVlmModel(BasePageModel, HuggingFaceModelDownloadMix
_log.debug(
f"Generated {num_tokens} tokens in time {generation_time:.2f} seconds."
)
if self.vlm_options.decode_response:
generated_texts = self.vlm_options.decode_response(
generated_texts
)
page.predictions.vlm_response = VlmPrediction(
text=generated_texts,
generation_time=generation_time,

View File

@ -142,6 +142,8 @@ class HuggingFaceMlxModel(BasePageModel, HuggingFaceModelDownloadMixin):
_log.debug(
f"{generation_time:.2f} seconds for {len(tokens)} tokens ({len(tokens) / generation_time} tokens/sec)."
)
if self.vlm_options.decode_response:
page_tags = self.vlm_options.decode_response(page_tags)
page.predictions.vlm_response = VlmPrediction(
text=page_tags,
generation_time=generation_time,

View File

@ -1,3 +1,4 @@
import json
import logging
import os
from pathlib import Path
@ -83,6 +84,14 @@ def lms_olmocr_vlm_options(model: str):
f"RAW_TEXT_START\n{base_text}\nRAW_TEXT_END"
)
def _decode_olmocr_response(generated_text: str) -> str:
try:
generated_json = json.loads(generated_text)
except json.decoder.JSONDecodeError:
return ""
return generated_json["natural_text"]
options = ApiVlmOptions(
url="http://localhost:1234/v1/chat/completions",
params=dict(
@ -92,6 +101,7 @@ def lms_olmocr_vlm_options(model: str):
timeout=90,
scale=1.0,
max_size=1024, # from OlmOcr pipeline
decode_response=_decode_olmocr_response,
response_format=ResponseFormat.MARKDOWN,
)
return options