mirror of
https://github.com/DS4SD/docling.git
synced 2025-07-25 19:44:34 +00:00
feat: add TwoStageVlmModel
Signed-off-by: Peter Staar <taa@zurich.ibm.com>
This commit is contained in:
parent
a07ba863c4
commit
4eceefa47c
@ -269,6 +269,7 @@ class VlmPipelineOptions(PaginatedPipelineOptions):
|
||||
class LayoutOptions(BaseModel):
|
||||
"""Options for layout processing."""
|
||||
|
||||
repo_id: str = "ds4sd/docling-layout-heron"
|
||||
create_orphan_clusters: bool = True # Whether to create clusters for orphaned cells
|
||||
|
||||
|
||||
|
@ -87,3 +87,10 @@ class ApiVlmOptions(BaseVlmOptions):
|
||||
timeout: float = 60
|
||||
concurrency: int = 1
|
||||
response_format: ResponseFormat
|
||||
|
||||
|
||||
class TwoStageVlmOptions(BaseVlmOptions):
|
||||
kind: Literal["inline_model_options"] = "inline_two_stage_model_options"
|
||||
|
||||
vlm_options: UnionInlineVlmOptions
|
||||
layout_options: LayoutOptions
|
||||
|
@ -156,6 +156,7 @@ class LayoutModel(BasePageModel):
|
||||
page_image = page.get_image(scale=1.0)
|
||||
assert page_image is not None
|
||||
|
||||
"""
|
||||
clusters = []
|
||||
for ix, pred_item in enumerate(
|
||||
self.layout_predictor.predict(page_image)
|
||||
@ -174,14 +175,16 @@ class LayoutModel(BasePageModel):
|
||||
cells=[],
|
||||
)
|
||||
clusters.append(cluster)
|
||||
|
||||
"""
|
||||
clusters = self.predict_on_page(page_image)
|
||||
|
||||
if settings.debug.visualize_raw_layout:
|
||||
self.draw_clusters_and_cells_side_by_side(
|
||||
conv_res, page, clusters, mode_prefix="raw"
|
||||
)
|
||||
|
||||
# Apply postprocessing
|
||||
|
||||
"""
|
||||
processed_clusters, processed_cells = LayoutPostprocessor(
|
||||
page, clusters, self.options
|
||||
).postprocess()
|
||||
@ -208,10 +211,66 @@ class LayoutModel(BasePageModel):
|
||||
page.predictions.layout = LayoutPrediction(
|
||||
clusters=processed_clusters
|
||||
)
|
||||
|
||||
"""
|
||||
page, processed_clusters, processed_cells = self.postprocess_on_page(page, cluster)
|
||||
|
||||
if settings.debug.visualize_layout:
|
||||
self.draw_clusters_and_cells_side_by_side(
|
||||
conv_res, page, processed_clusters, mode_prefix="postprocessed"
|
||||
)
|
||||
|
||||
yield page
|
||||
|
||||
def predict_on_page(self, page_image: Image) -> list[Cluster]:
|
||||
|
||||
pred_items = self.layout_predictor.predict(page_image)
|
||||
|
||||
clusters = []
|
||||
for ix, pred_item in enumerate(pred_items):
|
||||
label = DocItemLabel(
|
||||
pred_item["label"]
|
||||
.lower()
|
||||
.replace(" ", "_")
|
||||
.replace("-", "_")
|
||||
) # Temporary, until docling-ibm-model uses docling-core types
|
||||
cluster = Cluster(
|
||||
id=ix,
|
||||
label=label,
|
||||
confidence=pred_item["confidence"],
|
||||
bbox=BoundingBox.model_validate(pred_item),
|
||||
cells=[],
|
||||
)
|
||||
clusters.append(cluster)
|
||||
|
||||
return clusters
|
||||
|
||||
def postprocess_on_page(self, page: Page, cluster: list(Cluster)):
|
||||
|
||||
processed_clusters, processed_cells = LayoutPostprocessor(
|
||||
page, clusters, self.options
|
||||
).postprocess()
|
||||
# Note: LayoutPostprocessor updates page.cells and page.parsed_page internally
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
"Mean of empty slice|invalid value encountered in scalar divide",
|
||||
RuntimeWarning,
|
||||
"numpy",
|
||||
)
|
||||
|
||||
conv_res.confidence.pages[page.page_no].layout_score = float(
|
||||
np.mean([c.confidence for c in processed_clusters])
|
||||
)
|
||||
|
||||
conv_res.confidence.pages[page.page_no].ocr_score = float(
|
||||
np.mean(
|
||||
[c.confidence for c in processed_cells if c.from_ocr]
|
||||
)
|
||||
)
|
||||
|
||||
page.predictions.layout = LayoutPrediction(
|
||||
clusters=processed_clusters
|
||||
)
|
||||
|
||||
return page, processed_clusters, processed_cells
|
||||
|
115
docling/models/vlm_models_inline/TwoStageVlmModel.py
Normal file
115
docling/models/vlm_models_inline/TwoStageVlmModel.py
Normal file
@ -0,0 +1,115 @@
|
||||
import importlib.metadata
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Iterable
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
from docling.datamodel.accelerator_options import (
|
||||
AcceleratorOptions,
|
||||
)
|
||||
from docling.datamodel.base_models import Page, VlmPrediction
|
||||
from docling.datamodel.document import ConversionResult
|
||||
from docling.datamodel.pipeline_options_vlm_model import (
|
||||
InlineVlmOptions,
|
||||
TransformersModelType,
|
||||
TransformersPromptStyle,
|
||||
)
|
||||
from docling.models.base_model import BasePageModel
|
||||
from docling.models.utils.hf_model_download import (
|
||||
HuggingFaceModelDownloadMixin,
|
||||
)
|
||||
from docling.utils.accelerator_utils import decide_device
|
||||
from docling.utils.profiling import TimeRecorder
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TwoStageVlmModel(BasePageModel, HuggingFaceModelDownloadMixin):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
layout_model: LayoutModelModel,
|
||||
vlm_model: BasePageModel,
|
||||
):
|
||||
self.layout_model = layout_options
|
||||
self.vlm_model = vlm_options
|
||||
|
||||
def __call__(
|
||||
self, conv_res: ConversionResult, page_batch: Iterable[Page]
|
||||
) -> Iterable[Page]:
|
||||
for page in page_batch:
|
||||
assert page._backend is not None
|
||||
if not page._backend.is_valid():
|
||||
yield page
|
||||
else:
|
||||
with TimeRecorder(conv_res, "two-staged-vlm"):
|
||||
assert page.size is not None
|
||||
|
||||
page_image = page.get_image(
|
||||
scale=self.vlm_options.scale, max_size=self.vlm_options.max_size
|
||||
)
|
||||
|
||||
pred_clusters = self.layout_model.predict_on_page(page_image)
|
||||
page, processed_clusters, processed_cells = self.layout_model.postprocess_on_page(page=page,
|
||||
page_image=page_image)
|
||||
|
||||
# Define prompt structure
|
||||
if callable(self.vlm_options.prompt):
|
||||
user_prompt = self.vlm_options.prompt(page.parsed_page)
|
||||
else:
|
||||
user_prompt = self.vlm_options.prompt
|
||||
|
||||
prompt = self.formulate_prompt(user_prompt, processed_clusters)
|
||||
|
||||
generated_text, generation_time = self.vlm_model.predict_on_image(page_image=page_image,
|
||||
prompt=prompt)
|
||||
|
||||
page.predictions.vlm_response = VlmPrediction(
|
||||
text=generated_text,
|
||||
generation_time=generation_time,
|
||||
)
|
||||
|
||||
yield page
|
||||
|
||||
def formulate_prompt(self, user_prompt: str, clusters:list[Cluster]) -> str:
|
||||
"""Formulate a prompt for the VLM."""
|
||||
|
||||
if self.vlm_options.transformers_prompt_style == TransformersPromptStyle.RAW:
|
||||
return user_prompt
|
||||
|
||||
elif self.vlm_options.repo_id == "microsoft/Phi-4-multimodal-instruct":
|
||||
_log.debug("Using specialized prompt for Phi-4")
|
||||
# more info here: https://huggingface.co/microsoft/Phi-4-multimodal-instruct#loading-the-model-locally
|
||||
|
||||
user_prompt = "<|user|>"
|
||||
assistant_prompt = "<|assistant|>"
|
||||
prompt_suffix = "<|end|>"
|
||||
|
||||
prompt = f"{user_prompt}<|image_1|>{user_prompt}{prompt_suffix}{assistant_prompt}"
|
||||
_log.debug(f"prompt for {self.vlm_options.repo_id}: {prompt}")
|
||||
|
||||
return prompt
|
||||
|
||||
elif self.vlm_options.transformers_prompt_style == TransformersPromptStyle.CHAT:
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "This is a page from a document.",
|
||||
},
|
||||
{"type": "image"},
|
||||
{"type": "text", "text": user_prompt},
|
||||
],
|
||||
}
|
||||
]
|
||||
prompt = self.processor.apply_chat_template(
|
||||
messages, add_generation_prompt=False
|
||||
)
|
||||
return prompt
|
||||
|
||||
raise RuntimeError(
|
||||
f"Uknown prompt style `{self.vlm_options.transformers_prompt_style}`. Valid values are {', '.join(s.value for s in TransformersPromptStyle)}."
|
||||
)
|
@ -44,6 +44,8 @@ from docling.models.vlm_models_inline.mlx_model import HuggingFaceMlxModel
|
||||
from docling.pipeline.base_pipeline import PaginatedPipeline
|
||||
from docling.utils.profiling import ProfilingScope, TimeRecorder
|
||||
|
||||
from docling.models.layout_model import LayoutModel
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -107,7 +109,43 @@ class VlmPipeline(PaginatedPipeline):
|
||||
raise ValueError(
|
||||
f"Could not instantiate the right type of VLM pipeline: {vlm_options.inference_framework}"
|
||||
)
|
||||
elif isinstance(self.pipeline_options.vlm_options, TwoStageVlmOptions):
|
||||
twostagevlm_options = cast(TwoStageVlmOptions, self.pipeline_options.vlm_options)
|
||||
|
||||
layout_options = twostagevlm_options.lay_options
|
||||
vlm_options = twostagevlm_options.vlm_options
|
||||
|
||||
layout_model = LayoutModel(
|
||||
artifacts_path=artifacts_path,
|
||||
accelerator_options=pipeline_options.accelerator_options,
|
||||
options=layout_options,
|
||||
)
|
||||
|
||||
if vlm_options.inference_framework == InferenceFramework.MLX:
|
||||
vlm_model = HuggingFaceMlxModel(
|
||||
enabled=True, # must be always enabled for this pipeline to make sense.
|
||||
artifacts_path=artifacts_path,
|
||||
accelerator_options=pipeline_options.accelerator_options,
|
||||
vlm_options=vlm_options,
|
||||
)
|
||||
self.build_pipe = [
|
||||
TwoStageVlmModel(layout_model=layout_model, vlm_model=vlm_model)
|
||||
]
|
||||
elif vlm_options.inference_framework == InferenceFramework.TRANSFORMERS:
|
||||
vlm_model = HuggingFaceTransformersVlmModel(
|
||||
enabled=True, # must be always enabled for this pipeline to make sense.
|
||||
artifacts_path=artifacts_path,
|
||||
accelerator_options=pipeline_options.accelerator_options,
|
||||
vlm_options=vlm_options,
|
||||
)
|
||||
self.build_pipe = [
|
||||
TwoStageVlmModel(layout_model=layout_model, vlm_model=vlm_model)
|
||||
]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Could not instantiate the right type of VLM pipeline: {vlm_options.inference_framework}"
|
||||
)
|
||||
|
||||
self.enrichment_pipe = [
|
||||
# Other models working on `NodeItem` elements in the DoclingDocument
|
||||
]
|
||||
|
Loading…
Reference in New Issue
Block a user