mirror of
https://github.com/DS4SD/docling.git
synced 2025-07-27 04:24:45 +00:00
Many layout processing improvements, add document index type
Signed-off-by: Christoph Auer <cau@zurich.ibm.com> Signed-off-by: Christoph Auer <cau@zurich.ibm.com>
This commit is contained in:
parent
d094c4990a
commit
57d51ede04
@ -63,7 +63,7 @@ _log = logging.getLogger(__name__)
|
||||
|
||||
layout_label_to_ds_type = {
|
||||
DocItemLabel.TITLE: "title",
|
||||
DocItemLabel.DOCUMENT_INDEX: "table-of-contents",
|
||||
DocItemLabel.DOCUMENT_INDEX: "table",
|
||||
DocItemLabel.SECTION_HEADER: "subtitle-level-1",
|
||||
DocItemLabel.CHECKBOX_SELECTED: "checkbox-selected",
|
||||
DocItemLabel.CHECKBOX_UNSELECTED: "checkbox-unselected",
|
||||
|
@ -7,7 +7,7 @@ from typing import Iterable, List
|
||||
|
||||
from docling_core.types.doc import CoordOrigin, DocItemLabel
|
||||
from docling_ibm_models.layoutmodel.layout_predictor import LayoutPredictor
|
||||
from PIL import Image, ImageDraw
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
|
||||
from docling.datamodel.base_models import (
|
||||
BoundingBox,
|
||||
@ -44,7 +44,7 @@ class LayoutModel(BasePageModel):
|
||||
]
|
||||
PAGE_HEADER_LABELS = [DocItemLabel.PAGE_HEADER, DocItemLabel.PAGE_FOOTER]
|
||||
|
||||
TABLE_LABEL = DocItemLabel.TABLE
|
||||
TABLE_LABELS = [DocItemLabel.TABLE, DocItemLabel.DOCUMENT_INDEX]
|
||||
FIGURE_LABEL = DocItemLabel.PICTURE
|
||||
FORMULA_LABEL = DocItemLabel.FORMULA
|
||||
CONTAINER_LABELS = [DocItemLabel.FORM, DocItemLabel.KEY_VALUE_REGION]
|
||||
@ -62,6 +62,7 @@ class LayoutModel(BasePageModel):
|
||||
Draws a page image side by side with clusters filtered into two categories:
|
||||
- Left: Clusters excluding FORM, KEY_VALUE_REGION, and PICTURE.
|
||||
- Right: Clusters including FORM, KEY_VALUE_REGION, and PICTURE.
|
||||
Includes label names and confidence scores for each cluster.
|
||||
"""
|
||||
label_to_color = {
|
||||
DocItemLabel.TEXT: (255, 255, 153), # Light Yellow
|
||||
@ -103,9 +104,18 @@ class LayoutModel(BasePageModel):
|
||||
# Function to draw clusters on an image
|
||||
def draw_clusters(image, clusters):
|
||||
draw = ImageDraw.Draw(image, "RGBA")
|
||||
|
||||
# Create a smaller font for the labels
|
||||
try:
|
||||
font = ImageFont.truetype("arial.ttf", 12)
|
||||
except OSError:
|
||||
# Fallback to default font if arial is not available
|
||||
font = ImageFont.load_default()
|
||||
|
||||
for c_tl in clusters:
|
||||
all_clusters = [c_tl, *c_tl.children]
|
||||
for c in all_clusters:
|
||||
# Draw cells first (underneath)
|
||||
cell_color = (0, 0, 0, 40) # Transparent black for cells
|
||||
for tc in c.cells:
|
||||
cx0, cy0, cx1, cy1 = tc.bbox.as_tuple()
|
||||
@ -115,21 +125,44 @@ class LayoutModel(BasePageModel):
|
||||
fill=cell_color,
|
||||
)
|
||||
|
||||
# Draw cluster rectangle
|
||||
x0, y0, x1, y1 = c.bbox.as_tuple()
|
||||
cluster_fill_color = (
|
||||
*list(label_to_color.get(c.label)), # type: ignore
|
||||
70,
|
||||
)
|
||||
cluster_outline_color = (
|
||||
*list(label_to_color.get(c.label)), # type: ignore
|
||||
255,
|
||||
)
|
||||
cluster_fill_color = (*list(label_to_color.get(c.label)), 70)
|
||||
cluster_outline_color = (*list(label_to_color.get(c.label)), 255)
|
||||
draw.rectangle(
|
||||
[(x0, y0), (x1, y1)],
|
||||
outline=cluster_outline_color,
|
||||
fill=cluster_fill_color,
|
||||
)
|
||||
|
||||
# Add label name and confidence
|
||||
label_text = f"{c.label.name} ({c.confidence:.2f})"
|
||||
|
||||
# Create semi-transparent background for text
|
||||
text_bbox = draw.textbbox((x0, y0), label_text, font=font)
|
||||
text_bg_padding = 2
|
||||
draw.rectangle(
|
||||
[
|
||||
(
|
||||
text_bbox[0] - text_bg_padding,
|
||||
text_bbox[1] - text_bg_padding,
|
||||
),
|
||||
(
|
||||
text_bbox[2] + text_bg_padding,
|
||||
text_bbox[3] + text_bg_padding,
|
||||
),
|
||||
],
|
||||
fill=(255, 255, 255, 180), # Semi-transparent white
|
||||
)
|
||||
|
||||
# Draw text
|
||||
draw.text(
|
||||
(x0, y0),
|
||||
label_text,
|
||||
fill=(0, 0, 0, 255), # Solid black
|
||||
font=font,
|
||||
)
|
||||
|
||||
# Draw clusters on both images
|
||||
draw_clusters(left_image, left_clusters)
|
||||
draw_clusters(right_image, right_clusters)
|
||||
@ -277,8 +310,9 @@ class LayoutModel(BasePageModel):
|
||||
)
|
||||
|
||||
# Apply postprocessing
|
||||
|
||||
processed_clusters, processed_cells = LayoutPostprocessor(
|
||||
page.cells, clusters
|
||||
page.cells, clusters, page.size
|
||||
).postprocess()
|
||||
# processed_clusters, processed_cells = clusters, page.cells
|
||||
|
||||
|
@ -95,7 +95,7 @@ class PageAssembleModel(BasePageModel):
|
||||
headers.append(text_el)
|
||||
else:
|
||||
body.append(text_el)
|
||||
elif cluster.label == LayoutModel.TABLE_LABEL:
|
||||
elif cluster.label in LayoutModel.TABLE_LABELS:
|
||||
tbl = None
|
||||
if page.predictions.tablestructure:
|
||||
tbl = page.predictions.tablestructure.table_map.get(
|
||||
|
@ -133,7 +133,8 @@ class TableStructureModel(BasePageModel):
|
||||
],
|
||||
)
|
||||
for cluster in page.predictions.layout.clusters
|
||||
if cluster.label == DocItemLabel.TABLE
|
||||
if cluster.label
|
||||
in [DocItemLabel.TABLE, DocItemLabel.DOCUMENT_INDEX]
|
||||
]
|
||||
if not len(in_tables):
|
||||
yield page
|
||||
@ -198,7 +199,7 @@ class TableStructureModel(BasePageModel):
|
||||
id=table_cluster.id,
|
||||
page_no=page.page_no,
|
||||
cluster=table_cluster,
|
||||
label=DocItemLabel.TABLE,
|
||||
label=table_cluster.label,
|
||||
)
|
||||
|
||||
page.predictions.tablestructure.table_map[
|
||||
|
@ -169,6 +169,8 @@ def to_docling_document(doc_glm, update_name_label=False) -> DoclingDocument:
|
||||
current_list = None
|
||||
text = ""
|
||||
caption_refs = []
|
||||
item_label = DocItemLabel(pelem["name"])
|
||||
|
||||
for caption in obj["captions"]:
|
||||
text += caption["text"]
|
||||
|
||||
@ -254,7 +256,7 @@ def to_docling_document(doc_glm, update_name_label=False) -> DoclingDocument:
|
||||
),
|
||||
)
|
||||
|
||||
tbl = doc.add_table(data=tbl_data, prov=prov)
|
||||
tbl = doc.add_table(data=tbl_data, prov=prov, label=item_label)
|
||||
tbl.captions.extend(caption_refs)
|
||||
|
||||
elif ptype in ["form", "key_value_region"]:
|
||||
|
@ -4,7 +4,7 @@ import sys
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List, Set, Tuple
|
||||
|
||||
from docling_core.types.doc import DocItemLabel
|
||||
from docling_core.types.doc import DocItemLabel, Size
|
||||
from rtree import index
|
||||
|
||||
from docling.datamodel.base_models import BoundingBox, Cell, Cluster
|
||||
@ -152,7 +152,12 @@ class LayoutPostprocessor:
|
||||
"wrapper": {"area_threshold": 2.0, "conf_threshold": 0.2},
|
||||
}
|
||||
|
||||
WRAPPER_TYPES = {DocItemLabel.FORM, DocItemLabel.KEY_VALUE_REGION}
|
||||
WRAPPER_TYPES = {
|
||||
DocItemLabel.FORM,
|
||||
DocItemLabel.KEY_VALUE_REGION,
|
||||
DocItemLabel.TABLE,
|
||||
DocItemLabel.DOCUMENT_INDEX,
|
||||
}
|
||||
SPECIAL_TYPES = WRAPPER_TYPES | {DocItemLabel.PICTURE}
|
||||
|
||||
CONFIDENCE_THRESHOLDS = {
|
||||
@ -164,7 +169,7 @@ class LayoutPostprocessor:
|
||||
DocItemLabel.PAGE_HEADER: 0.5,
|
||||
DocItemLabel.PICTURE: 0.5,
|
||||
DocItemLabel.SECTION_HEADER: 0.45,
|
||||
DocItemLabel.TABLE: 0.35,
|
||||
DocItemLabel.TABLE: 0.5,
|
||||
DocItemLabel.TEXT: 0.55, # 0.45,
|
||||
DocItemLabel.TITLE: 0.45,
|
||||
DocItemLabel.CODE: 0.45,
|
||||
@ -176,14 +181,15 @@ class LayoutPostprocessor:
|
||||
}
|
||||
|
||||
LABEL_REMAPPING = {
|
||||
DocItemLabel.DOCUMENT_INDEX: DocItemLabel.TABLE,
|
||||
# DocItemLabel.DOCUMENT_INDEX: DocItemLabel.TABLE,
|
||||
DocItemLabel.TITLE: DocItemLabel.SECTION_HEADER,
|
||||
}
|
||||
|
||||
def __init__(self, cells: List[Cell], clusters: List[Cluster]):
|
||||
def __init__(self, cells: List[Cell], clusters: List[Cluster], page_size: Size):
|
||||
"""Initialize processor with cells and clusters."""
|
||||
"""Initialize processor with cells and spatial indices."""
|
||||
self.cells = cells
|
||||
self.page_size = page_size
|
||||
self.regular_clusters = [
|
||||
c for c in clusters if c.label not in self.SPECIAL_TYPES
|
||||
]
|
||||
@ -281,6 +287,19 @@ class LayoutPostprocessor:
|
||||
|
||||
special_clusters = self._handle_cross_type_overlaps(special_clusters)
|
||||
|
||||
# Calculate page area from known page size
|
||||
page_area = self.page_size.width * self.page_size.height
|
||||
if page_area > 0:
|
||||
# Filter out full-page pictures
|
||||
special_clusters = [
|
||||
cluster
|
||||
for cluster in special_clusters
|
||||
if not (
|
||||
cluster.label == DocItemLabel.PICTURE
|
||||
and cluster.bbox.area() / page_area > 0.90
|
||||
)
|
||||
]
|
||||
|
||||
for special in special_clusters:
|
||||
contained = []
|
||||
for cluster in self.regular_clusters:
|
||||
@ -313,6 +332,13 @@ class LayoutPostprocessor:
|
||||
b=max(c.bbox.b for c in contained),
|
||||
)
|
||||
|
||||
# Collect all cells from children
|
||||
all_cells = []
|
||||
for child in contained:
|
||||
all_cells.extend(child.cells)
|
||||
special.cells = self._deduplicate_cells(all_cells)
|
||||
special.cells = self._sort_cells(special.cells)
|
||||
|
||||
picture_clusters = [
|
||||
c for c in special_clusters if c.label == DocItemLabel.PICTURE
|
||||
]
|
||||
@ -338,7 +364,7 @@ class LayoutPostprocessor:
|
||||
wrappers_to_remove = set()
|
||||
|
||||
for wrapper in special_clusters:
|
||||
if wrapper.label != DocItemLabel.KEY_VALUE_REGION:
|
||||
if wrapper.label not in self.WRAPPER_TYPES:
|
||||
continue # only treat KEY_VALUE_REGION for now.
|
||||
|
||||
for regular in self.regular_clusters:
|
||||
@ -348,8 +374,12 @@ class LayoutPostprocessor:
|
||||
wrapper_area = wrapper.bbox.area()
|
||||
overlap_ratio = overlap / wrapper_area
|
||||
|
||||
conf_diff = wrapper.confidence - regular.confidence
|
||||
|
||||
# If wrapper is mostly overlapping with a TABLE, remove the wrapper
|
||||
if overlap_ratio > 0.8: # 80% overlap threshold
|
||||
if (
|
||||
overlap_ratio > 0.9 and conf_diff < 0.1
|
||||
): # self.OVERLAP_PARAMS["wrapper"]["conf_threshold"]): # 80% overlap threshold
|
||||
wrappers_to_remove.add(wrapper.id)
|
||||
break
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user