Many layout processing improvements, add document index type

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>
This commit is contained in:
Christoph Auer 2024-12-11 17:08:35 +01:00
parent d094c4990a
commit 57d51ede04
6 changed files with 90 additions and 23 deletions

View File

@ -63,7 +63,7 @@ _log = logging.getLogger(__name__)
layout_label_to_ds_type = {
DocItemLabel.TITLE: "title",
DocItemLabel.DOCUMENT_INDEX: "table-of-contents",
DocItemLabel.DOCUMENT_INDEX: "table",
DocItemLabel.SECTION_HEADER: "subtitle-level-1",
DocItemLabel.CHECKBOX_SELECTED: "checkbox-selected",
DocItemLabel.CHECKBOX_UNSELECTED: "checkbox-unselected",

View File

@ -7,7 +7,7 @@ from typing import Iterable, List
from docling_core.types.doc import CoordOrigin, DocItemLabel
from docling_ibm_models.layoutmodel.layout_predictor import LayoutPredictor
from PIL import Image, ImageDraw
from PIL import Image, ImageDraw, ImageFont
from docling.datamodel.base_models import (
BoundingBox,
@ -44,7 +44,7 @@ class LayoutModel(BasePageModel):
]
PAGE_HEADER_LABELS = [DocItemLabel.PAGE_HEADER, DocItemLabel.PAGE_FOOTER]
TABLE_LABEL = DocItemLabel.TABLE
TABLE_LABELS = [DocItemLabel.TABLE, DocItemLabel.DOCUMENT_INDEX]
FIGURE_LABEL = DocItemLabel.PICTURE
FORMULA_LABEL = DocItemLabel.FORMULA
CONTAINER_LABELS = [DocItemLabel.FORM, DocItemLabel.KEY_VALUE_REGION]
@ -62,6 +62,7 @@ class LayoutModel(BasePageModel):
Draws a page image side by side with clusters filtered into two categories:
- Left: Clusters excluding FORM, KEY_VALUE_REGION, and PICTURE.
- Right: Clusters including FORM, KEY_VALUE_REGION, and PICTURE.
Includes label names and confidence scores for each cluster.
"""
label_to_color = {
DocItemLabel.TEXT: (255, 255, 153), # Light Yellow
@ -103,9 +104,18 @@ class LayoutModel(BasePageModel):
# Function to draw clusters on an image
def draw_clusters(image, clusters):
draw = ImageDraw.Draw(image, "RGBA")
# Create a smaller font for the labels
try:
font = ImageFont.truetype("arial.ttf", 12)
except OSError:
# Fallback to default font if arial is not available
font = ImageFont.load_default()
for c_tl in clusters:
all_clusters = [c_tl, *c_tl.children]
for c in all_clusters:
# Draw cells first (underneath)
cell_color = (0, 0, 0, 40) # Transparent black for cells
for tc in c.cells:
cx0, cy0, cx1, cy1 = tc.bbox.as_tuple()
@ -115,21 +125,44 @@ class LayoutModel(BasePageModel):
fill=cell_color,
)
# Draw cluster rectangle
x0, y0, x1, y1 = c.bbox.as_tuple()
cluster_fill_color = (
*list(label_to_color.get(c.label)), # type: ignore
70,
)
cluster_outline_color = (
*list(label_to_color.get(c.label)), # type: ignore
255,
)
cluster_fill_color = (*list(label_to_color.get(c.label)), 70)
cluster_outline_color = (*list(label_to_color.get(c.label)), 255)
draw.rectangle(
[(x0, y0), (x1, y1)],
outline=cluster_outline_color,
fill=cluster_fill_color,
)
# Add label name and confidence
label_text = f"{c.label.name} ({c.confidence:.2f})"
# Create semi-transparent background for text
text_bbox = draw.textbbox((x0, y0), label_text, font=font)
text_bg_padding = 2
draw.rectangle(
[
(
text_bbox[0] - text_bg_padding,
text_bbox[1] - text_bg_padding,
),
(
text_bbox[2] + text_bg_padding,
text_bbox[3] + text_bg_padding,
),
],
fill=(255, 255, 255, 180), # Semi-transparent white
)
# Draw text
draw.text(
(x0, y0),
label_text,
fill=(0, 0, 0, 255), # Solid black
font=font,
)
# Draw clusters on both images
draw_clusters(left_image, left_clusters)
draw_clusters(right_image, right_clusters)
@ -277,8 +310,9 @@ class LayoutModel(BasePageModel):
)
# Apply postprocessing
processed_clusters, processed_cells = LayoutPostprocessor(
page.cells, clusters
page.cells, clusters, page.size
).postprocess()
# processed_clusters, processed_cells = clusters, page.cells

View File

@ -95,7 +95,7 @@ class PageAssembleModel(BasePageModel):
headers.append(text_el)
else:
body.append(text_el)
elif cluster.label == LayoutModel.TABLE_LABEL:
elif cluster.label in LayoutModel.TABLE_LABELS:
tbl = None
if page.predictions.tablestructure:
tbl = page.predictions.tablestructure.table_map.get(

View File

@ -133,7 +133,8 @@ class TableStructureModel(BasePageModel):
],
)
for cluster in page.predictions.layout.clusters
if cluster.label == DocItemLabel.TABLE
if cluster.label
in [DocItemLabel.TABLE, DocItemLabel.DOCUMENT_INDEX]
]
if not len(in_tables):
yield page
@ -198,7 +199,7 @@ class TableStructureModel(BasePageModel):
id=table_cluster.id,
page_no=page.page_no,
cluster=table_cluster,
label=DocItemLabel.TABLE,
label=table_cluster.label,
)
page.predictions.tablestructure.table_map[

View File

@ -169,6 +169,8 @@ def to_docling_document(doc_glm, update_name_label=False) -> DoclingDocument:
current_list = None
text = ""
caption_refs = []
item_label = DocItemLabel(pelem["name"])
for caption in obj["captions"]:
text += caption["text"]
@ -254,7 +256,7 @@ def to_docling_document(doc_glm, update_name_label=False) -> DoclingDocument:
),
)
tbl = doc.add_table(data=tbl_data, prov=prov)
tbl = doc.add_table(data=tbl_data, prov=prov, label=item_label)
tbl.captions.extend(caption_refs)
elif ptype in ["form", "key_value_region"]:

View File

@ -4,7 +4,7 @@ import sys
from collections import defaultdict
from typing import Dict, List, Set, Tuple
from docling_core.types.doc import DocItemLabel
from docling_core.types.doc import DocItemLabel, Size
from rtree import index
from docling.datamodel.base_models import BoundingBox, Cell, Cluster
@ -152,7 +152,12 @@ class LayoutPostprocessor:
"wrapper": {"area_threshold": 2.0, "conf_threshold": 0.2},
}
WRAPPER_TYPES = {DocItemLabel.FORM, DocItemLabel.KEY_VALUE_REGION}
WRAPPER_TYPES = {
DocItemLabel.FORM,
DocItemLabel.KEY_VALUE_REGION,
DocItemLabel.TABLE,
DocItemLabel.DOCUMENT_INDEX,
}
SPECIAL_TYPES = WRAPPER_TYPES | {DocItemLabel.PICTURE}
CONFIDENCE_THRESHOLDS = {
@ -164,7 +169,7 @@ class LayoutPostprocessor:
DocItemLabel.PAGE_HEADER: 0.5,
DocItemLabel.PICTURE: 0.5,
DocItemLabel.SECTION_HEADER: 0.45,
DocItemLabel.TABLE: 0.35,
DocItemLabel.TABLE: 0.5,
DocItemLabel.TEXT: 0.55, # 0.45,
DocItemLabel.TITLE: 0.45,
DocItemLabel.CODE: 0.45,
@ -176,14 +181,15 @@ class LayoutPostprocessor:
}
LABEL_REMAPPING = {
DocItemLabel.DOCUMENT_INDEX: DocItemLabel.TABLE,
# DocItemLabel.DOCUMENT_INDEX: DocItemLabel.TABLE,
DocItemLabel.TITLE: DocItemLabel.SECTION_HEADER,
}
def __init__(self, cells: List[Cell], clusters: List[Cluster]):
def __init__(self, cells: List[Cell], clusters: List[Cluster], page_size: Size):
"""Initialize processor with cells and clusters."""
"""Initialize processor with cells and spatial indices."""
self.cells = cells
self.page_size = page_size
self.regular_clusters = [
c for c in clusters if c.label not in self.SPECIAL_TYPES
]
@ -281,6 +287,19 @@ class LayoutPostprocessor:
special_clusters = self._handle_cross_type_overlaps(special_clusters)
# Calculate page area from known page size
page_area = self.page_size.width * self.page_size.height
if page_area > 0:
# Filter out full-page pictures
special_clusters = [
cluster
for cluster in special_clusters
if not (
cluster.label == DocItemLabel.PICTURE
and cluster.bbox.area() / page_area > 0.90
)
]
for special in special_clusters:
contained = []
for cluster in self.regular_clusters:
@ -313,6 +332,13 @@ class LayoutPostprocessor:
b=max(c.bbox.b for c in contained),
)
# Collect all cells from children
all_cells = []
for child in contained:
all_cells.extend(child.cells)
special.cells = self._deduplicate_cells(all_cells)
special.cells = self._sort_cells(special.cells)
picture_clusters = [
c for c in special_clusters if c.label == DocItemLabel.PICTURE
]
@ -338,7 +364,7 @@ class LayoutPostprocessor:
wrappers_to_remove = set()
for wrapper in special_clusters:
if wrapper.label != DocItemLabel.KEY_VALUE_REGION:
if wrapper.label not in self.WRAPPER_TYPES:
continue # only treat KEY_VALUE_REGION for now.
for regular in self.regular_clusters:
@ -348,8 +374,12 @@ class LayoutPostprocessor:
wrapper_area = wrapper.bbox.area()
overlap_ratio = overlap / wrapper_area
conf_diff = wrapper.confidence - regular.confidence
# If wrapper is mostly overlapping with a TABLE, remove the wrapper
if overlap_ratio > 0.8: # 80% overlap threshold
if (
overlap_ratio > 0.9 and conf_diff < 0.1
): # self.OVERLAP_PARAMS["wrapper"]["conf_threshold"]): # 80% overlap threshold
wrappers_to_remove.add(wrapper.id)
break