feat: working on a two stage VLM model

Signed-off-by: Peter Staar <taa@zurich.ibm.com>
This commit is contained in:
Peter Staar 2025-07-08 09:49:39 +02:00
parent 4eceefa47c
commit 810446c8dc
8 changed files with 90 additions and 64 deletions

View File

@ -16,6 +16,7 @@ from docling.datamodel import asr_model_specs
# Import the following for backwards compatibility
from docling.datamodel.accelerator_options import AcceleratorDevice, AcceleratorOptions
from docling.datamodel.asr_model_specs import WHISPER_TINY as whisper_tiny
from docling.datamodel.pipeline_options_asr_model import (
InlineAsrOptions,
)

View File

@ -6,6 +6,7 @@ from pydantic import AnyUrl, BaseModel
from typing_extensions import deprecated
from docling.datamodel.accelerator_options import AcceleratorDevice
from docling.datamodel.pipeline_options import LayoutOptions
class BaseVlmOptions(BaseModel):
@ -90,7 +91,7 @@ class ApiVlmOptions(BaseVlmOptions):
class TwoStageVlmOptions(BaseVlmOptions):
kind: Literal["inline_model_options"] = "inline_two_stage_model_options"
kind: Literal["inline_two_stage_model_options"] = "inline_two_stage_model_options"
vlm_options: UnionInlineVlmOptions
vlm_options: InlineVlmOptions
layout_options: LayoutOptions

View File

@ -3,6 +3,7 @@ from collections.abc import Iterable
from typing import Generic, Optional, Protocol, Type
from docling_core.types.doc import BoundingBox, DocItem, DoclingDocument, NodeItem
from PIL import Image
from typing_extensions import TypeVar
from docling.datamodel.base_models import ItemAndImageEnrichmentElement, Page
@ -19,12 +20,25 @@ class BaseModelWithOptions(Protocol):
class BasePageModel(ABC):
scale: float # scale with which the page-image needs to be created (dpi = 72*scale)
max_size: int # max size of width/height of page-image
@abstractmethod
def __call__(
self, conv_res: ConversionResult, page_batch: Iterable[Page]
) -> Iterable[Page]:
pass
class BaseLayoutModel(BasePageModel):
@abstractmethod
def predict_on_page_image(self, *, page_image: Image.Image) -> list(Cluster):
pass
class BaseVlmModel(BasePageModel):
@abstractmethod
def predict_on_page_image(self, *, page_image: Image.Image, prompt: str) -> str:
pass
EnrichElementT = TypeVar("EnrichElementT", default=NodeItem)

View File

@ -7,6 +7,7 @@ from typing import Optional
import numpy as np
from docling_core.types.doc import DocItemLabel
from docling_core.types.doc.page import TextCell
from PIL import Image
from docling.datamodel.accelerator_options import AcceleratorOptions
@ -176,11 +177,11 @@ class LayoutModel(BasePageModel):
)
clusters.append(cluster)
"""
clusters = self.predict_on_page(page_image)
predicted_clusters = self.predict_on_page(page_image=page_image)
if settings.debug.visualize_raw_layout:
self.draw_clusters_and_cells_side_by_side(
conv_res, page, clusters, mode_prefix="raw"
conv_res, page, predicted_clusters, mode_prefix="raw"
)
# Apply postprocessing
@ -212,8 +213,28 @@ class LayoutModel(BasePageModel):
clusters=processed_clusters
)
"""
page, processed_clusters, processed_cells = self.postprocess_on_page(page, cluster)
page, processed_clusters, processed_cells = (
self.postprocess_on_page(page=page, clusters=predicted_clusters)
)
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
"Mean of empty slice|invalid value encountered in scalar divide",
RuntimeWarning,
"numpy",
)
conv_res.confidence.pages[page.page_no].layout_score = float(
np.mean([c.confidence for c in processed_clusters])
)
conv_res.confidence.pages[page.page_no].ocr_score = float(
np.mean(
[c.confidence for c in processed_cells if c.from_ocr]
)
)
if settings.debug.visualize_layout:
self.draw_clusters_and_cells_side_by_side(
conv_res, page, processed_clusters, mode_prefix="postprocessed"
@ -221,17 +242,13 @@ class LayoutModel(BasePageModel):
yield page
def predict_on_page(self, page_image: Image) -> list[Cluster]:
def predict_on_page(self, *, page_image: Image.Image) -> list[Cluster]:
pred_items = self.layout_predictor.predict(page_image)
clusters = []
for ix, pred_item in enumerate(pred_items):
label = DocItemLabel(
pred_item["label"]
.lower()
.replace(" ", "_")
.replace("-", "_")
pred_item["label"].lower().replace(" ", "_").replace("-", "_")
) # Temporary, until docling-ibm-model uses docling-core types
cluster = Cluster(
id=ix,
@ -241,36 +258,17 @@ class LayoutModel(BasePageModel):
cells=[],
)
clusters.append(cluster)
return clusters
def postprocess_on_page(self, page: Page, cluster: list(Cluster)):
def postprocess_on_page(
self, *, page: Page, clusters: list[Cluster]
) -> tuple[Page, list[Cluster], list[TextCell]]:
processed_clusters, processed_cells = LayoutPostprocessor(
page, clusters, self.options
).postprocess()
# Note: LayoutPostprocessor updates page.cells and page.parsed_page internally
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
"Mean of empty slice|invalid value encountered in scalar divide",
RuntimeWarning,
"numpy",
)
conv_res.confidence.pages[page.page_no].layout_score = float(
np.mean([c.confidence for c in processed_clusters])
)
conv_res.confidence.pages[page.page_no].ocr_score = float(
np.mean(
[c.confidence for c in processed_cells if c.from_ocr]
)
)
page.predictions.layout = LayoutPrediction(
clusters=processed_clusters
)
page.predictions.layout = LayoutPrediction(clusters=processed_clusters)
return page, processed_clusters, processed_cells

