Many layout processing improvements, add document index type

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>
This commit is contained in:
Christoph Auer 2024-12-11 17:08:35 +01:00
parent d094c4990a
commit 57d51ede04
6 changed files with 90 additions and 23 deletions

View File

@ -63,7 +63,7 @@ _log = logging.getLogger(__name__)
layout_label_to_ds_type = { layout_label_to_ds_type = {
DocItemLabel.TITLE: "title", DocItemLabel.TITLE: "title",
DocItemLabel.DOCUMENT_INDEX: "table-of-contents", DocItemLabel.DOCUMENT_INDEX: "table",
DocItemLabel.SECTION_HEADER: "subtitle-level-1", DocItemLabel.SECTION_HEADER: "subtitle-level-1",
DocItemLabel.CHECKBOX_SELECTED: "checkbox-selected", DocItemLabel.CHECKBOX_SELECTED: "checkbox-selected",
DocItemLabel.CHECKBOX_UNSELECTED: "checkbox-unselected", DocItemLabel.CHECKBOX_UNSELECTED: "checkbox-unselected",

View File

@ -7,7 +7,7 @@ from typing import Iterable, List
from docling_core.types.doc import CoordOrigin, DocItemLabel from docling_core.types.doc import CoordOrigin, DocItemLabel
from docling_ibm_models.layoutmodel.layout_predictor import LayoutPredictor from docling_ibm_models.layoutmodel.layout_predictor import LayoutPredictor
from PIL import Image, ImageDraw from PIL import Image, ImageDraw, ImageFont
from docling.datamodel.base_models import ( from docling.datamodel.base_models import (
BoundingBox, BoundingBox,
@ -44,7 +44,7 @@ class LayoutModel(BasePageModel):
] ]
PAGE_HEADER_LABELS = [DocItemLabel.PAGE_HEADER, DocItemLabel.PAGE_FOOTER] PAGE_HEADER_LABELS = [DocItemLabel.PAGE_HEADER, DocItemLabel.PAGE_FOOTER]
TABLE_LABEL = DocItemLabel.TABLE TABLE_LABELS = [DocItemLabel.TABLE, DocItemLabel.DOCUMENT_INDEX]
FIGURE_LABEL = DocItemLabel.PICTURE FIGURE_LABEL = DocItemLabel.PICTURE
FORMULA_LABEL = DocItemLabel.FORMULA FORMULA_LABEL = DocItemLabel.FORMULA
CONTAINER_LABELS = [DocItemLabel.FORM, DocItemLabel.KEY_VALUE_REGION] CONTAINER_LABELS = [DocItemLabel.FORM, DocItemLabel.KEY_VALUE_REGION]
@ -62,6 +62,7 @@ class LayoutModel(BasePageModel):
Draws a page image side by side with clusters filtered into two categories: Draws a page image side by side with clusters filtered into two categories:
- Left: Clusters excluding FORM, KEY_VALUE_REGION, and PICTURE. - Left: Clusters excluding FORM, KEY_VALUE_REGION, and PICTURE.
- Right: Clusters including FORM, KEY_VALUE_REGION, and PICTURE. - Right: Clusters including FORM, KEY_VALUE_REGION, and PICTURE.
Includes label names and confidence scores for each cluster.
""" """
label_to_color = { label_to_color = {
DocItemLabel.TEXT: (255, 255, 153), # Light Yellow DocItemLabel.TEXT: (255, 255, 153), # Light Yellow
@ -103,9 +104,18 @@ class LayoutModel(BasePageModel):
# Function to draw clusters on an image # Function to draw clusters on an image
def draw_clusters(image, clusters): def draw_clusters(image, clusters):
draw = ImageDraw.Draw(image, "RGBA") draw = ImageDraw.Draw(image, "RGBA")
# Create a smaller font for the labels
try:
font = ImageFont.truetype("arial.ttf", 12)
except OSError:
# Fallback to default font if arial is not available
font = ImageFont.load_default()
for c_tl in clusters: for c_tl in clusters:
all_clusters = [c_tl, *c_tl.children] all_clusters = [c_tl, *c_tl.children]
for c in all_clusters: for c in all_clusters:
# Draw cells first (underneath)
cell_color = (0, 0, 0, 40) # Transparent black for cells cell_color = (0, 0, 0, 40) # Transparent black for cells
for tc in c.cells: for tc in c.cells:
cx0, cy0, cx1, cy1 = tc.bbox.as_tuple() cx0, cy0, cx1, cy1 = tc.bbox.as_tuple()
@ -115,21 +125,44 @@ class LayoutModel(BasePageModel):
fill=cell_color, fill=cell_color,
) )
# Draw cluster rectangle
x0, y0, x1, y1 = c.bbox.as_tuple() x0, y0, x1, y1 = c.bbox.as_tuple()
cluster_fill_color = ( cluster_fill_color = (*list(label_to_color.get(c.label)), 70)
*list(label_to_color.get(c.label)), # type: ignore cluster_outline_color = (*list(label_to_color.get(c.label)), 255)
70,
)
cluster_outline_color = (
*list(label_to_color.get(c.label)), # type: ignore
255,
)
draw.rectangle( draw.rectangle(
[(x0, y0), (x1, y1)], [(x0, y0), (x1, y1)],
outline=cluster_outline_color, outline=cluster_outline_color,
fill=cluster_fill_color, fill=cluster_fill_color,
) )
# Add label name and confidence
label_text = f"{c.label.name} ({c.confidence:.2f})"
# Create semi-transparent background for text
text_bbox = draw.textbbox((x0, y0), label_text, font=font)
text_bg_padding = 2
draw.rectangle(
[
(
text_bbox[0] - text_bg_padding,
text_bbox[1] - text_bg_padding,
),
(
text_bbox[2] + text_bg_padding,
text_bbox[3] + text_bg_padding,
),
],
fill=(255, 255, 255, 180), # Semi-transparent white
)
# Draw text
draw.text(
(x0, y0),
label_text,
fill=(0, 0, 0, 255), # Solid black
font=font,
)
# Draw clusters on both images # Draw clusters on both images
draw_clusters(left_image, left_clusters) draw_clusters(left_image, left_clusters)
draw_clusters(right_image, right_clusters) draw_clusters(right_image, right_clusters)
@ -277,8 +310,9 @@ class LayoutModel(BasePageModel):
) )
# Apply postprocessing # Apply postprocessing
processed_clusters, processed_cells = LayoutPostprocessor( processed_clusters, processed_cells = LayoutPostprocessor(
page.cells, clusters page.cells, clusters, page.size
).postprocess() ).postprocess()
# processed_clusters, processed_cells = clusters, page.cells # processed_clusters, processed_cells = clusters, page.cells

