feat: working on a two stage VLM model

Signed-off-by: Peter Staar <taa@zurich.ibm.com>
This commit is contained in:
Peter Staar 2025-07-08 09:49:39 +02:00
parent 4eceefa47c
commit 810446c8dc
8 changed files with 90 additions and 64 deletions

View File

@ -16,6 +16,7 @@ from docling.datamodel import asr_model_specs
# Import the following for backwards compatibility # Import the following for backwards compatibility
from docling.datamodel.accelerator_options import AcceleratorDevice, AcceleratorOptions from docling.datamodel.accelerator_options import AcceleratorDevice, AcceleratorOptions
from docling.datamodel.asr_model_specs import WHISPER_TINY as whisper_tiny
from docling.datamodel.pipeline_options_asr_model import ( from docling.datamodel.pipeline_options_asr_model import (
InlineAsrOptions, InlineAsrOptions,
) )

View File

@ -6,6 +6,7 @@ from pydantic import AnyUrl, BaseModel
from typing_extensions import deprecated from typing_extensions import deprecated
from docling.datamodel.accelerator_options import AcceleratorDevice from docling.datamodel.accelerator_options import AcceleratorDevice
from docling.datamodel.pipeline_options import LayoutOptions
class BaseVlmOptions(BaseModel): class BaseVlmOptions(BaseModel):
@ -90,7 +91,7 @@ class ApiVlmOptions(BaseVlmOptions):
class TwoStageVlmOptions(BaseVlmOptions): class TwoStageVlmOptions(BaseVlmOptions):
kind: Literal["inline_model_options"] = "inline_two_stage_model_options" kind: Literal["inline_two_stage_model_options"] = "inline_two_stage_model_options"
vlm_options: UnionInlineVlmOptions vlm_options: InlineVlmOptions
layout_options: LayoutOptions layout_options: LayoutOptions

View File

@ -3,6 +3,7 @@ from collections.abc import Iterable
from typing import Generic, Optional, Protocol, Type from typing import Generic, Optional, Protocol, Type
from docling_core.types.doc import BoundingBox, DocItem, DoclingDocument, NodeItem from docling_core.types.doc import BoundingBox, DocItem, DoclingDocument, NodeItem
from PIL import Image
from typing_extensions import TypeVar from typing_extensions import TypeVar
from docling.datamodel.base_models import ItemAndImageEnrichmentElement, Page from docling.datamodel.base_models import ItemAndImageEnrichmentElement, Page
@ -19,12 +20,25 @@ class BaseModelWithOptions(Protocol):
class BasePageModel(ABC): class BasePageModel(ABC):
scale: float # scale with which the page-image needs to be created (dpi = 72*scale)
max_size: int # max size of width/height of page-image
@abstractmethod @abstractmethod
def __call__( def __call__(
self, conv_res: ConversionResult, page_batch: Iterable[Page] self, conv_res: ConversionResult, page_batch: Iterable[Page]
) -> Iterable[Page]: ) -> Iterable[Page]:
pass pass
class BaseLayoutModel(BasePageModel):
@abstractmethod
def predict_on_page_image(self, *, page_image: Image.Image) -> list(Cluster):
pass
class BaseVlmModel(BasePageModel):
@abstractmethod
def predict_on_page_image(self, *, page_image: Image.Image, prompt: str) -> str:
pass
EnrichElementT = TypeVar("EnrichElementT", default=NodeItem) EnrichElementT = TypeVar("EnrichElementT", default=NodeItem)

View File

