feat: Factory and plugin-capability for Layout and Table models (#2637)

* feat: Scaffolding for layout and table model plugin factory

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Add missing files

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Add base options classes for layout and table

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

---------

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>
This commit is contained in:
Christoph Auer
2025-11-21 10:26:06 +01:00
committed by GitHub
parent dcb57bf528
commit ad97e52851
11 changed files with 344 additions and 169 deletions

View File

@@ -59,9 +59,14 @@ class TableFormerMode(str, Enum):
ACCURATE = "accurate"
class TableStructureOptions(BaseModel):
class BaseTableStructureOptions(BaseOptions):
"""Base options for table structure models."""
class TableStructureOptions(BaseTableStructureOptions):
"""Options for the table structure."""
kind: ClassVar[str] = "docling_tableformer"
do_cell_matching: bool = (
True
# True: Matches predictions back to PDF cells. Can break table output if PDF cells
@@ -308,19 +313,25 @@ class VlmPipelineOptions(PaginatedPipelineOptions):
)
class LayoutOptions(BaseModel):
"""Options for layout processing."""
class BaseLayoutOptions(BaseOptions):
"""Base options for layout models."""
create_orphan_clusters: bool = True # Whether to create clusters for orphaned cells
keep_empty_clusters: bool = (
False # Whether to keep clusters that contain no text cells
)
model_spec: LayoutModelConfig = DOCLING_LAYOUT_HERON
skip_cell_assignment: bool = (
False # Skip cell-to-cluster assignment for VLM-only processing
)
class LayoutOptions(BaseLayoutOptions):
"""Options for layout processing."""
kind: ClassVar[str] = "docling_layout_default"
create_orphan_clusters: bool = True # Whether to create clusters for orphaned cells
model_spec: LayoutModelConfig = DOCLING_LAYOUT_HERON
class AsrPipelineOptions(PipelineOptions):
asr_options: Union[InlineAsrOptions] = asr_model_specs.WHISPER_TINY

View File

@@ -0,0 +1,39 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from collections.abc import Iterable, Sequence
from typing import Type
from docling.datamodel.base_models import LayoutPrediction, Page
from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options import BaseLayoutOptions
from docling.models.base_model import BaseModelWithOptions, BasePageModel
class BaseLayoutModel(BasePageModel, BaseModelWithOptions, ABC):
"""Shared interface for layout models."""
@classmethod
@abstractmethod
def get_options_type(cls) -> Type[BaseLayoutOptions]:
"""Return the options type supported by this layout model."""
@abstractmethod
def predict_layout(
self,
conv_res: ConversionResult,
pages: Sequence[Page],
) -> Sequence[LayoutPrediction]:
"""Produce layout predictions for the provided pages."""
def __call__(
self,
conv_res: ConversionResult,
page_batch: Iterable[Page],
) -> Iterable[Page]:
pages = list(page_batch)
predictions = self.predict_layout(conv_res, pages)
for page, prediction in zip(pages, predictions):
page.predictions.layout = prediction
yield page

View File

@@ -0,0 +1,45 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from collections.abc import Iterable, Sequence
from typing import Type
from docling.datamodel.base_models import Page, TableStructurePrediction
from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options import BaseTableStructureOptions
from docling.models.base_model import BaseModelWithOptions, BasePageModel
class BaseTableStructureModel(BasePageModel, BaseModelWithOptions, ABC):
"""Shared interface for table structure models."""
enabled: bool
@classmethod
@abstractmethod
def get_options_type(cls) -> Type[BaseTableStructureOptions]:
"""Return the options type supported by this table model."""
@abstractmethod
def predict_tables(
self,
conv_res: ConversionResult,
pages: Sequence[Page],
) -> Sequence[TableStructurePrediction]:
"""Produce table structure predictions for the provided pages."""
def __call__(
self,
conv_res: ConversionResult,
page_batch: Iterable[Page],
) -> Iterable[Page]:
if not getattr(self, "enabled", True):
yield from page_batch
return
pages = list(page_batch)
predictions = self.predict_tables(conv_res, pages)
for page, prediction in zip(pages, predictions):
page.predictions.tablestructure = prediction
yield page

View File

@@ -1,10 +1,12 @@
import logging
from functools import lru_cache
from docling.models.factories.layout_factory import LayoutFactory
from docling.models.factories.ocr_factory import OcrFactory
from docling.models.factories.picture_description_factory import (
PictureDescriptionFactory,
)
from docling.models.factories.table_factory import TableStructureFactory
logger = logging.getLogger(__name__)
@@ -25,3 +27,21 @@ def get_picture_description_factory(
factory.load_from_plugins(allow_external_plugins=allow_external_plugins)
logger.info("Registered picture descriptions: %r", factory.registered_kind)
return factory
@lru_cache
def get_layout_factory(allow_external_plugins: bool = False) -> LayoutFactory:
factory = LayoutFactory()
factory.load_from_plugins(allow_external_plugins=allow_external_plugins)
logger.info("Registered layout engines: %r", factory.registered_kind)
return factory
@lru_cache
def get_table_structure_factory(
allow_external_plugins: bool = False,
) -> TableStructureFactory:
factory = TableStructureFactory()
factory.load_from_plugins(allow_external_plugins=allow_external_plugins)
logger.info("Registered table structure engines: %r", factory.registered_kind)
return factory

View File

@@ -0,0 +1,7 @@
from docling.models.base_layout_model import BaseLayoutModel
from docling.models.factories.base_factory import BaseFactory
class LayoutFactory(BaseFactory[BaseLayoutModel]):
def __init__(self, *args, **kwargs):
super().__init__("layout_engines", *args, **kwargs)

View File

@@ -0,0 +1,7 @@
from docling.models.base_table_model import BaseTableStructureModel
from docling.models.factories.base_factory import BaseFactory
class TableStructureFactory(BaseFactory[BaseTableStructureModel]):
def __init__(self, *args, **kwargs):
super().__init__("table_structure_engines", *args, **kwargs)

View File

@@ -1,7 +1,7 @@
import copy
import logging
import warnings
from collections.abc import Iterable
from collections.abc import Sequence
from pathlib import Path
from typing import List, Optional, Union
@@ -15,7 +15,7 @@ from docling.datamodel.document import ConversionResult
from docling.datamodel.layout_model_specs import DOCLING_LAYOUT_V2, LayoutModelConfig
from docling.datamodel.pipeline_options import LayoutOptions
from docling.datamodel.settings import settings
from docling.models.base_model import BasePageModel
from docling.models.base_layout_model import BaseLayoutModel
from docling.models.utils.hf_model_download import download_hf_model
from docling.utils.accelerator_utils import decide_device
from docling.utils.layout_postprocessor import LayoutPostprocessor
@@ -25,7 +25,7 @@ from docling.utils.visualization import draw_clusters
_log = logging.getLogger(__name__)
class LayoutModel(BasePageModel):
class LayoutModel(BaseLayoutModel):
TEXT_ELEM_LABELS = [
DocItemLabel.TEXT,
DocItemLabel.FOOTNOTE,
@@ -86,6 +86,10 @@ class LayoutModel(BasePageModel):
num_threads=accelerator_options.num_threads,
)
@classmethod
def get_options_type(cls) -> type[LayoutOptions]:
return LayoutOptions
@staticmethod
def download_models(
local_dir: Optional[Path] = None,
@@ -145,11 +149,13 @@ class LayoutModel(BasePageModel):
out_file = out_path / f"{mode_prefix}_layout_page_{page.page_no:05}.png"
combined_image.save(str(out_file), format="png")
def __call__(
self, conv_res: ConversionResult, page_batch: Iterable[Page]
) -> Iterable[Page]:
# Convert to list to allow multiple iterations
pages = list(page_batch)
def predict_layout(
self,
conv_res: ConversionResult,
pages: Sequence[Page],
) -> Sequence[LayoutPrediction]:
# Convert to list to ensure predictable iteration
pages = list(pages)
# Separate valid and invalid pages
valid_pages = []
@@ -167,12 +173,6 @@ class LayoutModel(BasePageModel):
valid_pages.append(page)
valid_page_images.append(page_image)
_log.debug(f"{len(pages)=}")
if pages:
_log.debug(f"{pages[0].page_no}-{pages[-1].page_no}")
_log.debug(f"{len(valid_pages)=}")
_log.debug(f"{len(valid_page_images)=}")
# Process all valid pages with batch prediction
batch_predictions = []
if valid_page_images:
@@ -182,11 +182,14 @@ class LayoutModel(BasePageModel):
)
# Process each page with its predictions
layout_predictions: list[LayoutPrediction] = []
valid_page_idx = 0
for page in pages:
assert page._backend is not None
if not page._backend.is_valid():
yield page
existing_prediction = page.predictions.layout or LayoutPrediction()
page.predictions.layout = existing_prediction
layout_predictions.append(existing_prediction)
continue
page_predictions = batch_predictions[valid_page_idx]
@@ -233,11 +236,14 @@ class LayoutModel(BasePageModel):
np.mean([c.confidence for c in processed_cells if c.from_ocr])
)
page.predictions.layout = LayoutPrediction(clusters=processed_clusters)
prediction = LayoutPrediction(clusters=processed_clusters)
page.predictions.layout = prediction
if settings.debug.visualize_layout:
self.draw_clusters_and_cells_side_by_side(
conv_res, page, processed_clusters, mode_prefix="postprocessed"
)
yield page
layout_predictions.append(prediction)
return layout_predictions

View File

@@ -28,3 +28,23 @@ def picture_description():
PictureDescriptionApiModel,
]
}
def layout_engines():
from docling.models.layout_model import LayoutModel
return {
"layout_engines": [
LayoutModel,
]
}
def table_structure_engines():
from docling.models.table_structure_model import TableStructureModel
return {
"table_structure_engines": [
TableStructureModel,
]
}

