mirror of
https://github.com/DS4SD/docling.git
synced 2025-07-25 19:44:34 +00:00
refactoring redundant code and fixing mypy errors
Signed-off-by: Peter Staar <taa@zurich.ibm.com>
This commit is contained in:
parent
b5479ab971
commit
c10e2920a4
@ -11,6 +11,7 @@ from docling.datamodel.base_models import (
|
||||
ItemAndImageEnrichmentElement,
|
||||
Page,
|
||||
TextCell,
|
||||
VlmPredictionToken,
|
||||
)
|
||||
from docling.datamodel.document import ConversionResult
|
||||
from docling.datamodel.pipeline_options import BaseOptions
|
||||
@ -49,7 +50,13 @@ class BaseLayoutModel(BasePageModel):
|
||||
|
||||
class BaseVlmModel(BasePageModel):
|
||||
@abstractmethod
|
||||
def predict_on_page_image(self, *, page_image: Image.Image, prompt: str) -> str:
|
||||
def get_user_prompt(self, page: Optional[Page]) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def predict_on_page_image(
|
||||
self, *, page_image: Image.Image, prompt: str, output_tokens: bool = False
|
||||
) -> tuple[str, Optional[list[VlmPredictionToken]]]:
|
||||
pass
|
||||
|
||||
|
||||
|
@ -38,7 +38,7 @@ class HuggingFaceTransformersVlmModel(BaseVlmModel, HuggingFaceModelDownloadMixi
|
||||
self.vlm_options = vlm_options
|
||||
|
||||
self.scale = self.vlm_options.scale
|
||||
self.max_size = self.vlm_options.max_size
|
||||
# self.max_size = self.vlm_options.max_size
|
||||
|
||||
if self.enabled:
|
||||
import torch
|
||||
|
@ -4,6 +4,8 @@ from collections.abc import Iterable
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from docling.datamodel.accelerator_options import (
|
||||
AcceleratorOptions,
|
||||
)
|
||||
@ -33,7 +35,7 @@ class HuggingFaceMlxModel(BaseVlmModel, HuggingFaceModelDownloadMixin):
|
||||
self.max_tokens = vlm_options.max_new_tokens
|
||||
self.temperature = vlm_options.temperature
|
||||
self.scale = self.vlm_options.scale
|
||||
self.max_size = self.vlm_options.max_size
|
||||
# self.max_size = self.vlm_options.max_size
|
||||
|
||||
if self.enabled:
|
||||
try:
|
||||
@ -62,6 +64,55 @@ class HuggingFaceMlxModel(BaseVlmModel, HuggingFaceModelDownloadMixin):
|
||||
self.vlm_model, self.processor = load(artifacts_path)
|
||||
self.config = load_config(artifacts_path)
|
||||
|
||||
def get_user_prompt(self, page: Optional[Page]) -> str:
|
||||
if callable(self.vlm_options.prompt) and page is not None:
|
||||
return self.vlm_options.prompt(page.parsed_page)
|
||||
else:
|
||||
user_prompt = self.vlm_options.prompt
|
||||
prompt = self.apply_chat_template(
|
||||
self.processor, self.config, user_prompt, num_images=1
|
||||
)
|
||||
return prompt
|
||||
|
||||
def predict_on_page_image(
|
||||
self, *, page_image: Image.Image, prompt: str, output_tokens: bool = False
|
||||
) -> tuple[str, Optional[list[VlmPredictionToken]]]:
|
||||
tokens = []
|
||||
output = ""
|
||||
for token in self.stream_generate(
|
||||
self.vlm_model,
|
||||
self.processor,
|
||||
prompt,
|
||||
[page_image],
|
||||
max_tokens=self.max_tokens,
|
||||
verbose=False,
|
||||
temp=self.temperature,
|
||||
):
|
||||
if len(token.logprobs.shape) == 1:
|
||||
tokens.append(
|
||||
VlmPredictionToken(
|
||||
text=token.text,
|
||||
token=token.token,
|
||||
logprob=token.logprobs[token.token],
|
||||
)
|
||||
)
|
||||
elif len(token.logprobs.shape) == 2 and token.logprobs.shape[0] == 1:
|
||||
tokens.append(
|
||||
VlmPredictionToken(
|
||||
text=token.text,
|
||||
token=token.token,
|
||||
logprob=token.logprobs[0, token.token],
|
||||
)
|
||||
)
|
||||
else:
|
||||
_log.warning(f"incompatible shape for logprobs: {token.logprobs.shape}")
|
||||
|
||||
output += token.text
|
||||
if "</doctag>" in token.text:
|
||||
break
|
||||
|
||||
return output, tokens
|
||||
|
||||
def __call__(
|
||||
self, conv_res: ConversionResult, page_batch: Iterable[Page]
|
||||
) -> Iterable[Page]:
|
||||
@ -73,19 +124,23 @@ class HuggingFaceMlxModel(BaseVlmModel, HuggingFaceModelDownloadMixin):
|
||||
with TimeRecorder(conv_res, f"vlm-mlx-{self.vlm_options.repo_id}"):
|
||||
assert page.size is not None
|
||||
|
||||
hi_res_image = page.get_image(
|
||||
page_image = page.get_image(
|
||||
scale=self.vlm_options.scale, max_size=self.vlm_options.max_size
|
||||
)
|
||||
if hi_res_image is not None:
|
||||
im_width, im_height = hi_res_image.size
|
||||
"""
|
||||
if page_image is not None:
|
||||
im_width, im_height = page_image.size
|
||||
"""
|
||||
assert page_image is not None
|
||||
|
||||
# populate page_tags with predicted doc tags
|
||||
page_tags = ""
|
||||
|
||||
if hi_res_image:
|
||||
if hi_res_image.mode != "RGB":
|
||||
hi_res_image = hi_res_image.convert("RGB")
|
||||
if page_image:
|
||||
if page_image.mode != "RGB":
|
||||
page_image = page_image.convert("RGB")
|
||||
|
||||
"""
|
||||
if callable(self.vlm_options.prompt):
|
||||
user_prompt = self.vlm_options.prompt(page.parsed_page)
|
||||
else:
|
||||
@ -93,11 +148,12 @@ class HuggingFaceMlxModel(BaseVlmModel, HuggingFaceModelDownloadMixin):
|
||||
prompt = self.apply_chat_template(
|
||||
self.processor, self.config, user_prompt, num_images=1
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
_log.debug("start generating ...")
|
||||
"""
|
||||
prompt = self.get_user_prompt(page)
|
||||
|
||||
# Call model to generate:
|
||||
start_time = time.time()
|
||||
"""
|
||||
tokens: list[VlmPredictionToken] = []
|
||||
|
||||
output = ""
|
||||
@ -105,7 +161,7 @@ class HuggingFaceMlxModel(BaseVlmModel, HuggingFaceModelDownloadMixin):
|
||||
self.vlm_model,
|
||||
self.processor,
|
||||
prompt,
|
||||
[hi_res_image],
|
||||
[page_image],
|
||||
max_tokens=self.max_tokens,
|
||||
verbose=False,
|
||||
temp=self.temperature,
|
||||
@ -137,13 +193,20 @@ class HuggingFaceMlxModel(BaseVlmModel, HuggingFaceModelDownloadMixin):
|
||||
output += token.text
|
||||
if "</doctag>" in token.text:
|
||||
break
|
||||
"""
|
||||
output, tokens = self.predict_on_page_image(
|
||||
page_image=page_image, prompt=prompt, output_tokens=True
|
||||
)
|
||||
|
||||
generation_time = time.time() - start_time
|
||||
page_tags = output
|
||||
|
||||
"""
|
||||
_log.debug(
|
||||
f"{generation_time:.2f} seconds for {len(tokens)} tokens ({len(tokens) / generation_time} tokens/sec)."
|
||||
)
|
||||
"""
|
||||
|
||||
page.predictions.vlm_response = VlmPrediction(
|
||||
text=page_tags,
|
||||
generation_time=generation_time,
|
||||
|
@ -61,64 +61,24 @@ class TwoStageVlmModel(BasePageModel, HuggingFaceModelDownloadMixin):
|
||||
page=page, clusters=pred_clusters
|
||||
)
|
||||
)
|
||||
|
||||
# Define prompt structure
|
||||
if callable(self.vlm_options.prompt):
|
||||
user_prompt = self.vlm_options.prompt(page.parsed_page)
|
||||
else:
|
||||
user_prompt = self.vlm_options.prompt
|
||||
|
||||
prompt = self.formulate_prompt(user_prompt, processed_clusters)
|
||||
user_prompt = self.vlm_model.get_user_prompt(page=page)
|
||||
prompt = self.formulate_prompt(
|
||||
user_prompt=user_prompt, clusters=processed_clusters
|
||||
)
|
||||
|
||||
generated_text, generation_time = self.vlm_model.predict_on_image(
|
||||
start_time = time.time()
|
||||
generated_text = self.vlm_model.predict_on_page_image(
|
||||
page_image=page_image, prompt=prompt
|
||||
)
|
||||
|
||||
page.predictions.vlm_response = VlmPrediction(
|
||||
text=generated_text,
|
||||
generation_time=generation_time,
|
||||
text=generated_text, generation_time=time.time() - start_time
|
||||
)
|
||||
|
||||
yield page
|
||||
|
||||
def formulate_prompt(self, user_prompt: str, clusters: list[Cluster]) -> str:
|
||||
def formulate_prompt(self, *, user_prompt: str, clusters: list[Cluster]) -> str:
|
||||
"""Formulate a prompt for the VLM."""
|
||||
|
||||
if self.vlm_options.transformers_prompt_style == TransformersPromptStyle.RAW:
|
||||
return user_prompt
|
||||
|
||||
elif self.vlm_options.repo_id == "microsoft/Phi-4-multimodal-instruct":
|
||||
_log.debug("Using specialized prompt for Phi-4")
|
||||
# more info here: https://huggingface.co/microsoft/Phi-4-multimodal-instruct#loading-the-model-locally
|
||||
|
||||
user_prompt = "<|user|>"
|
||||
assistant_prompt = "<|assistant|>"
|
||||
prompt_suffix = "<|end|>"
|
||||
|
||||
prompt = f"{user_prompt}<|image_1|>{user_prompt}{prompt_suffix}{assistant_prompt}"
|
||||
_log.debug(f"prompt for {self.vlm_options.repo_id}: {prompt}")
|
||||
|
||||
return prompt
|
||||
|
||||
elif self.vlm_options.transformers_prompt_style == TransformersPromptStyle.CHAT:
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "This is a page from a document.",
|
||||
},
|
||||
{"type": "image"},
|
||||
{"type": "text", "text": user_prompt},
|
||||
],
|
||||
}
|
||||
]
|
||||
prompt = self.processor.apply_chat_template(
|
||||
messages, add_generation_prompt=False
|
||||
)
|
||||
return prompt
|
||||
|
||||
raise RuntimeError(
|
||||
f"Uknown prompt style `{self.vlm_options.transformers_prompt_style}`. Valid values are {', '.join(s.value for s in TransformersPromptStyle)}."
|
||||
)
|
||||
return user_prompt
|
||||
|
@ -26,12 +26,13 @@ from docling.backend.md_backend import MarkdownDocumentBackend
|
||||
from docling.backend.pdf_backend import PdfDocumentBackend
|
||||
from docling.datamodel.base_models import InputFormat, Page
|
||||
from docling.datamodel.document import ConversionResult, InputDocument
|
||||
from docling.datamodel.pipeline_options import TwoStageVlmOptions, VlmPipelineOptions
|
||||
from docling.datamodel.pipeline_options import VlmPipelineOptions
|
||||
from docling.datamodel.pipeline_options_vlm_model import (
|
||||
ApiVlmOptions,
|
||||
InferenceFramework,
|
||||
InlineVlmOptions,
|
||||
ResponseFormat,
|
||||
TwoStageVlmOptions,
|
||||
)
|
||||
from docling.datamodel.settings import settings
|
||||
from docling.models.api_vlm_model import ApiVlmModel
|
||||
|
Loading…
Reference in New Issue
Block a user