mirror of
https://github.com/DS4SD/docling.git
synced 2025-07-26 20:14:47 +00:00
refactoring redundant code and fixing mypy errors
Signed-off-by: Peter Staar <taa@zurich.ibm.com>
This commit is contained in:
parent
b5479ab971
commit
c10e2920a4
@ -11,6 +11,7 @@ from docling.datamodel.base_models import (
|
|||||||
ItemAndImageEnrichmentElement,
|
ItemAndImageEnrichmentElement,
|
||||||
Page,
|
Page,
|
||||||
TextCell,
|
TextCell,
|
||||||
|
VlmPredictionToken,
|
||||||
)
|
)
|
||||||
from docling.datamodel.document import ConversionResult
|
from docling.datamodel.document import ConversionResult
|
||||||
from docling.datamodel.pipeline_options import BaseOptions
|
from docling.datamodel.pipeline_options import BaseOptions
|
||||||
@ -49,7 +50,13 @@ class BaseLayoutModel(BasePageModel):
|
|||||||
|
|
||||||
class BaseVlmModel(BasePageModel):
|
class BaseVlmModel(BasePageModel):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def predict_on_page_image(self, *, page_image: Image.Image, prompt: str) -> str:
|
def get_user_prompt(self, page: Optional[Page]) -> str:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def predict_on_page_image(
|
||||||
|
self, *, page_image: Image.Image, prompt: str, output_tokens: bool = False
|
||||||
|
) -> tuple[str, Optional[list[VlmPredictionToken]]]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@ -38,7 +38,7 @@ class HuggingFaceTransformersVlmModel(BaseVlmModel, HuggingFaceModelDownloadMixi
|
|||||||
self.vlm_options = vlm_options
|
self.vlm_options = vlm_options
|
||||||
|
|
||||||
self.scale = self.vlm_options.scale
|
self.scale = self.vlm_options.scale
|
||||||
self.max_size = self.vlm_options.max_size
|
# self.max_size = self.vlm_options.max_size
|
||||||
|
|
||||||
if self.enabled:
|
if self.enabled:
|
||||||
import torch
|
import torch
|
||||||
|
@ -4,6 +4,8 @@ from collections.abc import Iterable
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
from docling.datamodel.accelerator_options import (
|
from docling.datamodel.accelerator_options import (
|
||||||
AcceleratorOptions,
|
AcceleratorOptions,
|
||||||
)
|
)
|
||||||
@ -33,7 +35,7 @@ class HuggingFaceMlxModel(BaseVlmModel, HuggingFaceModelDownloadMixin):
|
|||||||
self.max_tokens = vlm_options.max_new_tokens
|
self.max_tokens = vlm_options.max_new_tokens
|
||||||
self.temperature = vlm_options.temperature
|
self.temperature = vlm_options.temperature
|
||||||
self.scale = self.vlm_options.scale
|
self.scale = self.vlm_options.scale
|
||||||
self.max_size = self.vlm_options.max_size
|
# self.max_size = self.vlm_options.max_size
|
||||||
|
|
||||||
if self.enabled:
|
if self.enabled:
|
||||||
try:
|
try:
|
||||||
@ -62,6 +64,55 @@ class HuggingFaceMlxModel(BaseVlmModel, HuggingFaceModelDownloadMixin):
|
|||||||
self.vlm_model, self.processor = load(artifacts_path)
|
self.vlm_model, self.processor = load(artifacts_path)
|
||||||
self.config = load_config(artifacts_path)
|
self.config = load_config(artifacts_path)
|
||||||
|
|
||||||
|
def get_user_prompt(self, page: Optional[Page]) -> str:
|
||||||
|
if callable(self.vlm_options.prompt) and page is not None:
|
||||||
|
return self.vlm_options.prompt(page.parsed_page)
|
||||||
|
else:
|
||||||
|
user_prompt = self.vlm_options.prompt
|
||||||
|
prompt = self.apply_chat_template(
|
||||||
|
self.processor, self.config, user_prompt, num_images=1
|
||||||
|
)
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
def predict_on_page_image(
|
||||||
|
self, *, page_image: Image.Image, prompt: str, output_tokens: bool = False
|
||||||
|
) -> tuple[str, Optional[list[VlmPredictionToken]]]:
|
||||||
|
tokens = []
|
||||||
|
output = ""
|
||||||
|
for token in self.stream_generate(
|
||||||
|
self.vlm_model,
|
||||||
|
self.processor,
|
||||||
|
prompt,
|
||||||
|
[page_image],
|
||||||
|
max_tokens=self.max_tokens,
|
||||||
|
verbose=False,
|
||||||
|
temp=self.temperature,
|
||||||
|
):
|
||||||
|
if len(token.logprobs.shape) == 1:
|
||||||
|
tokens.append(
|
||||||
|
VlmPredictionToken(
|
||||||
|
text=token.text,
|
||||||
|
token=token.token,
|
||||||
|
logprob=token.logprobs[token.token],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif len(token.logprobs.shape) == 2 and token.logprobs.shape[0] == 1:
|
||||||
|
tokens.append(
|
||||||
|
VlmPredictionToken(
|
||||||
|
text=token.text,
|
||||||
|
token=token.token,
|
||||||
|
logprob=token.logprobs[0, token.token],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
_log.warning(f"incompatible shape for logprobs: {token.logprobs.shape}")
|
||||||
|
|
||||||
|
output += token.text
|
||||||
|
if "</doctag>" in token.text:
|
||||||
|
break
|
||||||
|
|
||||||
|
return output, tokens
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self, conv_res: ConversionResult, page_batch: Iterable[Page]
|
self, conv_res: ConversionResult, page_batch: Iterable[Page]
|
||||||
) -> Iterable[Page]:
|
) -> Iterable[Page]:
|
||||||
@ -73,19 +124,23 @@ class HuggingFaceMlxModel(BaseVlmModel, HuggingFaceModelDownloadMixin):
|
|||||||
with TimeRecorder(conv_res, f"vlm-mlx-{self.vlm_options.repo_id}"):
|
with TimeRecorder(conv_res, f"vlm-mlx-{self.vlm_options.repo_id}"):
|
||||||
assert page.size is not None
|
assert page.size is not None
|
||||||
|
|
||||||
hi_res_image = page.get_image(
|
page_image = page.get_image(
|
||||||
scale=self.vlm_options.scale, max_size=self.vlm_options.max_size
|
scale=self.vlm_options.scale, max_size=self.vlm_options.max_size
|
||||||
)
|
)
|
||||||
if hi_res_image is not None:
|
"""
|
||||||
im_width, im_height = hi_res_image.size
|
if page_image is not None:
|
||||||
|
im_width, im_height = page_image.size
|
||||||
|
"""
|
||||||
|
assert page_image is not None
|
||||||
|
|
||||||
# populate page_tags with predicted doc tags
|
# populate page_tags with predicted doc tags
|
||||||
page_tags = ""
|
page_tags = ""
|
||||||
|
|
||||||
if hi_res_image:
|
if page_image:
|
||||||
if hi_res_image.mode != "RGB":
|
if page_image.mode != "RGB":
|
||||||
hi_res_image = hi_res_image.convert("RGB")
|
page_image = page_image.convert("RGB")
|
||||||
|
|
||||||
|
"""
|
||||||
if callable(self.vlm_options.prompt):
|
if callable(self.vlm_options.prompt):
|
||||||
user_prompt = self.vlm_options.prompt(page.parsed_page)
|
user_prompt = self.vlm_options.prompt(page.parsed_page)
|
||||||
else:
|
else:
|
||||||
@ -93,11 +148,12 @@ class HuggingFaceMlxModel(BaseVlmModel, HuggingFaceModelDownloadMixin):
|
|||||||
prompt = self.apply_chat_template(
|
prompt = self.apply_chat_template(
|
||||||
self.processor, self.config, user_prompt, num_images=1
|
self.processor, self.config, user_prompt, num_images=1
|
||||||
)
|
)
|
||||||
|
"""
|
||||||
start_time = time.time()
|
prompt = self.get_user_prompt(page)
|
||||||
_log.debug("start generating ...")
|
|
||||||
|
|
||||||
# Call model to generate:
|
# Call model to generate:
|
||||||
|
start_time = time.time()
|
||||||
|
"""
|
||||||
tokens: list[VlmPredictionToken] = []
|
tokens: list[VlmPredictionToken] = []
|
||||||
|
|
||||||
output = ""
|
output = ""
|
||||||
@ -105,7 +161,7 @@ class HuggingFaceMlxModel(BaseVlmModel, HuggingFaceModelDownloadMixin):
|
|||||||
self.vlm_model,
|
self.vlm_model,
|
||||||
self.processor,
|
self.processor,
|
||||||
prompt,
|
prompt,
|
||||||
[hi_res_image],
|
[page_image],
|
||||||
max_tokens=self.max_tokens,
|
max_tokens=self.max_tokens,
|
||||||
verbose=False,
|
verbose=False,
|
||||||
temp=self.temperature,
|
temp=self.temperature,
|
||||||
@ -137,13 +193,20 @@ class HuggingFaceMlxModel(BaseVlmModel, HuggingFaceModelDownloadMixin):
|
|||||||
output += token.text
|
output += token.text
|
||||||
if "</doctag>" in token.text:
|
if "</doctag>" in token.text:
|
||||||
break
|
break
|
||||||
|
"""
|
||||||
|
output, tokens = self.predict_on_page_image(
|
||||||
|
page_image=page_image, prompt=prompt, output_tokens=True
|
||||||
|
)
|
||||||
|
|
||||||
generation_time = time.time() - start_time
|
generation_time = time.time() - start_time
|
||||||
page_tags = output
|
page_tags = output
|
||||||
|
|
||||||
|
"""
|
||||||
_log.debug(
|
_log.debug(
|
||||||
f"{generation_time:.2f} seconds for {len(tokens)} tokens ({len(tokens) / generation_time} tokens/sec)."
|
f"{generation_time:.2f} seconds for {len(tokens)} tokens ({len(tokens) / generation_time} tokens/sec)."
|
||||||
)
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
page.predictions.vlm_response = VlmPrediction(
|
page.predictions.vlm_response = VlmPrediction(
|
||||||
text=page_tags,
|
text=page_tags,
|
||||||
generation_time=generation_time,
|
generation_time=generation_time,
|
||||||
|
@ -62,63 +62,23 @@ class TwoStageVlmModel(BasePageModel, HuggingFaceModelDownloadMixin):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Define prompt structure
|
user_prompt = self.vlm_model.get_user_prompt(page=page)
|
||||||
if callable(self.vlm_options.prompt):
|
prompt = self.formulate_prompt(
|
||||||
user_prompt = self.vlm_options.prompt(page.parsed_page)
|
user_prompt=user_prompt, clusters=processed_clusters
|
||||||
else:
|
)
|
||||||
user_prompt = self.vlm_options.prompt
|
|
||||||
|
|
||||||
prompt = self.formulate_prompt(user_prompt, processed_clusters)
|
start_time = time.time()
|
||||||
|
generated_text = self.vlm_model.predict_on_page_image(
|
||||||
generated_text, generation_time = self.vlm_model.predict_on_image(
|
|
||||||
page_image=page_image, prompt=prompt
|
page_image=page_image, prompt=prompt
|
||||||
)
|
)
|
||||||
|
|
||||||
page.predictions.vlm_response = VlmPrediction(
|
page.predictions.vlm_response = VlmPrediction(
|
||||||
text=generated_text,
|
text=generated_text, generation_time=time.time() - start_time
|
||||||
generation_time=generation_time,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
yield page
|
yield page
|
||||||
|
|
||||||
def formulate_prompt(self, user_prompt: str, clusters: list[Cluster]) -> str:
|
def formulate_prompt(self, *, user_prompt: str, clusters: list[Cluster]) -> str:
|
||||||
"""Formulate a prompt for the VLM."""
|
"""Formulate a prompt for the VLM."""
|
||||||
|
|
||||||
if self.vlm_options.transformers_prompt_style == TransformersPromptStyle.RAW:
|
|
||||||
return user_prompt
|
return user_prompt
|
||||||
|
|
||||||
elif self.vlm_options.repo_id == "microsoft/Phi-4-multimodal-instruct":
|
|
||||||
_log.debug("Using specialized prompt for Phi-4")
|
|
||||||
# more info here: https://huggingface.co/microsoft/Phi-4-multimodal-instruct#loading-the-model-locally
|
|
||||||
|
|
||||||
user_prompt = "<|user|>"
|
|
||||||
assistant_prompt = "<|assistant|>"
|
|
||||||
prompt_suffix = "<|end|>"
|
|
||||||
|
|
||||||
prompt = f"{user_prompt}<|image_1|>{user_prompt}{prompt_suffix}{assistant_prompt}"
|
|
||||||
_log.debug(f"prompt for {self.vlm_options.repo_id}: {prompt}")
|
|
||||||
|
|
||||||
return prompt
|
|
||||||
|
|
||||||
elif self.vlm_options.transformers_prompt_style == TransformersPromptStyle.CHAT:
|
|
||||||
messages = [
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": [
|
|
||||||
{
|
|
||||||
"type": "text",
|
|
||||||
"text": "This is a page from a document.",
|
|
||||||
},
|
|
||||||
{"type": "image"},
|
|
||||||
{"type": "text", "text": user_prompt},
|
|
||||||
],
|
|
||||||
}
|
|
||||||
]
|
|
||||||
prompt = self.processor.apply_chat_template(
|
|
||||||
messages, add_generation_prompt=False
|
|
||||||
)
|
|
||||||
return prompt
|
|
||||||
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Uknown prompt style `{self.vlm_options.transformers_prompt_style}`. Valid values are {', '.join(s.value for s in TransformersPromptStyle)}."
|
|
||||||
)
|
|
||||||
|
@ -26,12 +26,13 @@ from docling.backend.md_backend import MarkdownDocumentBackend
|
|||||||
from docling.backend.pdf_backend import PdfDocumentBackend
|
from docling.backend.pdf_backend import PdfDocumentBackend
|
||||||
from docling.datamodel.base_models import InputFormat, Page
|
from docling.datamodel.base_models import InputFormat, Page
|
||||||
from docling.datamodel.document import ConversionResult, InputDocument
|
from docling.datamodel.document import ConversionResult, InputDocument
|
||||||
from docling.datamodel.pipeline_options import TwoStageVlmOptions, VlmPipelineOptions
|
from docling.datamodel.pipeline_options import VlmPipelineOptions
|
||||||
from docling.datamodel.pipeline_options_vlm_model import (
|
from docling.datamodel.pipeline_options_vlm_model import (
|
||||||
ApiVlmOptions,
|
ApiVlmOptions,
|
||||||
InferenceFramework,
|
InferenceFramework,
|
||||||
InlineVlmOptions,
|
InlineVlmOptions,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
|
TwoStageVlmOptions,
|
||||||
)
|
)
|
||||||
from docling.datamodel.settings import settings
|
from docling.datamodel.settings import settings
|
||||||
from docling.models.api_vlm_model import ApiVlmModel
|
from docling.models.api_vlm_model import ApiVlmModel
|
||||||
|
Loading…
Reference in New Issue
Block a user