@ -7,6 +7,7 @@ from typing import Optional
import numpy as np import numpy as np
from docling_core.types.doc import DocItemLabel from docling_core.types.doc import DocItemLabel
from docling_core.types.doc.page import TextCell
from PIL import Image from PIL import Image
from docling.datamodel.accelerator_options import AcceleratorOptions from docling.datamodel.accelerator_options import AcceleratorOptions
@ -176,11 +177,11 @@ class LayoutModel(BasePageModel):
) )
clusters.append(cluster) clusters.append(cluster)
""" """
clusters = self.predict_on_page(page_image) predicted_clusters = self.predict_on_page(page_image=page_image)
if settings.debug.visualize_raw_layout: if settings.debug.visualize_raw_layout:
self.draw_clusters_and_cells_side_by_side( self.draw_clusters_and_cells_side_by_side(
conv_res, page, clusters, mode_prefix="raw" conv_res, page, predicted_clusters, mode_prefix="raw"
) )
# Apply postprocessing # Apply postprocessing
@ -212,45 +213,10 @@ class LayoutModel(BasePageModel):
clusters=processed_clusters clusters=processed_clusters
) )
""" """
page, processed_clusters, processed_cells = self.postprocess_on_page(page, cluster) page, processed_clusters, processed_cells = (
self.postprocess_on_page(page=page, clusters=predicted_clusters)
if settings.debug.visualize_layout:
self.draw_clusters_and_cells_side_by_side(
conv_res, page, processed_clusters, mode_prefix="postprocessed"
) )
yield page
def predict_on_page(self, page_image: Image) -> list[Cluster]:
pred_items = self.layout_predictor.predict(page_image)
clusters = []
for ix, pred_item in enumerate(pred_items):
label = DocItemLabel(
pred_item["label"]
.lower()
.replace(" ", "_")
.replace("-", "_")
) # Temporary, until docling-ibm-model uses docling-core types
cluster = Cluster(
id=ix,
label=label,
confidence=pred_item["confidence"],
bbox=BoundingBox.model_validate(pred_item),
cells=[],
)
clusters.append(cluster)
return clusters
def postprocess_on_page(self, page: Page, cluster: list(Cluster)):
processed_clusters, processed_cells = LayoutPostprocessor(
page, clusters, self.options
).postprocess()
# Note: LayoutPostprocessor updates page.cells and page.parsed_page internally
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.filterwarnings( warnings.filterwarnings(
"ignore", "ignore",
@ -269,8 +235,40 @@ class LayoutModel(BasePageModel):
) )
) )
page.predictions.layout = LayoutPrediction( if settings.debug.visualize_layout:
clusters=processed_clusters self.draw_clusters_and_cells_side_by_side(
conv_res, page, processed_clusters, mode_prefix="postprocessed"
) )
yield page
def predict_on_page(self, *, page_image: Image.Image) -> list[Cluster]:
pred_items = self.layout_predictor.predict(page_image)
clusters = []
for ix, pred_item in enumerate(pred_items):
label = DocItemLabel(
pred_item["label"].lower().replace(" ", "_").replace("-", "_")
) # Temporary, until docling-ibm-model uses docling-core types
cluster = Cluster(
id=ix,
label=label,
confidence=pred_item["confidence"],
bbox=BoundingBox.model_validate(pred_item),
cells=[],
)
clusters.append(cluster)
return clusters
def postprocess_on_page(
self, *, page: Page, clusters: list[Cluster]
) -> tuple[Page, list[Cluster], list[TextCell]]:
processed_clusters, processed_cells = LayoutPostprocessor(
page, clusters, self.options
).postprocess()
# Note: LayoutPostprocessor updates page.cells and page.parsed_page internally
page.predictions.layout = LayoutPrediction(clusters=processed_clusters)
return page, processed_clusters, processed_cells return page, processed_clusters, processed_cells

View File

