Pass nested cluster processing through full pipeline

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>
This commit is contained in:
Christoph Auer 2024-12-03 13:08:45 +01:00
parent 7245cc6080
commit 4dcc738b6d
6 changed files with 62 additions and 35 deletions

View File

@ -132,6 +132,12 @@ class LayoutPrediction(BaseModel):
clusters: List[Cluster] = [] clusters: List[Cluster] = []
class ContainerElement(
BasePageElement
): # Used for Form and Key-Value-Regions, only for typing.
pass
class Table(BasePageElement): class Table(BasePageElement):
otsl_seq: List[str] otsl_seq: List[str]
num_rows: int = 0 num_rows: int = 0
@ -171,7 +177,7 @@ class PagePredictions(BaseModel):
equations_prediction: Optional[EquationPrediction] = None equations_prediction: Optional[EquationPrediction] = None
PageElement = Union[TextElement, Table, FigureElement] PageElement = Union[TextElement, Table, FigureElement, ContainerElement]
class AssembledUnit(BaseModel): class AssembledUnit(BaseModel):

View File

@ -77,6 +77,8 @@ layout_label_to_ds_type = {
DocItemLabel.PICTURE: "figure", DocItemLabel.PICTURE: "figure",
DocItemLabel.TEXT: "paragraph", DocItemLabel.TEXT: "paragraph",
DocItemLabel.PARAGRAPH: "paragraph", DocItemLabel.PARAGRAPH: "paragraph",
DocItemLabel.FORM: DocItemLabel.FORM.value,
DocItemLabel.KEY_VALUE_REGION: DocItemLabel.KEY_VALUE_REGION.value,
} }
_EMPTY_DOCLING_DOC = DoclingDocument(name="dummy") _EMPTY_DOCLING_DOC = DoclingDocument(name="dummy")

View File

@ -24,9 +24,15 @@ from docling_core.types.legacy_doc.document import (
from docling_core.types.legacy_doc.document import CCSFileInfoObject as DsFileInfoObject from docling_core.types.legacy_doc.document import CCSFileInfoObject as DsFileInfoObject
from docling_core.types.legacy_doc.document import ExportedCCSDocument as DsDocument from docling_core.types.legacy_doc.document import ExportedCCSDocument as DsDocument
from PIL import ImageDraw from PIL import ImageDraw
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict, TypeAdapter
from docling.datamodel.base_models import Cluster, FigureElement, Table, TextElement from docling.datamodel.base_models import (
Cluster,
ContainerElement,
FigureElement,
Table,
TextElement,
)
from docling.datamodel.document import ConversionResult, layout_label_to_ds_type from docling.datamodel.document import ConversionResult, layout_label_to_ds_type
from docling.datamodel.settings import settings from docling.datamodel.settings import settings
from docling.utils.profiling import ProfilingScope, TimeRecorder from docling.utils.profiling import ProfilingScope, TimeRecorder
@ -45,7 +51,9 @@ class GlmModel:
if self.options.model_names != "": if self.options.model_names != "":
load_pretrained_nlp_models() load_pretrained_nlp_models()
self.model = init_nlp_model(model_names=self.options.model_names) self.model = init_nlp_model(
model_names=self.options.model_names, loglevel="ERROR"
)
def _to_legacy_document(self, conv_res) -> DsDocument: def _to_legacy_document(self, conv_res) -> DsDocument:
title = "" title = ""
@ -207,7 +215,26 @@ class GlmModel:
) )
], ],
obj_type=layout_label_to_ds_type.get(element.label), obj_type=layout_label_to_ds_type.get(element.label),
# data=[[]], payload=TypeAdapter(List[Cluster]).dump_python(
element.cluster.children
), # hack to channel child clusters through GLM
)
)
elif isinstance(element, ContainerElement):
main_text.append(
BaseText(
payload=TypeAdapter(List[Cluster]).dump_python(
element.cluster.children
), # hack to channel child clusters through GLM
obj_type=layout_label_to_ds_type.get(element.label),
name=element.label,
prov=[
Prov(
bbox=target_bbox,
page=element.page_no + 1,
span=[0, 0],
)
],
) )
) )

View File

