feat: add TwoStageVlmModel

Signed-off-by: Peter Staar <taa@zurich.ibm.com>
This commit is contained in:
Peter Staar 2025-07-08 07:38:48 +02:00
parent a07ba863c4
commit 4eceefa47c
5 changed files with 223 additions and 3 deletions

View File

@ -269,6 +269,7 @@ class VlmPipelineOptions(PaginatedPipelineOptions):
class LayoutOptions(BaseModel):
"""Options for layout processing."""
repo_id: str = "ds4sd/docling-layout-heron"
create_orphan_clusters: bool = True # Whether to create clusters for orphaned cells

View File

@ -87,3 +87,10 @@ class ApiVlmOptions(BaseVlmOptions):
timeout: float = 60
concurrency: int = 1
response_format: ResponseFormat
class TwoStageVlmOptions(BaseVlmOptions):
kind: Literal["inline_model_options"] = "inline_two_stage_model_options"
vlm_options: UnionInlineVlmOptions
layout_options: LayoutOptions

View File

@ -156,6 +156,7 @@ class LayoutModel(BasePageModel):
page_image = page.get_image(scale=1.0)
assert page_image is not None
"""
clusters = []
for ix, pred_item in enumerate(
self.layout_predictor.predict(page_image)
@ -174,14 +175,16 @@ class LayoutModel(BasePageModel):
cells=[],
)
clusters.append(cluster)
"""
clusters = self.predict_on_page(page_image)
if settings.debug.visualize_raw_layout:
self.draw_clusters_and_cells_side_by_side(
conv_res, page, clusters, mode_prefix="raw"
)
# Apply postprocessing
"""
processed_clusters, processed_cells = LayoutPostprocessor(
page, clusters, self.options
).postprocess()
@ -208,10 +211,66 @@ class LayoutModel(BasePageModel):
page.predictions.layout = LayoutPrediction(
clusters=processed_clusters
)
"""
page, processed_clusters, processed_cells = self.postprocess_on_page(page, cluster)
if settings.debug.visualize_layout:
self.draw_clusters_and_cells_side_by_side(
conv_res, page, processed_clusters, mode_prefix="postprocessed"
)
yield page
def predict_on_page(self, page_image: Image) -> list[Cluster]:
pred_items = self.layout_predictor.predict(page_image)
clusters = []
for ix, pred_item in enumerate(pred_items):
label = DocItemLabel(
pred_item["label"]
.lower()
.replace(" ", "_")
.replace("-", "_")
) # Temporary, until docling-ibm-model uses docling-core types
cluster = Cluster(
id=ix,
label=label,
confidence=pred_item["confidence"],
bbox=BoundingBox.model_validate(pred_item),
cells=[],
)
clusters.append(cluster)
return clusters
def postprocess_on_page(self, page: Page, cluster: list(Cluster)):
processed_clusters, processed_cells = LayoutPostprocessor(
page, clusters, self.options
).postprocess()
# Note: LayoutPostprocessor updates page.cells and page.parsed_page internally
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
"Mean of empty slice|invalid value encountered in scalar divide",
RuntimeWarning,
"numpy",
)
conv_res.confidence.pages[page.page_no].layout_score = float(
np.mean([c.confidence for c in processed_clusters])
)
conv_res.confidence.pages[page.page_no].ocr_score = float(
np.mean(
[c.confidence for c in processed_cells if c.from_ocr]
)
)
page.predictions.layout = LayoutPrediction(
clusters=processed_clusters
)
return page, processed_clusters, processed_cells

View File