View File

@ -95,7 +95,7 @@ class PageAssembleModel(BasePageModel):
headers.append(text_el) headers.append(text_el)
else: else:
body.append(text_el) body.append(text_el)
elif cluster.label == LayoutModel.TABLE_LABEL: elif cluster.label in LayoutModel.TABLE_LABELS:
tbl = None tbl = None
if page.predictions.tablestructure: if page.predictions.tablestructure:
tbl = page.predictions.tablestructure.table_map.get( tbl = page.predictions.tablestructure.table_map.get(

View File

@ -133,7 +133,8 @@ class TableStructureModel(BasePageModel):
], ],
) )
for cluster in page.predictions.layout.clusters for cluster in page.predictions.layout.clusters
if cluster.label == DocItemLabel.TABLE if cluster.label
in [DocItemLabel.TABLE, DocItemLabel.DOCUMENT_INDEX]
] ]
if not len(in_tables): if not len(in_tables):
yield page yield page
@ -198,7 +199,7 @@ class TableStructureModel(BasePageModel):
id=table_cluster.id, id=table_cluster.id,
page_no=page.page_no, page_no=page.page_no,
cluster=table_cluster, cluster=table_cluster,
label=DocItemLabel.TABLE, label=table_cluster.label,
) )
page.predictions.tablestructure.table_map[ page.predictions.tablestructure.table_map[

View File

@ -169,6 +169,8 @@ def to_docling_document(doc_glm, update_name_label=False) -> DoclingDocument:
current_list = None current_list = None
text = "" text = ""
caption_refs = [] caption_refs = []
item_label = DocItemLabel(pelem["name"])
for caption in obj["captions"]: for caption in obj["captions"]:
text += caption["text"] text += caption["text"]
@ -254,7 +256,7 @@ def to_docling_document(doc_glm, update_name_label=False) -> DoclingDocument:
), ),
) )
tbl = doc.add_table(data=tbl_data, prov=prov) tbl = doc.add_table(data=tbl_data, prov=prov, label=item_label)
tbl.captions.extend(caption_refs) tbl.captions.extend(caption_refs)
elif ptype in ["form", "key_value_region"]: elif ptype in ["form", "key_value_region"]:

View File

