refactoring redundant code and fixing mypy errors

Signed-off-by: Peter Staar <taa@zurich.ibm.com>
This commit is contained in:
Peter Staar 2025-07-08 16:37:20 +02:00
parent b5479ab971
commit c10e2920a4
5 changed files with 94 additions and 63 deletions

View File

@ -11,6 +11,7 @@ from docling.datamodel.base_models import (
ItemAndImageEnrichmentElement,
Page,
TextCell,
VlmPredictionToken,
)
from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options import BaseOptions
@ -49,7 +50,13 @@ class BaseLayoutModel(BasePageModel):
class BaseVlmModel(BasePageModel):
@abstractmethod
def predict_on_page_image(self, *, page_image: Image.Image, prompt: str) -> str:
def get_user_prompt(self, page: Optional[Page]) -> str:
pass
@abstractmethod
def predict_on_page_image(
self, *, page_image: Image.Image, prompt: str, output_tokens: bool = False
) -> tuple[str, Optional[list[VlmPredictionToken]]]:
pass

View File

@ -38,7 +38,7 @@ class HuggingFaceTransformersVlmModel(BaseVlmModel, HuggingFaceModelDownloadMixi
self.vlm_options = vlm_options
self.scale = self.vlm_options.scale
self.max_size = self.vlm_options.max_size
# self.max_size = self.vlm_options.max_size
if self.enabled:
import torch

View File

@ -4,6 +4,8 @@ from collections.abc import Iterable
from pathlib import Path
from typing import Optional
from PIL import Image
from docling.datamodel.accelerator_options import (
AcceleratorOptions,
)
@ -33,7 +35,7 @@ class HuggingFaceMlxModel(BaseVlmModel, HuggingFaceModelDownloadMixin):
self.max_tokens = vlm_options.max_new_tokens
self.temperature = vlm_options.temperature
self.scale = self.vlm_options.scale
self.max_size = self.vlm_options.max_size
# self.max_size = self.vlm_options.max_size
if self.enabled:
try:
@ -62,6 +64,55 @@ class HuggingFaceMlxModel(BaseVlmModel, HuggingFaceModelDownloadMixin):
self.vlm_model, self.processor = load(artifacts_path)
self.config = load_config(artifacts_path)
def get_user_prompt(self, page: Optional[Page]) -> str:
if callable(self.vlm_options.prompt) and page is not None:
return self.vlm_options.prompt(page.parsed_page)
else:
user_prompt = self.vlm_options.prompt
prompt = self.apply_chat_template(
self.processor, self.config, user_prompt, num_images=1
)
return prompt
def predict_on_page_image(
self, *, page_image: Image.Image, prompt: str, output_tokens: bool = False
) -> tuple[str, Optional[list[VlmPredictionToken]]]:
tokens = []
output = ""
for token in self.stream_generate(
self.vlm_model,
self.processor,
prompt,
[page_image],
max_tokens=self.max_tokens,
verbose=False,
temp=self.temperature,
):
if len(token.logprobs.shape) == 1:
tokens.append(
VlmPredictionToken(
text=token.text,
token=token.token,
logprob=token.logprobs[token.token],
)
)
elif len(token.logprobs.shape) == 2 and token.logprobs.shape[0] == 1:
tokens.append(
VlmPredictionToken(
text=token.text,
token=token.token,
logprob=token.logprobs[0, token.token],
)
)
else:
_log.warning(f"incompatible shape for logprobs: {token.logprobs.shape}")
output += token.text
if "</doctag>" in token.text:
break
return output, tokens
def __call__(
self, conv_res: ConversionResult, page_batch: Iterable[Page]
) -> Iterable[Page]:
@ -73,19 +124,23 @@ class HuggingFaceMlxModel(BaseVlmModel, HuggingFaceModelDownloadMixin):
with TimeRecorder(conv_res, f"vlm-mlx-{self.vlm_options.repo_id}"):
assert page.size is not None
hi_res_image = page.get_image(
page_image = page.get_image(
scale=self.vlm_options.scale, max_size=self.vlm_options.max_size
)
if hi_res_image is not None:
im_width, im_height = hi_res_image.size
"""
if page_image is not None:
im_width, im_height = page_image.size
"""
assert page_image is not None
# populate page_tags with predicted doc tags
page_tags = ""
if hi_res_image:
if hi_res_image.mode != "RGB":
hi_res_image = hi_res_image.convert("RGB")
if page_image:
if page_image.mode != "RGB":
page_image = page_image.convert("RGB")
"""
if callable(self.vlm_options.prompt):
user_prompt = self.vlm_options.prompt(page.parsed_page)
else:
@ -93,11 +148,12 @@ class HuggingFaceMlxModel(BaseVlmModel, HuggingFaceModelDownloadMixin):
prompt = self.apply_chat_template(
self.processor, self.config, user_prompt, num_images=1
)
start_time = time.time()
_log.debug("start generating ...")
"""
prompt = self.get_user_prompt(page)
# Call model to generate:
start_time = time.time()
"""
tokens: list[VlmPredictionToken] = []
output = ""
@ -105,7 +161,7 @@ class HuggingFaceMlxModel(BaseVlmModel, HuggingFaceModelDownloadMixin):
self.vlm_model,
self.processor,
prompt,
[hi_res_image],
[page_image],
max_tokens=self.max_tokens,
verbose=False,
temp=self.temperature,
@ -137,13 +193,20 @@ class HuggingFaceMlxModel(BaseVlmModel, HuggingFaceModelDownloadMixin):
output += token.text
if "</doctag>" in token.text:
break
"""
output, tokens = self.predict_on_page_image(
page_image=page_image, prompt=prompt, output_tokens=True
)
generation_time = time.time() - start_time
page_tags = output
"""
_log.debug(
f"{generation_time:.2f} seconds for {len(tokens)} tokens ({len(tokens) / generation_time} tokens/sec)."
)
"""
page.predictions.vlm_response = VlmPrediction(
text=page_tags,
generation_time=generation_time,

View File

@ -61,64 +61,24 @@ class TwoStageVlmModel(BasePageModel, HuggingFaceModelDownloadMixin):
page=page, clusters=pred_clusters
)
)
# Define prompt structure
if callable(self.vlm_options.prompt):
user_prompt = self.vlm_options.prompt(page.parsed_page)
else:
user_prompt = self.vlm_options.prompt
prompt = self.formulate_prompt(user_prompt, processed_clusters)
user_prompt = self.vlm_model.get_user_prompt(page=page)
prompt = self.formulate_prompt(
user_prompt=user_prompt, clusters=processed_clusters
)
generated_text, generation_time = self.vlm_model.predict_on_image(
start_time = time.time()
generated_text = self.vlm_model.predict_on_page_image(
page_image=page_image, prompt=prompt
)
page.predictions.vlm_response = VlmPrediction(
text=generated_text,
generation_time=generation_time,
text=generated_text, generation_time=time.time() - start_time
)
yield page
def formulate_prompt(self, user_prompt: str, clusters: list[Cluster]) -> str:
def formulate_prompt(self, *, user_prompt: str, clusters: list[Cluster]) -> str:
"""Formulate a prompt for the VLM."""
if self.vlm_options.transformers_prompt_style == TransformersPromptStyle.RAW:
return user_prompt
elif self.vlm_options.repo_id == "microsoft/Phi-4-multimodal-instruct":
_log.debug("Using specialized prompt for Phi-4")
# more info here: https://huggingface.co/microsoft/Phi-4-multimodal-instruct#loading-the-model-locally
user_prompt = "<|user|>"
assistant_prompt = "<|assistant|>"
prompt_suffix = "<|end|>"
prompt = f"{user_prompt}<|image_1|>{user_prompt}{prompt_suffix}{assistant_prompt}"
_log.debug(f"prompt for {self.vlm_options.repo_id}: {prompt}")
return prompt
elif self.vlm_options.transformers_prompt_style == TransformersPromptStyle.CHAT:
messages = [
{
"role": "user",
"content": [
{
"type": "text",
"text": "This is a page from a document.",
},
{"type": "image"},
{"type": "text", "text": user_prompt},
],
}
]
prompt = self.processor.apply_chat_template(
messages, add_generation_prompt=False
)
return prompt
raise RuntimeError(
f"Uknown prompt style `{self.vlm_options.transformers_prompt_style}`. Valid values are {', '.join(s.value for s in TransformersPromptStyle)}."
)
return user_prompt

View File

@ -26,12 +26,13 @@ from docling.backend.md_backend import MarkdownDocumentBackend
from docling.backend.pdf_backend import PdfDocumentBackend
from docling.datamodel.base_models import InputFormat, Page
from docling.datamodel.document import ConversionResult, InputDocument
from docling.datamodel.pipeline_options import TwoStageVlmOptions, VlmPipelineOptions
from docling.datamodel.pipeline_options import VlmPipelineOptions
from docling.datamodel.pipeline_options_vlm_model import (
ApiVlmOptions,
InferenceFramework,
InlineVlmOptions,
ResponseFormat,
TwoStageVlmOptions,
)
from docling.datamodel.settings import settings
from docling.models.api_vlm_model import ApiVlmModel