@ -0,0 +1,115 @@
import importlib.metadata
import logging
import time
from collections.abc import Iterable
from pathlib import Path
from typing import Any, Optional
from docling.datamodel.accelerator_options import (
AcceleratorOptions,
)
from docling.datamodel.base_models import Page, VlmPrediction
from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options_vlm_model import (
InlineVlmOptions,
TransformersModelType,
TransformersPromptStyle,
)
from docling.models.base_model import BasePageModel
from docling.models.utils.hf_model_download import (
HuggingFaceModelDownloadMixin,
)
from docling.utils.accelerator_utils import decide_device
from docling.utils.profiling import TimeRecorder
_log = logging.getLogger(__name__)
class TwoStageVlmModel(BasePageModel, HuggingFaceModelDownloadMixin):
def __init__(
self,
*,
layout_model: LayoutModelModel,
vlm_model: BasePageModel,
):
self.layout_model = layout_options
self.vlm_model = vlm_options
def __call__(
self, conv_res: ConversionResult, page_batch: Iterable[Page]
) -> Iterable[Page]:
for page in page_batch:
assert page._backend is not None
if not page._backend.is_valid():
yield page
else:
with TimeRecorder(conv_res, "two-staged-vlm"):
assert page.size is not None
page_image = page.get_image(
scale=self.vlm_options.scale, max_size=self.vlm_options.max_size
)
pred_clusters = self.layout_model.predict_on_page(page_image)
page, processed_clusters, processed_cells = self.layout_model.postprocess_on_page(page=page,
page_image=page_image)
# Define prompt structure
if callable(self.vlm_options.prompt):
user_prompt = self.vlm_options.prompt(page.parsed_page)
else:
user_prompt = self.vlm_options.prompt
prompt = self.formulate_prompt(user_prompt, processed_clusters)
generated_text, generation_time = self.vlm_model.predict_on_image(page_image=page_image,
prompt=prompt)
page.predictions.vlm_response = VlmPrediction(
text=generated_text,
generation_time=generation_time,
)
yield page
def formulate_prompt(self, user_prompt: str, clusters:list[Cluster]) -> str:
"""Formulate a prompt for the VLM."""
if self.vlm_options.transformers_prompt_style == TransformersPromptStyle.RAW:
return user_prompt
elif self.vlm_options.repo_id == "microsoft/Phi-4-multimodal-instruct":
_log.debug("Using specialized prompt for Phi-4")
# more info here: https://huggingface.co/microsoft/Phi-4-multimodal-instruct#loading-the-model-locally
user_prompt = "<|user|>"
assistant_prompt = "<|assistant|>"
prompt_suffix = "<|end|>"
prompt = f"{user_prompt}<|image_1|>{user_prompt}{prompt_suffix}{assistant_prompt}"
_log.debug(f"prompt for {self.vlm_options.repo_id}: {prompt}")
return prompt
elif self.vlm_options.transformers_prompt_style == TransformersPromptStyle.CHAT:
messages = [
{
"role": "user",
"content": [
{
"type": "text",
"text": "This is a page from a document.",
},
{"type": "image"},
{"type": "text", "text": user_prompt},
],
}
]
prompt = self.processor.apply_chat_template(
messages, add_generation_prompt=False
)
return prompt
raise RuntimeError(
f"Uknown prompt style `{self.vlm_options.transformers_prompt_style}`. Valid values are {', '.join(s.value for s in TransformersPromptStyle)}."
)

View File

@ -44,6 +44,8 @@ from docling.models.vlm_models_inline.mlx_model import HuggingFaceMlxModel
from docling.pipeline.base_pipeline import PaginatedPipeline
from docling.utils.profiling import ProfilingScope, TimeRecorder
from docling.models.layout_model import LayoutModel
_log = logging.getLogger(__name__)
@ -107,7 +109,43 @@ class VlmPipeline(PaginatedPipeline):
raise ValueError(
f"Could not instantiate the right type of VLM pipeline: {vlm_options.inference_framework}"
)
elif isinstance(self.pipeline_options.vlm_options, TwoStageVlmOptions):
twostagevlm_options = cast(TwoStageVlmOptions, self.pipeline_options.vlm_options)
layout_options = twostagevlm_options.lay_options
vlm_options = twostagevlm_options.vlm_options
layout_model = LayoutModel(
artifacts_path=artifacts_path,
accelerator_options=pipeline_options.accelerator_options,
options=layout_options,
)
if vlm_options.inference_framework == InferenceFramework.MLX:
vlm_model = HuggingFaceMlxModel(
enabled=True, # must be always enabled for this pipeline to make sense.
artifacts_path=artifacts_path,
accelerator_options=pipeline_options.accelerator_options,
vlm_options=vlm_options,
)
self.build_pipe = [
TwoStageVlmModel(layout_model=layout_model, vlm_model=vlm_model)
]
elif vlm_options.inference_framework == InferenceFramework.TRANSFORMERS:
vlm_model = HuggingFaceTransformersVlmModel(
enabled=True, # must be always enabled for this pipeline to make sense.
artifacts_path=artifacts_path,
accelerator_options=pipeline_options.accelerator_options,
vlm_options=vlm_options,
)
self.build_pipe = [
TwoStageVlmModel(layout_model=layout_model, vlm_model=vlm_model)
]
else:
raise ValueError(
f"Could not instantiate the right type of VLM pipeline: {vlm_options.inference_framework}"
)
self.enrichment_pipe = [
# Other models working on `NodeItem` elements in the DoclingDocument
]