View File

@ -15,7 +15,7 @@ from docling.datamodel.pipeline_options_vlm_model import (
TransformersModelType,
TransformersPromptStyle,
)
from docling.models.base_model import BasePageModel
from docling.models.base_model import BasePageModel, BaseVlmModel
from docling.models.utils.hf_model_download import (
HuggingFaceModelDownloadMixin,
)
@ -25,7 +25,7 @@ from docling.utils.profiling import TimeRecorder
_log = logging.getLogger(__name__)
class HuggingFaceTransformersVlmModel(BasePageModel, HuggingFaceModelDownloadMixin):
class HuggingFaceTransformersVlmModel(BaseVlmModel, HuggingFaceModelDownloadMixin):
def __init__(
self,
enabled: bool,
@ -37,6 +37,9 @@ class HuggingFaceTransformersVlmModel(BasePageModel, HuggingFaceModelDownloadMix
self.vlm_options = vlm_options
self.scale = self.vlm_options.scale
self.max_size = self.vlm_options.max_size
if self.enabled:
import torch
from transformers import (

View File

@ -10,7 +10,7 @@ from docling.datamodel.accelerator_options import (
from docling.datamodel.base_models import Page, VlmPrediction, VlmPredictionToken
from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options_vlm_model import InlineVlmOptions
from docling.models.base_model import BasePageModel
from docling.models.base_model import BasePageModel, BaseVlmModel
from docling.models.utils.hf_model_download import (
HuggingFaceModelDownloadMixin,
)
@ -19,7 +19,7 @@ from docling.utils.profiling import TimeRecorder
_log = logging.getLogger(__name__)
class HuggingFaceMlxModel(BasePageModel, HuggingFaceModelDownloadMixin):
class HuggingFaceMlxModel(BaseVlmModel, HuggingFaceModelDownloadMixin):
def __init__(
self,
enabled: bool,
@ -28,10 +28,12 @@ class HuggingFaceMlxModel(BasePageModel, HuggingFaceModelDownloadMixin):
vlm_options: InlineVlmOptions,
):
self.enabled = enabled
self.vlm_options = vlm_options
self.max_tokens = vlm_options.max_new_tokens
self.temperature = vlm_options.temperature
self.scale = self.vlm_options.scale
self.max_size = self.vlm_options.max_size
if self.enabled:
try:

View File

@ -15,7 +15,8 @@ from docling.datamodel.pipeline_options_vlm_model import (
TransformersModelType,
TransformersPromptStyle,
)
from docling.models.base_model import BasePageModel
from docling.models.base_model import BasePageModel, BaseVlmModel
from docling.models.layout_model import LayoutModel
from docling.models.utils.hf_model_download import (
HuggingFaceModelDownloadMixin,
)
@ -29,11 +30,11 @@ class TwoStageVlmModel(BasePageModel, HuggingFaceModelDownloadMixin):
def __init__(
self,
*,
layout_model: LayoutModelModel,
vlm_model: BasePageModel,
layout_model: LayoutModel,
vlm_model: BaseVlmModel,
):
self.layout_model = layout_options
self.vlm_model = vlm_options
self.layout_model = layout_model
self.vlm_model = vlm_model
def __call__(
self, conv_res: ConversionResult, page_batch: Iterable[Page]
@ -47,23 +48,27 @@ class TwoStageVlmModel(BasePageModel, HuggingFaceModelDownloadMixin):
assert page.size is not None
page_image = page.get_image(
scale=self.vlm_options.scale, max_size=self.vlm_options.max_size
scale=self.vlm_model.scale, max_size=self.vlm_model.max_size
)
pred_clusters = self.layout_model.predict_on_page(page_image=page_image)
page, processed_clusters, processed_cells = (
self.layout_model.postprocess_on_page(
page=page, clusters=pred_clusters
)
)
pred_clusters = self.layout_model.predict_on_page(page_image)
page, processed_clusters, processed_cells = self.layout_model.postprocess_on_page(page=page,
page_image=page_image)
# Define prompt structure
if callable(self.vlm_options.prompt):
user_prompt = self.vlm_options.prompt(page.parsed_page)
else:
user_prompt = self.vlm_options.prompt
prompt = self.formulate_prompt(user_prompt, processed_clusters)
generated_text, generation_time = self.vlm_model.predict_on_image(page_image=page_image,
prompt=prompt)
generated_text, generation_time = self.vlm_model.predict_on_image(
page_image=page_image, prompt=prompt
)
page.predictions.vlm_response = VlmPrediction(
text=generated_text,
@ -72,7 +77,7 @@ class TwoStageVlmModel(BasePageModel, HuggingFaceModelDownloadMixin):
yield page
def formulate_prompt(self, user_prompt: str, clusters:list[Cluster]) -> str:
def formulate_prompt(self, user_prompt: str, clusters: list[Cluster]) -> str:
"""Formulate a prompt for the VLM."""
if self.vlm_options.transformers_prompt_style == TransformersPromptStyle.RAW:

View File

@ -26,9 +26,7 @@ from docling.backend.md_backend import MarkdownDocumentBackend
from docling.backend.pdf_backend import PdfDocumentBackend
from docling.datamodel.base_models import InputFormat, Page
from docling.datamodel.document import ConversionResult, InputDocument
from docling.datamodel.pipeline_options import (
VlmPipelineOptions,
)
from docling.datamodel.pipeline_options import TwoStageVlmOptions, VlmPipelineOptions
from docling.datamodel.pipeline_options_vlm_model import (
ApiVlmOptions,
InferenceFramework,
@ -37,15 +35,17 @@ from docling.datamodel.pipeline_options_vlm_model import (
)
from docling.datamodel.settings import settings
from docling.models.api_vlm_model import ApiVlmModel
from docling.models.layout_model import LayoutModel
from docling.models.vlm_models_inline.hf_transformers_model import (
HuggingFaceTransformersVlmModel,
)
from docling.models.vlm_models_inline.mlx_model import HuggingFaceMlxModel
from docling.models.vlm_models_inline.two_stage_vlm_model import (
TwoStageVlmModel,
)
from docling.pipeline.base_pipeline import PaginatedPipeline
from docling.utils.profiling import ProfilingScope, TimeRecorder
from docling.models.layout_model import LayoutModel
_log = logging.getLogger(__name__)
@ -110,7 +110,9 @@ class VlmPipeline(PaginatedPipeline):
f"Could not instantiate the right type of VLM pipeline: {vlm_options.inference_framework}"
)
elif isinstance(self.pipeline_options.vlm_options, TwoStageVlmOptions):
twostagevlm_options = cast(TwoStageVlmOptions, self.pipeline_options.vlm_options)
twostagevlm_options = cast(
TwoStageVlmOptions, self.pipeline_options.vlm_options
)
layout_options = twostagevlm_options.lay_options
vlm_options = twostagevlm_options.vlm_options
@ -120,7 +122,7 @@ class VlmPipeline(PaginatedPipeline):
accelerator_options=pipeline_options.accelerator_options,
options=layout_options,
)
if vlm_options.inference_framework == InferenceFramework.MLX:
vlm_model = HuggingFaceMlxModel(
enabled=True, # must be always enabled for this pipeline to make sense.
@ -145,7 +147,7 @@ class VlmPipeline(PaginatedPipeline):
raise ValueError(
f"Could not instantiate the right type of VLM pipeline: {vlm_options.inference_framework}"
)
self.enrichment_pipe = [
# Other models working on `NodeItem` elements in the DoclingDocument
]