mirror of
https://github.com/DS4SD/docling.git
synced 2025-07-27 04:24:45 +00:00
Many layout processing improvements, add document index type
Signed-off-by: Christoph Auer <cau@zurich.ibm.com> Signed-off-by: Christoph Auer <cau@zurich.ibm.com>
This commit is contained in:
parent
d094c4990a
commit
57d51ede04
@ -63,7 +63,7 @@ _log = logging.getLogger(__name__)
|
|||||||
|
|
||||||
layout_label_to_ds_type = {
|
layout_label_to_ds_type = {
|
||||||
DocItemLabel.TITLE: "title",
|
DocItemLabel.TITLE: "title",
|
||||||
DocItemLabel.DOCUMENT_INDEX: "table-of-contents",
|
DocItemLabel.DOCUMENT_INDEX: "table",
|
||||||
DocItemLabel.SECTION_HEADER: "subtitle-level-1",
|
DocItemLabel.SECTION_HEADER: "subtitle-level-1",
|
||||||
DocItemLabel.CHECKBOX_SELECTED: "checkbox-selected",
|
DocItemLabel.CHECKBOX_SELECTED: "checkbox-selected",
|
||||||
DocItemLabel.CHECKBOX_UNSELECTED: "checkbox-unselected",
|
DocItemLabel.CHECKBOX_UNSELECTED: "checkbox-unselected",
|
||||||
|
@ -7,7 +7,7 @@ from typing import Iterable, List
|
|||||||
|
|
||||||
from docling_core.types.doc import CoordOrigin, DocItemLabel
|
from docling_core.types.doc import CoordOrigin, DocItemLabel
|
||||||
from docling_ibm_models.layoutmodel.layout_predictor import LayoutPredictor
|
from docling_ibm_models.layoutmodel.layout_predictor import LayoutPredictor
|
||||||
from PIL import Image, ImageDraw
|
from PIL import Image, ImageDraw, ImageFont
|
||||||
|
|
||||||
from docling.datamodel.base_models import (
|
from docling.datamodel.base_models import (
|
||||||
BoundingBox,
|
BoundingBox,
|
||||||
@ -44,7 +44,7 @@ class LayoutModel(BasePageModel):
|
|||||||
]
|
]
|
||||||
PAGE_HEADER_LABELS = [DocItemLabel.PAGE_HEADER, DocItemLabel.PAGE_FOOTER]
|
PAGE_HEADER_LABELS = [DocItemLabel.PAGE_HEADER, DocItemLabel.PAGE_FOOTER]
|
||||||
|
|
||||||
TABLE_LABEL = DocItemLabel.TABLE
|
TABLE_LABELS = [DocItemLabel.TABLE, DocItemLabel.DOCUMENT_INDEX]
|
||||||
FIGURE_LABEL = DocItemLabel.PICTURE
|
FIGURE_LABEL = DocItemLabel.PICTURE
|
||||||
FORMULA_LABEL = DocItemLabel.FORMULA
|
FORMULA_LABEL = DocItemLabel.FORMULA
|
||||||
CONTAINER_LABELS = [DocItemLabel.FORM, DocItemLabel.KEY_VALUE_REGION]
|
CONTAINER_LABELS = [DocItemLabel.FORM, DocItemLabel.KEY_VALUE_REGION]
|
||||||
@ -62,6 +62,7 @@ class LayoutModel(BasePageModel):
|
|||||||
Draws a page image side by side with clusters filtered into two categories:
|
Draws a page image side by side with clusters filtered into two categories:
|
||||||
- Left: Clusters excluding FORM, KEY_VALUE_REGION, and PICTURE.
|
- Left: Clusters excluding FORM, KEY_VALUE_REGION, and PICTURE.
|
||||||
- Right: Clusters including FORM, KEY_VALUE_REGION, and PICTURE.
|
- Right: Clusters including FORM, KEY_VALUE_REGION, and PICTURE.
|
||||||
|
Includes label names and confidence scores for each cluster.
|
||||||
"""
|
"""
|
||||||
label_to_color = {
|
label_to_color = {
|
||||||
DocItemLabel.TEXT: (255, 255, 153), # Light Yellow
|
DocItemLabel.TEXT: (255, 255, 153), # Light Yellow
|
||||||
@ -103,9 +104,18 @@ class LayoutModel(BasePageModel):
|
|||||||
# Function to draw clusters on an image
|
# Function to draw clusters on an image
|
||||||
def draw_clusters(image, clusters):
|
def draw_clusters(image, clusters):
|
||||||
draw = ImageDraw.Draw(image, "RGBA")
|
draw = ImageDraw.Draw(image, "RGBA")
|
||||||
|
|
||||||
|
# Create a smaller font for the labels
|
||||||
|
try:
|
||||||
|
font = ImageFont.truetype("arial.ttf", 12)
|
||||||
|
except OSError:
|
||||||
|
# Fallback to default font if arial is not available
|
||||||
|
font = ImageFont.load_default()
|
||||||
|
|
||||||
for c_tl in clusters:
|
for c_tl in clusters:
|
||||||
all_clusters = [c_tl, *c_tl.children]
|
all_clusters = [c_tl, *c_tl.children]
|
||||||
for c in all_clusters:
|
for c in all_clusters:
|
||||||
|
# Draw cells first (underneath)
|
||||||
cell_color = (0, 0, 0, 40) # Transparent black for cells
|
cell_color = (0, 0, 0, 40) # Transparent black for cells
|
||||||
for tc in c.cells:
|
for tc in c.cells:
|
||||||
cx0, cy0, cx1, cy1 = tc.bbox.as_tuple()
|
cx0, cy0, cx1, cy1 = tc.bbox.as_tuple()
|
||||||
@ -115,21 +125,44 @@ class LayoutModel(BasePageModel):
|
|||||||
fill=cell_color,
|
fill=cell_color,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Draw cluster rectangle
|
||||||
x0, y0, x1, y1 = c.bbox.as_tuple()
|
x0, y0, x1, y1 = c.bbox.as_tuple()
|
||||||
cluster_fill_color = (
|
cluster_fill_color = (*list(label_to_color.get(c.label)), 70)
|
||||||
*list(label_to_color.get(c.label)), # type: ignore
|
cluster_outline_color = (*list(label_to_color.get(c.label)), 255)
|
||||||
70,
|
|
||||||
)
|
|
||||||
cluster_outline_color = (
|
|
||||||
*list(label_to_color.get(c.label)), # type: ignore
|
|
||||||
255,
|
|
||||||
)
|
|
||||||
draw.rectangle(
|
draw.rectangle(
|
||||||
[(x0, y0), (x1, y1)],
|
[(x0, y0), (x1, y1)],
|
||||||
outline=cluster_outline_color,
|
outline=cluster_outline_color,
|
||||||
fill=cluster_fill_color,
|
fill=cluster_fill_color,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Add label name and confidence
|
||||||
|
label_text = f"{c.label.name} ({c.confidence:.2f})"
|
||||||
|
|
||||||
|
# Create semi-transparent background for text
|
||||||
|
text_bbox = draw.textbbox((x0, y0), label_text, font=font)
|
||||||
|
text_bg_padding = 2
|
||||||
|
draw.rectangle(
|
||||||
|
[
|
||||||
|
(
|
||||||
|
text_bbox[0] - text_bg_padding,
|
||||||
|
text_bbox[1] - text_bg_padding,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
text_bbox[2] + text_bg_padding,
|
||||||
|
text_bbox[3] + text_bg_padding,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
fill=(255, 255, 255, 180), # Semi-transparent white
|
||||||
|
)
|
||||||
|
|
||||||
|
# Draw text
|
||||||
|
draw.text(
|
||||||
|
(x0, y0),
|
||||||
|
label_text,
|
||||||
|
fill=(0, 0, 0, 255), # Solid black
|
||||||
|
font=font,
|
||||||
|
)
|
||||||
|
|
||||||
# Draw clusters on both images
|
# Draw clusters on both images
|
||||||
draw_clusters(left_image, left_clusters)
|
draw_clusters(left_image, left_clusters)
|
||||||
draw_clusters(right_image, right_clusters)
|
draw_clusters(right_image, right_clusters)
|
||||||
@ -277,8 +310,9 @@ class LayoutModel(BasePageModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Apply postprocessing
|
# Apply postprocessing
|
||||||
|
|
||||||
processed_clusters, processed_cells = LayoutPostprocessor(
|
processed_clusters, processed_cells = LayoutPostprocessor(
|
||||||
page.cells, clusters
|
page.cells, clusters, page.size
|
||||||
).postprocess()
|
).postprocess()
|
||||||
# processed_clusters, processed_cells = clusters, page.cells
|
# processed_clusters, processed_cells = clusters, page.cells
|
||||||
|
|
||||||
|
@ -95,7 +95,7 @@ class PageAssembleModel(BasePageModel):
|
|||||||
headers.append(text_el)
|
headers.append(text_el)
|
||||||
else:
|
else:
|
||||||
body.append(text_el)
|
body.append(text_el)
|
||||||
elif cluster.label == LayoutModel.TABLE_LABEL:
|
elif cluster.label in LayoutModel.TABLE_LABELS:
|
||||||
tbl = None
|
tbl = None
|
||||||
if page.predictions.tablestructure:
|
if page.predictions.tablestructure:
|
||||||
tbl = page.predictions.tablestructure.table_map.get(
|
tbl = page.predictions.tablestructure.table_map.get(
|
||||||
|
@ -133,7 +133,8 @@ class TableStructureModel(BasePageModel):
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
for cluster in page.predictions.layout.clusters
|
for cluster in page.predictions.layout.clusters
|
||||||
if cluster.label == DocItemLabel.TABLE
|
if cluster.label
|
||||||
|
in [DocItemLabel.TABLE, DocItemLabel.DOCUMENT_INDEX]
|
||||||
]
|
]
|
||||||
if not len(in_tables):
|
if not len(in_tables):
|
||||||
yield page
|
yield page
|
||||||
@ -198,7 +199,7 @@ class TableStructureModel(BasePageModel):
|
|||||||
id=table_cluster.id,
|
id=table_cluster.id,
|
||||||
page_no=page.page_no,
|
page_no=page.page_no,
|
||||||
cluster=table_cluster,
|
cluster=table_cluster,
|
||||||
label=DocItemLabel.TABLE,
|
label=table_cluster.label,
|
||||||
)
|
)
|
||||||
|
|
||||||
page.predictions.tablestructure.table_map[
|
page.predictions.tablestructure.table_map[
|
||||||
|
@ -169,6 +169,8 @@ def to_docling_document(doc_glm, update_name_label=False) -> DoclingDocument:
|
|||||||
current_list = None
|
current_list = None
|
||||||
text = ""
|
text = ""
|
||||||
caption_refs = []
|
caption_refs = []
|
||||||
|
item_label = DocItemLabel(pelem["name"])
|
||||||
|
|
||||||
for caption in obj["captions"]:
|
for caption in obj["captions"]:
|
||||||
text += caption["text"]
|
text += caption["text"]
|
||||||
|
|
||||||
@ -254,7 +256,7 @@ def to_docling_document(doc_glm, update_name_label=False) -> DoclingDocument:
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
tbl = doc.add_table(data=tbl_data, prov=prov)
|
tbl = doc.add_table(data=tbl_data, prov=prov, label=item_label)
|
||||||
tbl.captions.extend(caption_refs)
|
tbl.captions.extend(caption_refs)
|
||||||
|
|
||||||
elif ptype in ["form", "key_value_region"]:
|
elif ptype in ["form", "key_value_region"]:
|
||||||
|
@ -4,7 +4,7 @@ import sys
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Dict, List, Set, Tuple
|
from typing import Dict, List, Set, Tuple
|
||||||
|
|
||||||
from docling_core.types.doc import DocItemLabel
|
from docling_core.types.doc import DocItemLabel, Size
|
||||||
from rtree import index
|
from rtree import index
|
||||||
|
|
||||||
from docling.datamodel.base_models import BoundingBox, Cell, Cluster
|
from docling.datamodel.base_models import BoundingBox, Cell, Cluster
|
||||||
@ -152,7 +152,12 @@ class LayoutPostprocessor:
|
|||||||
"wrapper": {"area_threshold": 2.0, "conf_threshold": 0.2},
|
"wrapper": {"area_threshold": 2.0, "conf_threshold": 0.2},
|
||||||
}
|
}
|
||||||
|
|
||||||
WRAPPER_TYPES = {DocItemLabel.FORM, DocItemLabel.KEY_VALUE_REGION}
|
WRAPPER_TYPES = {
|
||||||
|
DocItemLabel.FORM,
|
||||||
|
DocItemLabel.KEY_VALUE_REGION,
|
||||||
|
DocItemLabel.TABLE,
|
||||||
|
DocItemLabel.DOCUMENT_INDEX,
|
||||||
|
}
|
||||||
SPECIAL_TYPES = WRAPPER_TYPES | {DocItemLabel.PICTURE}
|
SPECIAL_TYPES = WRAPPER_TYPES | {DocItemLabel.PICTURE}
|
||||||
|
|
||||||
CONFIDENCE_THRESHOLDS = {
|
CONFIDENCE_THRESHOLDS = {
|
||||||
@ -164,7 +169,7 @@ class LayoutPostprocessor:
|
|||||||
DocItemLabel.PAGE_HEADER: 0.5,
|
DocItemLabel.PAGE_HEADER: 0.5,
|
||||||
DocItemLabel.PICTURE: 0.5,
|
DocItemLabel.PICTURE: 0.5,
|
||||||
DocItemLabel.SECTION_HEADER: 0.45,
|
DocItemLabel.SECTION_HEADER: 0.45,
|
||||||
DocItemLabel.TABLE: 0.35,
|
DocItemLabel.TABLE: 0.5,
|
||||||
DocItemLabel.TEXT: 0.55, # 0.45,
|
DocItemLabel.TEXT: 0.55, # 0.45,
|
||||||
DocItemLabel.TITLE: 0.45,
|
DocItemLabel.TITLE: 0.45,
|
||||||
DocItemLabel.CODE: 0.45,
|
DocItemLabel.CODE: 0.45,
|
||||||
@ -176,14 +181,15 @@ class LayoutPostprocessor:
|
|||||||
}
|
}
|
||||||
|
|
||||||
LABEL_REMAPPING = {
|
LABEL_REMAPPING = {
|
||||||
DocItemLabel.DOCUMENT_INDEX: DocItemLabel.TABLE,
|
# DocItemLabel.DOCUMENT_INDEX: DocItemLabel.TABLE,
|
||||||
DocItemLabel.TITLE: DocItemLabel.SECTION_HEADER,
|
DocItemLabel.TITLE: DocItemLabel.SECTION_HEADER,
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(self, cells: List[Cell], clusters: List[Cluster]):
|
def __init__(self, cells: List[Cell], clusters: List[Cluster], page_size: Size):
|
||||||
"""Initialize processor with cells and clusters."""
|
"""Initialize processor with cells and clusters."""
|
||||||
"""Initialize processor with cells and spatial indices."""
|
"""Initialize processor with cells and spatial indices."""
|
||||||
self.cells = cells
|
self.cells = cells
|
||||||
|
self.page_size = page_size
|
||||||
self.regular_clusters = [
|
self.regular_clusters = [
|
||||||
c for c in clusters if c.label not in self.SPECIAL_TYPES
|
c for c in clusters if c.label not in self.SPECIAL_TYPES
|
||||||
]
|
]
|
||||||
@ -281,6 +287,19 @@ class LayoutPostprocessor:
|
|||||||
|
|
||||||
special_clusters = self._handle_cross_type_overlaps(special_clusters)
|
special_clusters = self._handle_cross_type_overlaps(special_clusters)
|
||||||
|
|
||||||
|
# Calculate page area from known page size
|
||||||
|
page_area = self.page_size.width * self.page_size.height
|
||||||
|
if page_area > 0:
|
||||||
|
# Filter out full-page pictures
|
||||||
|
special_clusters = [
|
||||||
|
cluster
|
||||||
|
for cluster in special_clusters
|
||||||
|
if not (
|
||||||
|
cluster.label == DocItemLabel.PICTURE
|
||||||
|
and cluster.bbox.area() / page_area > 0.90
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
for special in special_clusters:
|
for special in special_clusters:
|
||||||
contained = []
|
contained = []
|
||||||
for cluster in self.regular_clusters:
|
for cluster in self.regular_clusters:
|
||||||
@ -313,6 +332,13 @@ class LayoutPostprocessor:
|
|||||||
b=max(c.bbox.b for c in contained),
|
b=max(c.bbox.b for c in contained),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Collect all cells from children
|
||||||
|
all_cells = []
|
||||||
|
for child in contained:
|
||||||
|
all_cells.extend(child.cells)
|
||||||
|
special.cells = self._deduplicate_cells(all_cells)
|
||||||
|
special.cells = self._sort_cells(special.cells)
|
||||||
|
|
||||||
picture_clusters = [
|
picture_clusters = [
|
||||||
c for c in special_clusters if c.label == DocItemLabel.PICTURE
|
c for c in special_clusters if c.label == DocItemLabel.PICTURE
|
||||||
]
|
]
|
||||||
@ -338,7 +364,7 @@ class LayoutPostprocessor:
|
|||||||
wrappers_to_remove = set()
|
wrappers_to_remove = set()
|
||||||
|
|
||||||
for wrapper in special_clusters:
|
for wrapper in special_clusters:
|
||||||
if wrapper.label != DocItemLabel.KEY_VALUE_REGION:
|
if wrapper.label not in self.WRAPPER_TYPES:
|
||||||
continue # only treat KEY_VALUE_REGION for now.
|
continue # only treat KEY_VALUE_REGION for now.
|
||||||
|
|
||||||
for regular in self.regular_clusters:
|
for regular in self.regular_clusters:
|
||||||
@ -348,8 +374,12 @@ class LayoutPostprocessor:
|
|||||||
wrapper_area = wrapper.bbox.area()
|
wrapper_area = wrapper.bbox.area()
|
||||||
overlap_ratio = overlap / wrapper_area
|
overlap_ratio = overlap / wrapper_area
|
||||||
|
|
||||||
|
conf_diff = wrapper.confidence - regular.confidence
|
||||||
|
|
||||||
# If wrapper is mostly overlapping with a TABLE, remove the wrapper
|
# If wrapper is mostly overlapping with a TABLE, remove the wrapper
|
||||||
if overlap_ratio > 0.8: # 80% overlap threshold
|
if (
|
||||||
|
overlap_ratio > 0.9 and conf_diff < 0.1
|
||||||
|
): # self.OVERLAP_PARAMS["wrapper"]["conf_threshold"]): # 80% overlap threshold
|
||||||
wrappers_to_remove.add(wrapper.id)
|
wrappers_to_remove.add(wrapper.id)
|
||||||
break
|
break
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user