mirror of
https://github.com/DS4SD/docling.git
synced 2025-07-26 20:14:47 +00:00
feat: add TwoStageVlmModel
Signed-off-by: Peter Staar <taa@zurich.ibm.com>
This commit is contained in:
parent
a07ba863c4
commit
4eceefa47c
@ -269,6 +269,7 @@ class VlmPipelineOptions(PaginatedPipelineOptions):
|
|||||||
class LayoutOptions(BaseModel):
|
class LayoutOptions(BaseModel):
|
||||||
"""Options for layout processing."""
|
"""Options for layout processing."""
|
||||||
|
|
||||||
|
repo_id: str = "ds4sd/docling-layout-heron"
|
||||||
create_orphan_clusters: bool = True # Whether to create clusters for orphaned cells
|
create_orphan_clusters: bool = True # Whether to create clusters for orphaned cells
|
||||||
|
|
||||||
|
|
||||||
|
@ -87,3 +87,10 @@ class ApiVlmOptions(BaseVlmOptions):
|
|||||||
timeout: float = 60
|
timeout: float = 60
|
||||||
concurrency: int = 1
|
concurrency: int = 1
|
||||||
response_format: ResponseFormat
|
response_format: ResponseFormat
|
||||||
|
|
||||||
|
|
||||||
|
class TwoStageVlmOptions(BaseVlmOptions):
|
||||||
|
kind: Literal["inline_model_options"] = "inline_two_stage_model_options"
|
||||||
|
|
||||||
|
vlm_options: UnionInlineVlmOptions
|
||||||
|
layout_options: LayoutOptions
|
||||||
|
@ -156,6 +156,7 @@ class LayoutModel(BasePageModel):
|
|||||||
page_image = page.get_image(scale=1.0)
|
page_image = page.get_image(scale=1.0)
|
||||||
assert page_image is not None
|
assert page_image is not None
|
||||||
|
|
||||||
|
"""
|
||||||
clusters = []
|
clusters = []
|
||||||
for ix, pred_item in enumerate(
|
for ix, pred_item in enumerate(
|
||||||
self.layout_predictor.predict(page_image)
|
self.layout_predictor.predict(page_image)
|
||||||
@ -174,6 +175,8 @@ class LayoutModel(BasePageModel):
|
|||||||
cells=[],
|
cells=[],
|
||||||
)
|
)
|
||||||
clusters.append(cluster)
|
clusters.append(cluster)
|
||||||
|
"""
|
||||||
|
clusters = self.predict_on_page(page_image)
|
||||||
|
|
||||||
if settings.debug.visualize_raw_layout:
|
if settings.debug.visualize_raw_layout:
|
||||||
self.draw_clusters_and_cells_side_by_side(
|
self.draw_clusters_and_cells_side_by_side(
|
||||||
@ -181,7 +184,7 @@ class LayoutModel(BasePageModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Apply postprocessing
|
# Apply postprocessing
|
||||||
|
"""
|
||||||
processed_clusters, processed_cells = LayoutPostprocessor(
|
processed_clusters, processed_cells = LayoutPostprocessor(
|
||||||
page, clusters, self.options
|
page, clusters, self.options
|
||||||
).postprocess()
|
).postprocess()
|
||||||
@ -208,6 +211,8 @@ class LayoutModel(BasePageModel):
|
|||||||
page.predictions.layout = LayoutPrediction(
|
page.predictions.layout = LayoutPrediction(
|
||||||
clusters=processed_clusters
|
clusters=processed_clusters
|
||||||
)
|
)
|
||||||
|
"""
|
||||||
|
page, processed_clusters, processed_cells = self.postprocess_on_page(page, cluster)
|
||||||
|
|
||||||
if settings.debug.visualize_layout:
|
if settings.debug.visualize_layout:
|
||||||
self.draw_clusters_and_cells_side_by_side(
|
self.draw_clusters_and_cells_side_by_side(
|
||||||
@ -215,3 +220,57 @@ class LayoutModel(BasePageModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
yield page
|
yield page
|
||||||
|
|
||||||
|
def predict_on_page(self, page_image: Image) -> list[Cluster]:
|
||||||
|
|
||||||
|
pred_items = self.layout_predictor.predict(page_image)
|
||||||
|
|
||||||
|
clusters = []
|
||||||
|
for ix, pred_item in enumerate(pred_items):
|
||||||
|
label = DocItemLabel(
|
||||||
|
pred_item["label"]
|
||||||
|
.lower()
|
||||||
|
.replace(" ", "_")
|
||||||
|
.replace("-", "_")
|
||||||
|
) # Temporary, until docling-ibm-model uses docling-core types
|
||||||
|
cluster = Cluster(
|
||||||
|
id=ix,
|
||||||
|
label=label,
|
||||||
|
confidence=pred_item["confidence"],
|
||||||
|
bbox=BoundingBox.model_validate(pred_item),
|
||||||
|
cells=[],
|
||||||
|
)
|
||||||
|
clusters.append(cluster)
|
||||||
|
|
||||||
|
return clusters
|
||||||
|
|
||||||
|
def postprocess_on_page(self, page: Page, cluster: list(Cluster)):
|
||||||
|
|
||||||
|
processed_clusters, processed_cells = LayoutPostprocessor(
|
||||||
|
page, clusters, self.options
|
||||||
|
).postprocess()
|
||||||
|
# Note: LayoutPostprocessor updates page.cells and page.parsed_page internally
|
||||||
|
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.filterwarnings(
|
||||||
|
"ignore",
|
||||||
|
"Mean of empty slice|invalid value encountered in scalar divide",
|
||||||
|
RuntimeWarning,
|
||||||
|
"numpy",
|
||||||
|
)
|
||||||
|
|
||||||
|
conv_res.confidence.pages[page.page_no].layout_score = float(
|
||||||
|
np.mean([c.confidence for c in processed_clusters])
|
||||||
|
)
|
||||||
|
|
||||||
|
conv_res.confidence.pages[page.page_no].ocr_score = float(
|
||||||
|
np.mean(
|
||||||
|
[c.confidence for c in processed_cells if c.from_ocr]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
page.predictions.layout = LayoutPrediction(
|
||||||
|
clusters=processed_clusters
|
||||||
|
)
|
||||||
|
|
||||||
|
return page, processed_clusters, processed_cells
|
||||||
|
115
docling/models/vlm_models_inline/TwoStageVlmModel.py
Normal file
115
docling/models/vlm_models_inline/TwoStageVlmModel.py
Normal file
@ -0,0 +1,115 @@
|
|||||||
|
import importlib.metadata
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from collections.abc import Iterable
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from docling.datamodel.accelerator_options import (
|
||||||
|
AcceleratorOptions,
|
||||||
|
)
|
||||||
|
from docling.datamodel.base_models import Page, VlmPrediction
|
||||||
|
from docling.datamodel.document import ConversionResult
|
||||||
|
from docling.datamodel.pipeline_options_vlm_model import (
|
||||||
|
InlineVlmOptions,
|
||||||
|
TransformersModelType,
|
||||||
|
TransformersPromptStyle,
|
||||||
|
)
|
||||||
|
from docling.models.base_model import BasePageModel
|
||||||
|
from docling.models.utils.hf_model_download import (
|
||||||
|
HuggingFaceModelDownloadMixin,
|
||||||
|
)
|
||||||
|
from docling.utils.accelerator_utils import decide_device
|
||||||
|
from docling.utils.profiling import TimeRecorder
|
||||||
|
|
||||||
|
_log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TwoStageVlmModel(BasePageModel, HuggingFaceModelDownloadMixin):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
layout_model: LayoutModelModel,
|
||||||
|
vlm_model: BasePageModel,
|
||||||
|
):
|
||||||
|
self.layout_model = layout_options
|
||||||
|
self.vlm_model = vlm_options
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self, conv_res: ConversionResult, page_batch: Iterable[Page]
|
||||||
|
) -> Iterable[Page]:
|
||||||
|
for page in page_batch:
|
||||||
|
assert page._backend is not None
|
||||||
|
if not page._backend.is_valid():
|
||||||
|
yield page
|
||||||
|
else:
|
||||||
|
with TimeRecorder(conv_res, "two-staged-vlm"):
|
||||||
|
assert page.size is not None
|
||||||
|
|
||||||
|
page_image = page.get_image(
|
||||||
|
scale=self.vlm_options.scale, max_size=self.vlm_options.max_size
|
||||||
|
)
|
||||||
|
|
||||||
|
pred_clusters = self.layout_model.predict_on_page(page_image)
|
||||||
|
page, processed_clusters, processed_cells = self.layout_model.postprocess_on_page(page=page,
|
||||||
|
page_image=page_image)
|
||||||
|
|
||||||
|
# Define prompt structure
|
||||||
|
if callable(self.vlm_options.prompt):
|
||||||
|
user_prompt = self.vlm_options.prompt(page.parsed_page)
|
||||||
|
else:
|
||||||
|
user_prompt = self.vlm_options.prompt
|
||||||
|
|
||||||
|
prompt = self.formulate_prompt(user_prompt, processed_clusters)
|
||||||
|
|
||||||
|
generated_text, generation_time = self.vlm_model.predict_on_image(page_image=page_image,
|
||||||
|
prompt=prompt)
|
||||||
|
|
||||||
|
page.predictions.vlm_response = VlmPrediction(
|
||||||
|
text=generated_text,
|
||||||
|
generation_time=generation_time,
|
||||||
|
)
|
||||||
|
|
||||||
|
yield page
|
||||||
|
|
||||||
|
def formulate_prompt(self, user_prompt: str, clusters:list[Cluster]) -> str:
|
||||||
|
"""Formulate a prompt for the VLM."""
|
||||||
|
|
||||||
|
if self.vlm_options.transformers_prompt_style == TransformersPromptStyle.RAW:
|
||||||
|
return user_prompt
|
||||||
|
|
||||||
|
elif self.vlm_options.repo_id == "microsoft/Phi-4-multimodal-instruct":
|
||||||
|
_log.debug("Using specialized prompt for Phi-4")
|
||||||
|
# more info here: https://huggingface.co/microsoft/Phi-4-multimodal-instruct#loading-the-model-locally
|
||||||
|
|
||||||
|
user_prompt = "<|user|>"
|
||||||
|
assistant_prompt = "<|assistant|>"
|
||||||
|
prompt_suffix = "<|end|>"
|
||||||
|
|
||||||
|
prompt = f"{user_prompt}<|image_1|>{user_prompt}{prompt_suffix}{assistant_prompt}"
|
||||||
|
_log.debug(f"prompt for {self.vlm_options.repo_id}: {prompt}")
|
||||||
|
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
elif self.vlm_options.transformers_prompt_style == TransformersPromptStyle.CHAT:
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "This is a page from a document.",
|
||||||
|
},
|
||||||
|
{"type": "image"},
|
||||||
|
{"type": "text", "text": user_prompt},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
prompt = self.processor.apply_chat_template(
|
||||||
|
messages, add_generation_prompt=False
|
||||||
|
)
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Uknown prompt style `{self.vlm_options.transformers_prompt_style}`. Valid values are {', '.join(s.value for s in TransformersPromptStyle)}."
|
||||||
|
)
|
@ -44,6 +44,8 @@ from docling.models.vlm_models_inline.mlx_model import HuggingFaceMlxModel
|
|||||||
from docling.pipeline.base_pipeline import PaginatedPipeline
|
from docling.pipeline.base_pipeline import PaginatedPipeline
|
||||||
from docling.utils.profiling import ProfilingScope, TimeRecorder
|
from docling.utils.profiling import ProfilingScope, TimeRecorder
|
||||||
|
|
||||||
|
from docling.models.layout_model import LayoutModel
|
||||||
|
|
||||||
_log = logging.getLogger(__name__)
|
_log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@ -107,6 +109,42 @@ class VlmPipeline(PaginatedPipeline):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Could not instantiate the right type of VLM pipeline: {vlm_options.inference_framework}"
|
f"Could not instantiate the right type of VLM pipeline: {vlm_options.inference_framework}"
|
||||||
)
|
)
|
||||||
|
elif isinstance(self.pipeline_options.vlm_options, TwoStageVlmOptions):
|
||||||
|
twostagevlm_options = cast(TwoStageVlmOptions, self.pipeline_options.vlm_options)
|
||||||
|
|
||||||
|
layout_options = twostagevlm_options.lay_options
|
||||||
|
vlm_options = twostagevlm_options.vlm_options
|
||||||
|
|
||||||
|
layout_model = LayoutModel(
|
||||||
|
artifacts_path=artifacts_path,
|
||||||
|
accelerator_options=pipeline_options.accelerator_options,
|
||||||
|
options=layout_options,
|
||||||
|
)
|
||||||
|
|
||||||
|
if vlm_options.inference_framework == InferenceFramework.MLX:
|
||||||
|
vlm_model = HuggingFaceMlxModel(
|
||||||
|
enabled=True, # must be always enabled for this pipeline to make sense.
|
||||||
|
artifacts_path=artifacts_path,
|
||||||
|
accelerator_options=pipeline_options.accelerator_options,
|
||||||
|
vlm_options=vlm_options,
|
||||||
|
)
|
||||||
|
self.build_pipe = [
|
||||||
|
TwoStageVlmModel(layout_model=layout_model, vlm_model=vlm_model)
|
||||||
|
]
|
||||||
|
elif vlm_options.inference_framework == InferenceFramework.TRANSFORMERS:
|
||||||
|
vlm_model = HuggingFaceTransformersVlmModel(
|
||||||
|
enabled=True, # must be always enabled for this pipeline to make sense.
|
||||||
|
artifacts_path=artifacts_path,
|
||||||
|
accelerator_options=pipeline_options.accelerator_options,
|
||||||
|
vlm_options=vlm_options,
|
||||||
|
)
|
||||||
|
self.build_pipe = [
|
||||||
|
TwoStageVlmModel(layout_model=layout_model, vlm_model=vlm_model)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Could not instantiate the right type of VLM pipeline: {vlm_options.inference_framework}"
|
||||||
|
)
|
||||||
|
|
||||||
self.enrichment_pipe = [
|
self.enrichment_pipe = [
|
||||||
# Other models working on `NodeItem` elements in the DoclingDocument
|
# Other models working on `NodeItem` elements in the DoclingDocument
|
||||||
|
Loading…
Reference in New Issue
Block a user