Pass nested cluster processing through full pipeline

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>
This commit is contained in:
Christoph Auer 2024-12-03 13:08:45 +01:00
parent 7245cc6080
commit 4dcc738b6d
6 changed files with 62 additions and 35 deletions

View File

@ -132,6 +132,12 @@ class LayoutPrediction(BaseModel):
clusters: List[Cluster] = []
class ContainerElement(
BasePageElement
): # Used for Form and Key-Value-Regions, only for typing.
pass
class Table(BasePageElement):
otsl_seq: List[str]
num_rows: int = 0
@ -171,7 +177,7 @@ class PagePredictions(BaseModel):
equations_prediction: Optional[EquationPrediction] = None
PageElement = Union[TextElement, Table, FigureElement]
PageElement = Union[TextElement, Table, FigureElement, ContainerElement]
class AssembledUnit(BaseModel):

View File

@ -77,6 +77,8 @@ layout_label_to_ds_type = {
DocItemLabel.PICTURE: "figure",
DocItemLabel.TEXT: "paragraph",
DocItemLabel.PARAGRAPH: "paragraph",
DocItemLabel.FORM: DocItemLabel.FORM.value,
DocItemLabel.KEY_VALUE_REGION: DocItemLabel.KEY_VALUE_REGION.value,
}
_EMPTY_DOCLING_DOC = DoclingDocument(name="dummy")

View File

@ -24,9 +24,15 @@ from docling_core.types.legacy_doc.document import (
from docling_core.types.legacy_doc.document import CCSFileInfoObject as DsFileInfoObject
from docling_core.types.legacy_doc.document import ExportedCCSDocument as DsDocument
from PIL import ImageDraw
from pydantic import BaseModel, ConfigDict
from pydantic import BaseModel, ConfigDict, TypeAdapter
from docling.datamodel.base_models import Cluster, FigureElement, Table, TextElement
from docling.datamodel.base_models import (
Cluster,
ContainerElement,
FigureElement,
Table,
TextElement,
)
from docling.datamodel.document import ConversionResult, layout_label_to_ds_type
from docling.datamodel.settings import settings
from docling.utils.profiling import ProfilingScope, TimeRecorder
@ -45,7 +51,9 @@ class GlmModel:
if self.options.model_names != "":
load_pretrained_nlp_models()
self.model = init_nlp_model(model_names=self.options.model_names)
self.model = init_nlp_model(
model_names=self.options.model_names, loglevel="ERROR"
)
def _to_legacy_document(self, conv_res) -> DsDocument:
title = ""
@ -207,7 +215,26 @@ class GlmModel:
)
],
obj_type=layout_label_to_ds_type.get(element.label),
# data=[[]],
payload=TypeAdapter(List[Cluster]).dump_python(
element.cluster.children
), # hack to channel child clusters through GLM
)
)
elif isinstance(element, ContainerElement):
main_text.append(
BaseText(
payload=TypeAdapter(List[Cluster]).dump_python(
element.cluster.children
), # hack to channel child clusters through GLM
obj_type=layout_label_to_ds_type.get(element.label),
name=element.label,
prov=[
Prov(
bbox=target_bbox,
page=element.page_no + 1,
span=[0, 0],
)
],
)
)

View File

@ -46,6 +46,8 @@ class LayoutModel(BasePageModel):
FIGURE_LABEL = DocItemLabel.PICTURE
FORMULA_LABEL = DocItemLabel.FORMULA
CONTAINER_LABELS = [DocItemLabel.FORM, DocItemLabel.KEY_VALUE_REGION]
def __init__(self, artifacts_path: Path):
self.layout_predictor = LayoutPredictor(artifacts_path) # TODO temporary
@ -177,24 +179,6 @@ class LayoutModel(BasePageModel):
cells=[],
)
clusters.append(cluster)
#
# # Map cells to clusters
# # TODO: Remove, postprocess should take care of it anyway.
# for cell in page.cells:
# for cluster in clusters:
# if not cell.bbox.area() > 0:
# overlap_frac = 0.0
# else:
# overlap_frac = (
# cell.bbox.intersection_area_with(cluster.bbox)
# / cell.bbox.area()
# )
#
# if overlap_frac > 0.5:
# cluster.cells.append(cell)
# Pre-sort clusters
# clusters = self.sort_clusters_by_cell_order(clusters)
# DEBUG code:
def draw_clusters_and_cells(
@ -299,18 +283,6 @@ class LayoutModel(BasePageModel):
clusters=processed_clusters
)
# def check_for_overlaps(clusters):
# for i, cluster in enumerate(clusters):
# for j, other_cluster in enumerate(clusters):
# if i >= j or cluster.label in [DocItemLabel.FORM, DocItemLabel.KEY_VALUE_REGION, DocItemLabel.PICTURE]:
# continue
#
# overlap_area = cluster.bbox.intersection_area_with(other_cluster.bbox)
# if overlap_area > 0:
# print(f"Overlap detected between cluster {i} and cluster {j}")
# print(f"Cluster {i} bbox: {cluster.bbox}, Cluster {j} bbox: {other_cluster.bbox}")
# check_for_overlaps(processed_clusters)
if settings.debug.visualize_layout:
self.draw_clusters_and_cells_side_by_side(
conv_res, page, processed_clusters, mode_prefix="postprocessed"

View File

@ -6,6 +6,7 @@ from pydantic import BaseModel
from docling.datamodel.base_models import (
AssembledUnit,
ContainerElement,
FigureElement,
Page,
PageElement,
@ -159,6 +160,15 @@ class PageAssembleModel(BasePageModel):
)
elements.append(equation)
body.append(equation)
elif cluster.label in LayoutModel.CONTAINER_LABELS:
container_el = ContainerElement(
label=cluster.label,
id=cluster.id,
page_no=page.page_no,
cluster=cluster,
)
elements.append(container_el)
body.append(container_el)
page.assembled = AssembledUnit(
elements=elements, headers=headers, body=body

View File

@ -1,5 +1,6 @@
import bisect
import logging
import sys
from collections import defaultdict
from typing import Dict, List, Set, Tuple
@ -279,7 +280,16 @@ class LayoutPostprocessor:
contained.append(cluster)
if contained:
# Sort contained clusters by minimum cell ID
contained.sort(
key=lambda cluster: (
min(cell.id for cell in cluster.cells)
if cluster.cells
else sys.maxsize
)
)
special.children = contained
# Adjust bbox only for wrapper types
if special.label in self.WRAPPER_TYPES:
special.bbox = BoundingBox(