@ -46,6 +46,8 @@ class LayoutModel(BasePageModel):
FIGURE_LABEL = DocItemLabel.PICTURE FIGURE_LABEL = DocItemLabel.PICTURE
FORMULA_LABEL = DocItemLabel.FORMULA FORMULA_LABEL = DocItemLabel.FORMULA
CONTAINER_LABELS = [DocItemLabel.FORM, DocItemLabel.KEY_VALUE_REGION]
def __init__(self, artifacts_path: Path): def __init__(self, artifacts_path: Path):
self.layout_predictor = LayoutPredictor(artifacts_path) # TODO temporary self.layout_predictor = LayoutPredictor(artifacts_path) # TODO temporary
@ -177,24 +179,6 @@ class LayoutModel(BasePageModel):
cells=[], cells=[],
) )
clusters.append(cluster) clusters.append(cluster)
#
# # Map cells to clusters
# # TODO: Remove, postprocess should take care of it anyway.
# for cell in page.cells:
# for cluster in clusters:
# if not cell.bbox.area() > 0:
# overlap_frac = 0.0
# else:
# overlap_frac = (
# cell.bbox.intersection_area_with(cluster.bbox)
# / cell.bbox.area()
# )
#
# if overlap_frac > 0.5:
# cluster.cells.append(cell)
# Pre-sort clusters
# clusters = self.sort_clusters_by_cell_order(clusters)
# DEBUG code: # DEBUG code:
def draw_clusters_and_cells( def draw_clusters_and_cells(
@ -299,18 +283,6 @@ class LayoutModel(BasePageModel):
clusters=processed_clusters clusters=processed_clusters
) )
# def check_for_overlaps(clusters):
# for i, cluster in enumerate(clusters):
# for j, other_cluster in enumerate(clusters):
# if i >= j or cluster.label in [DocItemLabel.FORM, DocItemLabel.KEY_VALUE_REGION, DocItemLabel.PICTURE]:
# continue
#
# overlap_area = cluster.bbox.intersection_area_with(other_cluster.bbox)
# if overlap_area > 0:
# print(f"Overlap detected between cluster {i} and cluster {j}")
# print(f"Cluster {i} bbox: {cluster.bbox}, Cluster {j} bbox: {other_cluster.bbox}")
# check_for_overlaps(processed_clusters)
if settings.debug.visualize_layout: if settings.debug.visualize_layout:
self.draw_clusters_and_cells_side_by_side( self.draw_clusters_and_cells_side_by_side(
conv_res, page, processed_clusters, mode_prefix="postprocessed" conv_res, page, processed_clusters, mode_prefix="postprocessed"

View File

@ -6,6 +6,7 @@ from pydantic import BaseModel
from docling.datamodel.base_models import ( from docling.datamodel.base_models import (
AssembledUnit, AssembledUnit,
ContainerElement,
FigureElement, FigureElement,
Page, Page,
PageElement, PageElement,
@ -159,6 +160,15 @@ class PageAssembleModel(BasePageModel):
) )
elements.append(equation) elements.append(equation)
body.append(equation) body.append(equation)
elif cluster.label in LayoutModel.CONTAINER_LABELS:
container_el = ContainerElement(
label=cluster.label,
id=cluster.id,
page_no=page.page_no,
cluster=cluster,
)
elements.append(container_el)
body.append(container_el)
page.assembled = AssembledUnit( page.assembled = AssembledUnit(
elements=elements, headers=headers, body=body elements=elements, headers=headers, body=body

View File

@ -1,5 +1,6 @@
import bisect import bisect
import logging import logging
import sys
from collections import defaultdict from collections import defaultdict
from typing import Dict, List, Set, Tuple from typing import Dict, List, Set, Tuple
@ -279,7 +280,16 @@ class LayoutPostprocessor:
contained.append(cluster) contained.append(cluster)
if contained: if contained:
# Sort contained clusters by minimum cell ID
contained.sort(
key=lambda cluster: (
min(cell.id for cell in cluster.cells)
if cluster.cells
else sys.maxsize
)
)
special.children = contained special.children = contained
# Adjust bbox only for wrapper types # Adjust bbox only for wrapper types
if special.label in self.WRAPPER_TYPES: if special.label in self.WRAPPER_TYPES:
special.bbox = BoundingBox( special.bbox = BoundingBox(