View File

@@ -1,6 +1,6 @@
import copy
import warnings
from collections.abc import Iterable
from collections.abc import Iterable, Sequence
from pathlib import Path
from typing import Optional
@@ -20,13 +20,13 @@ from docling.datamodel.pipeline_options import (
TableStructureOptions,
)
from docling.datamodel.settings import settings
from docling.models.base_model import BasePageModel
from docling.models.base_table_model import BaseTableStructureModel
from docling.models.utils.hf_model_download import download_hf_model
from docling.utils.accelerator_utils import decide_device
from docling.utils.profiling import TimeRecorder
class TableStructureModel(BasePageModel):
class TableStructureModel(BaseTableStructureModel):
_model_repo_folder = "docling-project--docling-models"
_model_path = "model_artifacts/tableformer"
@@ -88,6 +88,10 @@ class TableStructureModel(BasePageModel):
)
self.scale = 2.0 # Scale up table input images to 144 dpi
@classmethod
def get_options_type(cls) -> type[TableStructureOptions]:
return TableStructureOptions
@staticmethod
def download_models(
local_dir: Optional[Path] = None, force: bool = False, progress: bool = False
@@ -167,138 +171,135 @@ class TableStructureModel(BasePageModel):
out_file = out_path / f"table_struct_page_{page.page_no:05}.png"
image.save(str(out_file), format="png")
def __call__(
self, conv_res: ConversionResult, page_batch: Iterable[Page]
) -> Iterable[Page]:
if not self.enabled:
yield from page_batch
return
def predict_tables(
self,
conv_res: ConversionResult,
pages: Sequence[Page],
) -> Sequence[TableStructurePrediction]:
pages = list(pages)
predictions: list[TableStructurePrediction] = []
for page in page_batch:
for page in pages:
assert page._backend is not None
if not page._backend.is_valid():
yield page
else:
with TimeRecorder(conv_res, "table_structure"):
assert page.predictions.layout is not None
assert page.size is not None
existing_prediction = (
page.predictions.tablestructure or TableStructurePrediction()
)
page.predictions.tablestructure = existing_prediction
predictions.append(existing_prediction)
continue
page.predictions.tablestructure = (
TableStructurePrediction()
) # dummy
with TimeRecorder(conv_res, "table_structure"):
assert page.predictions.layout is not None
assert page.size is not None
in_tables = [
(
cluster,
[
round(cluster.bbox.l) * self.scale,
round(cluster.bbox.t) * self.scale,
round(cluster.bbox.r) * self.scale,
round(cluster.bbox.b) * self.scale,
],
table_prediction = TableStructurePrediction()
page.predictions.tablestructure = table_prediction
in_tables = [
(
cluster,
[
round(cluster.bbox.l) * self.scale,
round(cluster.bbox.t) * self.scale,
round(cluster.bbox.r) * self.scale,
round(cluster.bbox.b) * self.scale,
],
)
for cluster in page.predictions.layout.clusters
if cluster.label
in [DocItemLabel.TABLE, DocItemLabel.DOCUMENT_INDEX]
]
if not in_tables:
predictions.append(table_prediction)
continue
page_input = {
"width": page.size.width * self.scale,
"height": page.size.height * self.scale,
"image": numpy.asarray(page.get_image(scale=self.scale)),
}
for table_cluster, tbl_box in in_tables:
# Check if word-level cells are available from backend:
sp = page._backend.get_segmented_page()
if sp is not None:
tcells = sp.get_cells_in_bbox(
cell_unit=TextCellUnit.WORD,
bbox=table_cluster.bbox,
)
for cluster in page.predictions.layout.clusters
if cluster.label
in [DocItemLabel.TABLE, DocItemLabel.DOCUMENT_INDEX]
]
if not len(in_tables):
yield page
continue
page_input = {
"width": page.size.width * self.scale,
"height": page.size.height * self.scale,
"image": numpy.asarray(page.get_image(scale=self.scale)),
}
table_clusters, table_bboxes = zip(*in_tables)
if len(table_bboxes):
for table_cluster, tbl_box in in_tables:
# Check if word-level cells are available from backend:
sp = page._backend.get_segmented_page()
if sp is not None:
tcells = sp.get_cells_in_bbox(
cell_unit=TextCellUnit.WORD,
bbox=table_cluster.bbox,
)
if len(tcells) == 0:
# In case word-level cells yield empty
tcells = table_cluster.cells
else:
# Otherwise - we use normal (line/phrase) cells
tcells = table_cluster.cells
tokens = []
for c in tcells:
# Only allow non empty strings (spaces) into the cells of a table
if len(c.text.strip()) > 0:
new_cell = copy.deepcopy(c)
new_cell.rect = BoundingRectangle.from_bounding_box(
new_cell.rect.to_bounding_box().scaled(
scale=self.scale
)
)
tokens.append(
{
"id": new_cell.index,
"text": new_cell.text,
"bbox": new_cell.rect.to_bounding_box().model_dump(),
}
)
page_input["tokens"] = tokens
tf_output = self.tf_predictor.multi_table_predict(
page_input, [tbl_box], do_matching=self.do_cell_matching
if len(tcells) == 0:
# In case word-level cells yield empty
tcells = table_cluster.cells
else:
# Otherwise - we use normal (line/phrase) cells
tcells = table_cluster.cells
tokens = []
for c in tcells:
# Only allow non empty strings (spaces) into the cells of a table
if len(c.text.strip()) > 0:
new_cell = copy.deepcopy(c)
new_cell.rect = BoundingRectangle.from_bounding_box(
new_cell.rect.to_bounding_box().scaled(scale=self.scale)
)
table_out = tf_output[0]
table_cells = []
for element in table_out["tf_responses"]:
if not self.do_cell_matching:
the_bbox = BoundingBox.model_validate(
element["bbox"]
).scaled(1 / self.scale)
text_piece = page._backend.get_text_in_rect(
the_bbox
)
element["bbox"]["token"] = text_piece
tc = TableCell.model_validate(element)
if tc.bbox is not None:
tc.bbox = tc.bbox.scaled(1 / self.scale)
table_cells.append(tc)
assert "predict_details" in table_out
# Retrieving cols/rows, after post processing:
num_rows = table_out["predict_details"].get("num_rows", 0)
num_cols = table_out["predict_details"].get("num_cols", 0)
otsl_seq = (
table_out["predict_details"]
.get("prediction", {})
.get("rs_seq", [])
tokens.append(
{
"id": new_cell.index,
"text": new_cell.text,
"bbox": new_cell.rect.to_bounding_box().model_dump(),
}
)
page_input["tokens"] = tokens
tbl = Table(
otsl_seq=otsl_seq,
table_cells=table_cells,
num_rows=num_rows,
num_cols=num_cols,
id=table_cluster.id,
page_no=page.page_no,
cluster=table_cluster,
label=table_cluster.label,
)
tf_output = self.tf_predictor.multi_table_predict(
page_input, [tbl_box], do_matching=self.do_cell_matching
)
table_out = tf_output[0]
table_cells = []
for element in table_out["tf_responses"]:
if not self.do_cell_matching:
the_bbox = BoundingBox.model_validate(
element["bbox"]
).scaled(1 / self.scale)
text_piece = page._backend.get_text_in_rect(the_bbox)
element["bbox"]["token"] = text_piece
page.predictions.tablestructure.table_map[
table_cluster.id
] = tbl
tc = TableCell.model_validate(element)
if tc.bbox is not None:
tc.bbox = tc.bbox.scaled(1 / self.scale)
table_cells.append(tc)
# For debugging purposes:
if settings.debug.visualize_tables:
self.draw_table_and_cells(
conv_res,
page,
page.predictions.tablestructure.table_map.values(),
)
assert "predict_details" in table_out
yield page
# Retrieving cols/rows, after post processing:
num_rows = table_out["predict_details"].get("num_rows", 0)
num_cols = table_out["predict_details"].get("num_cols", 0)
otsl_seq = (
table_out["predict_details"]
.get("prediction", {})
.get("rs_seq", [])
)
tbl = Table(
otsl_seq=otsl_seq,
table_cells=table_cells,
num_rows=num_rows,
num_cols=num_cols,
id=table_cluster.id,
page_no=page.page_no,
cluster=table_cluster,
label=table_cluster.label,
)
table_prediction.table_map[table_cluster.id] = tbl
if settings.debug.visualize_tables:
self.draw_table_and_cells(
conv_res,
page,
page.predictions.tablestructure.table_map.values(),
)
predictions.append(table_prediction)
return predictions

View File

@@ -15,15 +15,17 @@ from docling.datamodel.pipeline_options import PdfPipelineOptions
from docling.datamodel.settings import settings
from docling.models.base_ocr_model import BaseOcrModel
from docling.models.code_formula_model import CodeFormulaModel, CodeFormulaModelOptions
from docling.models.factories import get_ocr_factory
from docling.models.layout_model import LayoutModel
from docling.models.factories import (
get_layout_factory,
get_ocr_factory,
get_table_structure_factory,
)
from docling.models.page_assemble_model import PageAssembleModel, PageAssembleOptions
from docling.models.page_preprocessing_model import (
PagePreprocessingModel,
PagePreprocessingOptions,
)
from docling.models.readingorder_model import ReadingOrderModel, ReadingOrderOptions
from docling.models.table_structure_model import TableStructureModel
from docling.pipeline.base_pipeline import PaginatedPipeline
from docling.utils.model_downloader import download_models
from docling.utils.profiling import ProfilingScope, TimeRecorder
@@ -48,6 +50,24 @@ class LegacyStandardPdfPipeline(PaginatedPipeline):
ocr_model = self.get_ocr_model(artifacts_path=self.artifacts_path)
layout_factory = get_layout_factory(
allow_external_plugins=self.pipeline_options.allow_external_plugins
)
layout_model = layout_factory.create_instance(
options=pipeline_options.layout_options,
artifacts_path=self.artifacts_path,
accelerator_options=pipeline_options.accelerator_options,
)
table_factory = get_table_structure_factory(
allow_external_plugins=self.pipeline_options.allow_external_plugins
)
table_model = table_factory.create_instance(
options=pipeline_options.table_structure_options,
enabled=pipeline_options.do_table_structure,
artifacts_path=self.artifacts_path,
accelerator_options=pipeline_options.accelerator_options,
)
self.build_pipe = [
# Pre-processing
PagePreprocessingModel(
@@ -58,18 +78,9 @@ class LegacyStandardPdfPipeline(PaginatedPipeline):
# OCR
ocr_model,
# Layout model
LayoutModel(
artifacts_path=self.artifacts_path,
accelerator_options=pipeline_options.accelerator_options,
options=pipeline_options.layout_options,
),
layout_model,
# Table structure model
TableStructureModel(
enabled=pipeline_options.do_table_structure,
artifacts_path=self.artifacts_path,
options=pipeline_options.table_structure_options,
accelerator_options=pipeline_options.accelerator_options,
),
table_model,
# Page assemble
PageAssembleModel(options=PageAssembleOptions()),
]

View File

@@ -41,15 +41,17 @@ from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options import ThreadedPdfPipelineOptions
from docling.datamodel.settings import settings
from docling.models.code_formula_model import CodeFormulaModel, CodeFormulaModelOptions
from docling.models.factories import get_ocr_factory
from docling.models.layout_model import LayoutModel
from docling.models.factories import (
get_layout_factory,
get_ocr_factory,
get_table_structure_factory,
)
from docling.models.page_assemble_model import PageAssembleModel, PageAssembleOptions
from docling.models.page_preprocessing_model import (
PagePreprocessingModel,
PagePreprocessingOptions,
)
from docling.models.readingorder_model import ReadingOrderModel, ReadingOrderOptions
from docling.models.table_structure_model import TableStructureModel
from docling.pipeline.base_pipeline import ConvertPipeline
from docling.utils.profiling import ProfilingScope, TimeRecorder
from docling.utils.utils import chunkify
@@ -436,15 +438,21 @@ class StandardPdfPipeline(ConvertPipeline):
)
)
self.ocr_model = self._make_ocr_model(art_path)
self.layout_model = LayoutModel(
layout_factory = get_layout_factory(
allow_external_plugins=self.pipeline_options.allow_external_plugins
)
self.layout_model = layout_factory.create_instance(
options=self.pipeline_options.layout_options,
artifacts_path=art_path,
accelerator_options=self.pipeline_options.accelerator_options,
options=self.pipeline_options.layout_options,
)
self.table_model = TableStructureModel(
table_factory = get_table_structure_factory(
allow_external_plugins=self.pipeline_options.allow_external_plugins
)
self.table_model = table_factory.create_instance(
options=self.pipeline_options.table_structure_options,
enabled=self.pipeline_options.do_table_structure,
artifacts_path=art_path,
options=self.pipeline_options.table_structure_options,
accelerator_options=self.pipeline_options.accelerator_options,
)
self.assemble_model = PageAssembleModel(options=PageAssembleOptions())