mirror of
https://github.com/DS4SD/docling.git
synced 2025-12-08 12:48:28 +00:00
feat(vlm): Ability to preprocess VLM response (#1907)
* Add ability to preprocess VLM response Signed-off-by: Shkarupa Alex <shkarupa.alex@gmail.com> * Move response decoding to vlm options (requires inheritance to override). Per-page prompt formulation also moved to vlm options to keep api consistent. Signed-off-by: Shkarupa Alex <shkarupa.alex@gmail.com> --------- Signed-off-by: Shkarupa Alex <shkarupa.alex@gmail.com>
This commit is contained in:
@@ -1,5 +1,5 @@
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Callable, Dict, List, Literal, Optional, Union
|
from typing import Any, Dict, List, Literal, Optional
|
||||||
|
|
||||||
from docling_core.types.doc.page import SegmentedPage
|
from docling_core.types.doc.page import SegmentedPage
|
||||||
from pydantic import AnyUrl, BaseModel
|
from pydantic import AnyUrl, BaseModel
|
||||||
@@ -10,11 +10,17 @@ from docling.datamodel.accelerator_options import AcceleratorDevice
|
|||||||
|
|
||||||
class BaseVlmOptions(BaseModel):
|
class BaseVlmOptions(BaseModel):
|
||||||
kind: str
|
kind: str
|
||||||
prompt: Union[str, Callable[[Optional[SegmentedPage]], str]]
|
prompt: str
|
||||||
scale: float = 2.0
|
scale: float = 2.0
|
||||||
max_size: Optional[int] = None
|
max_size: Optional[int] = None
|
||||||
temperature: float = 0.0
|
temperature: float = 0.0
|
||||||
|
|
||||||
|
def build_prompt(self, page: Optional[SegmentedPage]) -> str:
|
||||||
|
return self.prompt
|
||||||
|
|
||||||
|
def decode_response(self, text: str) -> str:
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
class ResponseFormat(str, Enum):
|
class ResponseFormat(str, Enum):
|
||||||
DOCTAGS = "doctags"
|
DOCTAGS = "doctags"
|
||||||
|
|||||||
@@ -53,11 +53,7 @@ class ApiVlmModel(BasePageModel):
|
|||||||
if hi_res_image.mode != "RGB":
|
if hi_res_image.mode != "RGB":
|
||||||
hi_res_image = hi_res_image.convert("RGB")
|
hi_res_image = hi_res_image.convert("RGB")
|
||||||
|
|
||||||
if callable(self.vlm_options.prompt):
|
prompt = self.vlm_options.build_prompt(page.parsed_page)
|
||||||
prompt = self.vlm_options.prompt(page.parsed_page)
|
|
||||||
else:
|
|
||||||
prompt = self.vlm_options.prompt
|
|
||||||
|
|
||||||
page_tags = api_image_request(
|
page_tags = api_image_request(
|
||||||
image=hi_res_image,
|
image=hi_res_image,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
@@ -67,6 +63,7 @@ class ApiVlmModel(BasePageModel):
|
|||||||
**self.params,
|
**self.params,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
page_tags = self.vlm_options.decode_response(page_tags)
|
||||||
page.predictions.vlm_response = VlmPrediction(text=page_tags)
|
page.predictions.vlm_response = VlmPrediction(text=page_tags)
|
||||||
|
|
||||||
return page
|
return page
|
||||||
|
|||||||
@@ -135,10 +135,7 @@ class HuggingFaceTransformersVlmModel(BasePageModel, HuggingFaceModelDownloadMix
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Define prompt structure
|
# Define prompt structure
|
||||||
if callable(self.vlm_options.prompt):
|
user_prompt = self.vlm_options.build_prompt(page.parsed_page)
|
||||||
user_prompt = self.vlm_options.prompt(page.parsed_page)
|
|
||||||
else:
|
|
||||||
user_prompt = self.vlm_options.prompt
|
|
||||||
prompt = self.formulate_prompt(user_prompt)
|
prompt = self.formulate_prompt(user_prompt)
|
||||||
|
|
||||||
inputs = self.processor(
|
inputs = self.processor(
|
||||||
@@ -166,6 +163,7 @@ class HuggingFaceTransformersVlmModel(BasePageModel, HuggingFaceModelDownloadMix
|
|||||||
_log.debug(
|
_log.debug(
|
||||||
f"Generated {num_tokens} tokens in time {generation_time:.2f} seconds."
|
f"Generated {num_tokens} tokens in time {generation_time:.2f} seconds."
|
||||||
)
|
)
|
||||||
|
generated_texts = self.vlm_options.decode_response(generated_texts)
|
||||||
page.predictions.vlm_response = VlmPrediction(
|
page.predictions.vlm_response = VlmPrediction(
|
||||||
text=generated_texts,
|
text=generated_texts,
|
||||||
generation_time=generation_time,
|
generation_time=generation_time,
|
||||||
|
|||||||
@@ -84,10 +84,7 @@ class HuggingFaceMlxModel(BasePageModel, HuggingFaceModelDownloadMixin):
|
|||||||
if hi_res_image.mode != "RGB":
|
if hi_res_image.mode != "RGB":
|
||||||
hi_res_image = hi_res_image.convert("RGB")
|
hi_res_image = hi_res_image.convert("RGB")
|
||||||
|
|
||||||
if callable(self.vlm_options.prompt):
|
user_prompt = self.vlm_options.build_prompt(page.parsed_page)
|
||||||
user_prompt = self.vlm_options.prompt(page.parsed_page)
|
|
||||||
else:
|
|
||||||
user_prompt = self.vlm_options.prompt
|
|
||||||
prompt = self.apply_chat_template(
|
prompt = self.apply_chat_template(
|
||||||
self.processor, self.config, user_prompt, num_images=1
|
self.processor, self.config, user_prompt, num_images=1
|
||||||
)
|
)
|
||||||
@@ -142,6 +139,7 @@ class HuggingFaceMlxModel(BasePageModel, HuggingFaceModelDownloadMixin):
|
|||||||
_log.debug(
|
_log.debug(
|
||||||
f"{generation_time:.2f} seconds for {len(tokens)} tokens ({len(tokens) / generation_time} tokens/sec)."
|
f"{generation_time:.2f} seconds for {len(tokens)} tokens ({len(tokens) / generation_time} tokens/sec)."
|
||||||
)
|
)
|
||||||
|
page_tags = self.vlm_options.decode_response(page_tags)
|
||||||
page.predictions.vlm_response = VlmPrediction(
|
page.predictions.vlm_response = VlmPrediction(
|
||||||
text=page_tags,
|
text=page_tags,
|
||||||
generation_time=generation_time,
|
generation_time=generation_time,
|
||||||
|
|||||||
85
docs/examples/vlm_pipeline_api_model.py
vendored
85
docs/examples/vlm_pipeline_api_model.py
vendored
@@ -1,3 +1,4 @@
|
|||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -38,57 +39,63 @@ def lms_vlm_options(model: str, prompt: str, format: ResponseFormat):
|
|||||||
|
|
||||||
|
|
||||||
def lms_olmocr_vlm_options(model: str):
|
def lms_olmocr_vlm_options(model: str):
|
||||||
def _dynamic_olmocr_prompt(page: Optional[SegmentedPage]):
|
class OlmocrVlmOptions(ApiVlmOptions):
|
||||||
if page is None:
|
def build_prompt(self, page: Optional[SegmentedPage]) -> str:
|
||||||
return (
|
if page is None:
|
||||||
"Below is the image of one page of a document. Just return the plain text"
|
return self.prompt.replace("#RAW_TEXT#", "")
|
||||||
" representation of this document as if you were reading it naturally.\n"
|
|
||||||
"Do not hallucinate.\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
anchor = [
|
anchor = [
|
||||||
f"Page dimensions: {int(page.dimension.width)}x{int(page.dimension.height)}"
|
f"Page dimensions: {int(page.dimension.width)}x{int(page.dimension.height)}"
|
||||||
]
|
]
|
||||||
|
|
||||||
for text_cell in page.textline_cells:
|
for text_cell in page.textline_cells:
|
||||||
if not text_cell.text.strip():
|
if not text_cell.text.strip():
|
||||||
continue
|
continue
|
||||||
bbox = text_cell.rect.to_bounding_box().to_bottom_left_origin(
|
bbox = text_cell.rect.to_bounding_box().to_bottom_left_origin(
|
||||||
page.dimension.height
|
page.dimension.height
|
||||||
)
|
)
|
||||||
anchor.append(f"[{int(bbox.l)}x{int(bbox.b)}] {text_cell.text}")
|
anchor.append(f"[{int(bbox.l)}x{int(bbox.b)}] {text_cell.text}")
|
||||||
|
|
||||||
for image_cell in page.bitmap_resources:
|
for image_cell in page.bitmap_resources:
|
||||||
bbox = image_cell.rect.to_bounding_box().to_bottom_left_origin(
|
bbox = image_cell.rect.to_bounding_box().to_bottom_left_origin(
|
||||||
page.dimension.height
|
page.dimension.height
|
||||||
)
|
)
|
||||||
anchor.append(
|
anchor.append(
|
||||||
f"[Image {int(bbox.l)}x{int(bbox.b)} to {int(bbox.r)}x{int(bbox.t)}]"
|
f"[Image {int(bbox.l)}x{int(bbox.b)} to {int(bbox.r)}x{int(bbox.t)}]"
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(anchor) == 1:
|
if len(anchor) == 1:
|
||||||
anchor.append(
|
anchor.append(
|
||||||
f"[Image 0x0 to {int(page.dimension.width)}x{int(page.dimension.height)}]"
|
f"[Image 0x0 to {int(page.dimension.width)}x{int(page.dimension.height)}]"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Original prompt uses cells sorting. We are skipping it in this demo.
|
# Original prompt uses cells sorting. We are skipping it for simplicity.
|
||||||
|
|
||||||
base_text = "\n".join(anchor)
|
raw_text = "\n".join(anchor)
|
||||||
|
|
||||||
return (
|
return self.prompt.replace("#RAW_TEXT#", raw_text)
|
||||||
f"Below is the image of one page of a document, as well as some raw textual"
|
|
||||||
f" content that was previously extracted for it. Just return the plain text"
|
|
||||||
f" representation of this document as if you were reading it naturally.\n"
|
|
||||||
f"Do not hallucinate.\n"
|
|
||||||
f"RAW_TEXT_START\n{base_text}\nRAW_TEXT_END"
|
|
||||||
)
|
|
||||||
|
|
||||||
options = ApiVlmOptions(
|
def decode_response(self, text: str) -> str:
|
||||||
|
# OlmOcr trained to generate json response with language, rotation and other info
|
||||||
|
try:
|
||||||
|
generated_json = json.loads(text)
|
||||||
|
except json.decoder.JSONDecodeError:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
return generated_json["natural_text"]
|
||||||
|
|
||||||
|
options = OlmocrVlmOptions(
|
||||||
url="http://localhost:1234/v1/chat/completions",
|
url="http://localhost:1234/v1/chat/completions",
|
||||||
params=dict(
|
params=dict(
|
||||||
model=model,
|
model=model,
|
||||||
),
|
),
|
||||||
prompt=_dynamic_olmocr_prompt,
|
prompt=(
|
||||||
|
"Below is the image of one page of a document, as well as some raw textual"
|
||||||
|
" content that was previously extracted for it. Just return the plain text"
|
||||||
|
" representation of this document as if you were reading it naturally.\n"
|
||||||
|
"Do not hallucinate.\n"
|
||||||
|
"RAW_TEXT_START\n#RAW_TEXT#\nRAW_TEXT_END"
|
||||||
|
),
|
||||||
timeout=90,
|
timeout=90,
|
||||||
scale=1.0,
|
scale=1.0,
|
||||||
max_size=1024, # from OlmOcr pipeline
|
max_size=1024, # from OlmOcr pipeline
|
||||||
|
|||||||
Reference in New Issue
Block a user