@ -15,7 +15,7 @@ from docling.datamodel.pipeline_options_vlm_model import (
TransformersModelType, TransformersModelType,
TransformersPromptStyle, TransformersPromptStyle,
) )
from docling.models.base_model import BasePageModel from docling.models.base_model import BasePageModel, BaseVlmModel
from docling.models.utils.hf_model_download import ( from docling.models.utils.hf_model_download import (
HuggingFaceModelDownloadMixin, HuggingFaceModelDownloadMixin,
) )
@ -25,7 +25,7 @@ from docling.utils.profiling import TimeRecorder
_log = logging.getLogger(__name__) _log = logging.getLogger(__name__)
class HuggingFaceTransformersVlmModel(BasePageModel, HuggingFaceModelDownloadMixin): class HuggingFaceTransformersVlmModel(BaseVlmModel, HuggingFaceModelDownloadMixin):
def __init__( def __init__(
self, self,
enabled: bool, enabled: bool,
@ -37,6 +37,9 @@ class HuggingFaceTransformersVlmModel(BasePageModel, HuggingFaceModelDownloadMix
self.vlm_options = vlm_options self.vlm_options = vlm_options
self.scale = self.vlm_options.scale
self.max_size = self.vlm_options.max_size
if self.enabled: if self.enabled:
import torch import torch
from transformers import ( from transformers import (

View File

@ -10,7 +10,7 @@ from docling.datamodel.accelerator_options import (
from docling.datamodel.base_models import Page, VlmPrediction, VlmPredictionToken from docling.datamodel.base_models import Page, VlmPrediction, VlmPredictionToken
from docling.datamodel.document import ConversionResult from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options_vlm_model import InlineVlmOptions from docling.datamodel.pipeline_options_vlm_model import InlineVlmOptions
from docling.models.base_model import BasePageModel from docling.models.base_model import BasePageModel, BaseVlmModel
from docling.models.utils.hf_model_download import ( from docling.models.utils.hf_model_download import (
HuggingFaceModelDownloadMixin, HuggingFaceModelDownloadMixin,
) )
@ -19,7 +19,7 @@ from docling.utils.profiling import TimeRecorder
_log = logging.getLogger(__name__) _log = logging.getLogger(__name__)
class HuggingFaceMlxModel(BasePageModel, HuggingFaceModelDownloadMixin): class HuggingFaceMlxModel(BaseVlmModel, HuggingFaceModelDownloadMixin):
def __init__( def __init__(
self, self,
enabled: bool, enabled: bool,
@ -28,10 +28,12 @@ class HuggingFaceMlxModel(BasePageModel, HuggingFaceModelDownloadMixin):
vlm_options: InlineVlmOptions, vlm_options: InlineVlmOptions,
): ):
self.enabled = enabled self.enabled = enabled
self.vlm_options = vlm_options self.vlm_options = vlm_options
self.max_tokens = vlm_options.max_new_tokens self.max_tokens = vlm_options.max_new_tokens
self.temperature = vlm_options.temperature self.temperature = vlm_options.temperature
self.scale = self.vlm_options.scale
self.max_size = self.vlm_options.max_size
if self.enabled: if self.enabled:
try: try:

View File

@ -15,7 +15,8 @@ from docling.datamodel.pipeline_options_vlm_model import (
TransformersModelType, TransformersModelType,
TransformersPromptStyle, TransformersPromptStyle,
) )
from docling.models.base_model import BasePageModel from docling.models.base_model import BasePageModel, BaseVlmModel
from docling.models.layout_model import LayoutModel
from docling.models.utils.hf_model_download import ( from docling.models.utils.hf_model_download import (
HuggingFaceModelDownloadMixin, HuggingFaceModelDownloadMixin,
) )
@ -29,11 +30,11 @@ class TwoStageVlmModel(BasePageModel, HuggingFaceModelDownloadMixin):
def __init__( def __init__(
self, self,
*, *,
layout_model: LayoutModelModel, layout_model: LayoutModel,
vlm_model: BasePageModel, vlm_model: BaseVlmModel,
): ):
self.layout_model = layout_options self.layout_model = layout_model
self.vlm_model = vlm_options self.vlm_model = vlm_model
def __call__( def __call__(
self, conv_res: ConversionResult, page_batch: Iterable[Page] self, conv_res: ConversionResult, page_batch: Iterable[Page]
@ -47,12 +48,15 @@ class TwoStageVlmModel(BasePageModel, HuggingFaceModelDownloadMixin):
assert page.size is not None assert page.size is not None
page_image = page.get_image( page_image = page.get_image(
scale=self.vlm_options.scale, max_size=self.vlm_options.max_size scale=self.vlm_model.scale, max_size=self.vlm_model.max_size
) )
pred_clusters = self.layout_model.predict_on_page(page_image) pred_clusters = self.layout_model.predict_on_page(page_image=page_image)
page, processed_clusters, processed_cells = self.layout_model.postprocess_on_page(page=page, page, processed_clusters, processed_cells = (
page_image=page_image) self.layout_model.postprocess_on_page(
page=page, clusters=pred_clusters
)
)
# Define prompt structure # Define prompt structure
if callable(self.vlm_options.prompt): if callable(self.vlm_options.prompt):
@ -62,8 +66,9 @@ class TwoStageVlmModel(BasePageModel, HuggingFaceModelDownloadMixin):
prompt = self.formulate_prompt(user_prompt, processed_clusters) prompt = self.formulate_prompt(user_prompt, processed_clusters)
generated_text, generation_time = self.vlm_model.predict_on_image(page_image=page_image, generated_text, generation_time = self.vlm_model.predict_on_image(
prompt=prompt) page_image=page_image, prompt=prompt
)
page.predictions.vlm_response = VlmPrediction( page.predictions.vlm_response = VlmPrediction(
text=generated_text, text=generated_text,
@ -72,7 +77,7 @@ class TwoStageVlmModel(BasePageModel, HuggingFaceModelDownloadMixin):
yield page yield page
def formulate_prompt(self, user_prompt: str, clusters:list[Cluster]) -> str: def formulate_prompt(self, user_prompt: str, clusters: list[Cluster]) -> str:
"""Formulate a prompt for the VLM.""" """Formulate a prompt for the VLM."""
if self.vlm_options.transformers_prompt_style == TransformersPromptStyle.RAW: if self.vlm_options.transformers_prompt_style == TransformersPromptStyle.RAW:

View File

@ -26,9 +26,7 @@ from docling.backend.md_backend import MarkdownDocumentBackend
from docling.backend.pdf_backend import PdfDocumentBackend from docling.backend.pdf_backend import PdfDocumentBackend
from docling.datamodel.base_models import InputFormat, Page from docling.datamodel.base_models import InputFormat, Page
from docling.datamodel.document import ConversionResult, InputDocument from docling.datamodel.document import ConversionResult, InputDocument
from docling.datamodel.pipeline_options import ( from docling.datamodel.pipeline_options import TwoStageVlmOptions, VlmPipelineOptions
VlmPipelineOptions,
)
from docling.datamodel.pipeline_options_vlm_model import ( from docling.datamodel.pipeline_options_vlm_model import (
ApiVlmOptions, ApiVlmOptions,
InferenceFramework, InferenceFramework,
@ -37,15 +35,17 @@ from docling.datamodel.pipeline_options_vlm_model import (
) )
from docling.datamodel.settings import settings from docling.datamodel.settings import settings
from docling.models.api_vlm_model import ApiVlmModel from docling.models.api_vlm_model import ApiVlmModel
from docling.models.layout_model import LayoutModel
from docling.models.vlm_models_inline.hf_transformers_model import ( from docling.models.vlm_models_inline.hf_transformers_model import (
HuggingFaceTransformersVlmModel, HuggingFaceTransformersVlmModel,
) )
from docling.models.vlm_models_inline.mlx_model import HuggingFaceMlxModel from docling.models.vlm_models_inline.mlx_model import HuggingFaceMlxModel
from docling.models.vlm_models_inline.two_stage_vlm_model import (
TwoStageVlmModel,
)
from docling.pipeline.base_pipeline import PaginatedPipeline from docling.pipeline.base_pipeline import PaginatedPipeline
from docling.utils.profiling import ProfilingScope, TimeRecorder from docling.utils.profiling import ProfilingScope, TimeRecorder
from docling.models.layout_model import LayoutModel
_log = logging.getLogger(__name__) _log = logging.getLogger(__name__)
@ -110,7 +110,9 @@ class VlmPipeline(PaginatedPipeline):
f"Could not instantiate the right type of VLM pipeline: {vlm_options.inference_framework}" f"Could not instantiate the right type of VLM pipeline: {vlm_options.inference_framework}"
) )
elif isinstance(self.pipeline_options.vlm_options, TwoStageVlmOptions): elif isinstance(self.pipeline_options.vlm_options, TwoStageVlmOptions):
twostagevlm_options = cast(TwoStageVlmOptions, self.pipeline_options.vlm_options) twostagevlm_options = cast(
TwoStageVlmOptions, self.pipeline_options.vlm_options
)
layout_options = twostagevlm_options.lay_options layout_options = twostagevlm_options.lay_options
vlm_options = twostagevlm_options.vlm_options vlm_options = twostagevlm_options.vlm_options