@ -4,7 +4,7 @@ import sys
from collections import defaultdict from collections import defaultdict
from typing import Dict, List, Set, Tuple from typing import Dict, List, Set, Tuple
from docling_core.types.doc import DocItemLabel from docling_core.types.doc import DocItemLabel, Size
from rtree import index from rtree import index
from docling.datamodel.base_models import BoundingBox, Cell, Cluster from docling.datamodel.base_models import BoundingBox, Cell, Cluster
@ -152,7 +152,12 @@ class LayoutPostprocessor:
"wrapper": {"area_threshold": 2.0, "conf_threshold": 0.2}, "wrapper": {"area_threshold": 2.0, "conf_threshold": 0.2},
} }
WRAPPER_TYPES = {DocItemLabel.FORM, DocItemLabel.KEY_VALUE_REGION} WRAPPER_TYPES = {
DocItemLabel.FORM,
DocItemLabel.KEY_VALUE_REGION,
DocItemLabel.TABLE,
DocItemLabel.DOCUMENT_INDEX,
}
SPECIAL_TYPES = WRAPPER_TYPES | {DocItemLabel.PICTURE} SPECIAL_TYPES = WRAPPER_TYPES | {DocItemLabel.PICTURE}
CONFIDENCE_THRESHOLDS = { CONFIDENCE_THRESHOLDS = {
@ -164,7 +169,7 @@ class LayoutPostprocessor:
DocItemLabel.PAGE_HEADER: 0.5, DocItemLabel.PAGE_HEADER: 0.5,
DocItemLabel.PICTURE: 0.5, DocItemLabel.PICTURE: 0.5,
DocItemLabel.SECTION_HEADER: 0.45, DocItemLabel.SECTION_HEADER: 0.45,
DocItemLabel.TABLE: 0.35, DocItemLabel.TABLE: 0.5,
DocItemLabel.TEXT: 0.55, # 0.45, DocItemLabel.TEXT: 0.55, # 0.45,
DocItemLabel.TITLE: 0.45, DocItemLabel.TITLE: 0.45,
DocItemLabel.CODE: 0.45, DocItemLabel.CODE: 0.45,
@ -176,14 +181,15 @@ class LayoutPostprocessor:
} }
LABEL_REMAPPING = { LABEL_REMAPPING = {
DocItemLabel.DOCUMENT_INDEX: DocItemLabel.TABLE, # DocItemLabel.DOCUMENT_INDEX: DocItemLabel.TABLE,
DocItemLabel.TITLE: DocItemLabel.SECTION_HEADER, DocItemLabel.TITLE: DocItemLabel.SECTION_HEADER,
} }
def __init__(self, cells: List[Cell], clusters: List[Cluster]): def __init__(self, cells: List[Cell], clusters: List[Cluster], page_size: Size):
"""Initialize processor with cells and clusters.""" """Initialize processor with cells and clusters."""
"""Initialize processor with cells and spatial indices.""" """Initialize processor with cells and spatial indices."""
self.cells = cells self.cells = cells
self.page_size = page_size
self.regular_clusters = [ self.regular_clusters = [
c for c in clusters if c.label not in self.SPECIAL_TYPES c for c in clusters if c.label not in self.SPECIAL_TYPES
] ]
@ -281,6 +287,19 @@ class LayoutPostprocessor:
special_clusters = self._handle_cross_type_overlaps(special_clusters) special_clusters = self._handle_cross_type_overlaps(special_clusters)
# Calculate page area from known page size
page_area = self.page_size.width * self.page_size.height
if page_area > 0:
# Filter out full-page pictures
special_clusters = [
cluster
for cluster in special_clusters
if not (
cluster.label == DocItemLabel.PICTURE
and cluster.bbox.area() / page_area > 0.90
)
]
for special in special_clusters: for special in special_clusters:
contained = [] contained = []
for cluster in self.regular_clusters: for cluster in self.regular_clusters:
@ -313,6 +332,13 @@ class LayoutPostprocessor:
b=max(c.bbox.b for c in contained), b=max(c.bbox.b for c in contained),
) )
# Collect all cells from children
all_cells = []
for child in contained:
all_cells.extend(child.cells)
special.cells = self._deduplicate_cells(all_cells)
special.cells = self._sort_cells(special.cells)
picture_clusters = [ picture_clusters = [
c for c in special_clusters if c.label == DocItemLabel.PICTURE c for c in special_clusters if c.label == DocItemLabel.PICTURE
] ]
@ -338,7 +364,7 @@ class LayoutPostprocessor:
wrappers_to_remove = set() wrappers_to_remove = set()
for wrapper in special_clusters: for wrapper in special_clusters:
if wrapper.label != DocItemLabel.KEY_VALUE_REGION: if wrapper.label not in self.WRAPPER_TYPES:
continue # only treat KEY_VALUE_REGION for now. continue # only treat KEY_VALUE_REGION for now.
for regular in self.regular_clusters: for regular in self.regular_clusters:
@ -348,8 +374,12 @@ class LayoutPostprocessor:
wrapper_area = wrapper.bbox.area() wrapper_area = wrapper.bbox.area()
overlap_ratio = overlap / wrapper_area overlap_ratio = overlap / wrapper_area
conf_diff = wrapper.confidence - regular.confidence
# If wrapper is mostly overlapping with a TABLE, remove the wrapper # If wrapper is mostly overlapping with a TABLE, remove the wrapper
if overlap_ratio > 0.8: # 80% overlap threshold if (
overlap_ratio > 0.9 and conf_diff < 0.1
): # self.OVERLAP_PARAMS["wrapper"]["conf_threshold"]): # 80% overlap threshold
wrappers_to_remove.add(wrapper.id) wrappers_to_remove.add(wrapper.id)
break break