mirror of
https://github.com/DS4SD/docling.git
synced 2025-07-27 12:34:22 +00:00
Implement hierachical cluster layout processing
Signed-off-by: Christoph Auer <cau@zurich.ibm.com> Signed-off-by: Christoph Auer <cau@zurich.ibm.com>
This commit is contained in:
parent
e0cf80a919
commit
7245cc6080
@ -97,30 +97,32 @@ class LayoutModel(BasePageModel):
|
||||
# Function to draw clusters on an image
|
||||
def draw_clusters(image, clusters):
|
||||
draw = ImageDraw.Draw(image, "RGBA")
|
||||
for c in clusters:
|
||||
cell_color = (0, 0, 0, 40) # Transparent black for cells
|
||||
for tc in c.cells:
|
||||
cx0, cy0, cx1, cy1 = tc.bbox.as_tuple()
|
||||
draw.rectangle(
|
||||
[(cx0, cy0), (cx1, cy1)],
|
||||
outline=None,
|
||||
fill=cell_color,
|
||||
)
|
||||
for c_tl in clusters:
|
||||
all_clusters = [c_tl, *c_tl.children]
|
||||
for c in all_clusters:
|
||||
cell_color = (0, 0, 0, 40) # Transparent black for cells
|
||||
for tc in c.cells:
|
||||
cx0, cy0, cx1, cy1 = tc.bbox.as_tuple()
|
||||
draw.rectangle(
|
||||
[(cx0, cy0), (cx1, cy1)],
|
||||
outline=None,
|
||||
fill=cell_color,
|
||||
)
|
||||
|
||||
x0, y0, x1, y1 = c.bbox.as_tuple()
|
||||
cluster_fill_color = (
|
||||
*list(label_to_color.get(c.label)), # type: ignore
|
||||
70,
|
||||
)
|
||||
cluster_outline_color = (
|
||||
*list(label_to_color.get(c.label)), # type: ignore
|
||||
255,
|
||||
)
|
||||
draw.rectangle(
|
||||
[(x0, y0), (x1, y1)],
|
||||
outline=cluster_outline_color,
|
||||
fill=cluster_fill_color,
|
||||
)
|
||||
x0, y0, x1, y1 = c.bbox.as_tuple()
|
||||
cluster_fill_color = (
|
||||
*list(label_to_color.get(c.label)), # type: ignore
|
||||
70,
|
||||
)
|
||||
cluster_outline_color = (
|
||||
*list(label_to_color.get(c.label)), # type: ignore
|
||||
255,
|
||||
)
|
||||
draw.rectangle(
|
||||
[(x0, y0), (x1, y1)],
|
||||
outline=cluster_outline_color,
|
||||
fill=cluster_fill_color,
|
||||
)
|
||||
|
||||
# Draw clusters on both images
|
||||
draw_clusters(left_image, left_clusters)
|
||||
|
@ -1,5 +1,6 @@
|
||||
import bisect
|
||||
import logging
|
||||
from bisect import bisect_left
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List, Set, Tuple
|
||||
|
||||
from docling_core.types.doc import DocItemLabel
|
||||
@ -11,7 +12,7 @@ _log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UnionFind:
|
||||
"""Union-Find (Disjoint-Set) data structure."""
|
||||
"""Efficient Union-Find data structure for grouping elements."""
|
||||
|
||||
def __init__(self, elements):
|
||||
self.parent = {elem: elem for elem in elements}
|
||||
@ -23,73 +24,61 @@ class UnionFind:
|
||||
return self.parent[x]
|
||||
|
||||
def union(self, x, y):
|
||||
root_x = self.find(x)
|
||||
root_y = self.find(y)
|
||||
if root_x != root_y:
|
||||
if self.rank[root_x] > self.rank[root_y]:
|
||||
self.parent[root_y] = root_x
|
||||
elif self.rank[root_x] < self.rank[root_y]:
|
||||
self.parent[root_x] = root_y
|
||||
else:
|
||||
self.parent[root_y] = root_x
|
||||
self.rank[root_x] += 1
|
||||
root_x, root_y = self.find(x), self.find(y)
|
||||
if root_x == root_y:
|
||||
return
|
||||
|
||||
def groups(self):
|
||||
"""Return groups as {root: [elements]}."""
|
||||
from collections import defaultdict
|
||||
if self.rank[root_x] > self.rank[root_y]:
|
||||
self.parent[root_y] = root_x
|
||||
elif self.rank[root_x] < self.rank[root_y]:
|
||||
self.parent[root_x] = root_y
|
||||
else:
|
||||
self.parent[root_y] = root_x
|
||||
self.rank[root_x] += 1
|
||||
|
||||
components = defaultdict(list)
|
||||
def get_groups(self) -> Dict[int, List[int]]:
|
||||
"""Returns groups as {root: [elements]}."""
|
||||
groups = defaultdict(list)
|
||||
for elem in self.parent:
|
||||
root = self.find(elem)
|
||||
components[root].append(elem)
|
||||
return components
|
||||
groups[self.find(elem)].append(elem)
|
||||
return groups
|
||||
|
||||
|
||||
class SpatialClusterIndex:
|
||||
"""Helper class to manage spatial indexes for clusters and find overlaps efficiently."""
|
||||
"""Efficient spatial indexing for clusters using R-tree and interval trees."""
|
||||
|
||||
def __init__(self, clusters: List[Cluster]):
|
||||
# Create spatial index
|
||||
p = index.Property()
|
||||
p.dimension = 2
|
||||
self.spatial_index = index.Index(properties=p)
|
||||
|
||||
# Initialize interval trees
|
||||
self.x_intervals = IntervalTree()
|
||||
self.y_intervals = IntervalTree()
|
||||
self.clusters_by_id: Dict[int, Cluster] = {}
|
||||
|
||||
# Map to store clusters by ID
|
||||
self.clusters_by_id = {} # type: ignore
|
||||
|
||||
# Populate indexes and maps
|
||||
for cluster in clusters:
|
||||
self.add_cluster(cluster)
|
||||
|
||||
def add_cluster(self, cluster: Cluster):
|
||||
"""Add a cluster to all indexes."""
|
||||
self.spatial_index.insert(cluster.id, cluster.bbox.as_tuple())
|
||||
self.x_intervals.insert(cluster.bbox.l, cluster.bbox.r, cluster.id)
|
||||
self.y_intervals.insert(cluster.bbox.t, cluster.bbox.b, cluster.id)
|
||||
bbox = cluster.bbox
|
||||
self.spatial_index.insert(cluster.id, bbox.as_tuple())
|
||||
self.x_intervals.insert(bbox.l, bbox.r, cluster.id)
|
||||
self.y_intervals.insert(bbox.t, bbox.b, cluster.id)
|
||||
self.clusters_by_id[cluster.id] = cluster
|
||||
|
||||
def remove_cluster(self, cluster: Cluster):
|
||||
"""Remove a cluster from all indexes."""
|
||||
self.spatial_index.delete(cluster.id, cluster.bbox.as_tuple())
|
||||
# Note: IntervalTree doesn't support deletion, but we handle this
|
||||
# by checking clusters_by_id membership
|
||||
del self.clusters_by_id[cluster.id]
|
||||
|
||||
def find_candidates(self, bbox: BoundingBox) -> Set[int]:
|
||||
"""Find all potential overlapping cluster IDs using all indexes."""
|
||||
bbox_tuple = bbox.as_tuple()
|
||||
spatial_candidates = set(self.spatial_index.intersection(bbox_tuple))
|
||||
"""Find potential overlapping cluster IDs using all indexes."""
|
||||
spatial = set(self.spatial_index.intersection(bbox.as_tuple()))
|
||||
x_candidates = self.x_intervals.find_containing(
|
||||
bbox.l
|
||||
) | self.x_intervals.find_containing(bbox.r)
|
||||
y_candidates = self.y_intervals.find_containing(
|
||||
bbox.t
|
||||
) | self.y_intervals.find_containing(bbox.b)
|
||||
return spatial_candidates | x_candidates | y_candidates
|
||||
return spatial | x_candidates | y_candidates
|
||||
|
||||
def check_overlap(
|
||||
self,
|
||||
@ -99,8 +88,7 @@ class SpatialClusterIndex:
|
||||
containment_threshold: float,
|
||||
) -> bool:
|
||||
"""Check if two bboxes overlap sufficiently."""
|
||||
area1 = bbox1.area()
|
||||
area2 = bbox2.area()
|
||||
area1, area2 = bbox1.area(), bbox2.area()
|
||||
if area1 <= 0 or area2 <= 0:
|
||||
return False
|
||||
|
||||
@ -108,38 +96,47 @@ class SpatialClusterIndex:
|
||||
if overlap_area <= 0:
|
||||
return False
|
||||
|
||||
# Check both IoU and containment
|
||||
iou = overlap_area / (area1 + area2 - overlap_area)
|
||||
containment_ratio1 = overlap_area / area1
|
||||
containment_ratio2 = overlap_area / area2
|
||||
containment1 = overlap_area / area1
|
||||
containment2 = overlap_area / area2
|
||||
|
||||
return (
|
||||
iou > overlap_threshold
|
||||
or containment_ratio1 > containment_threshold
|
||||
or containment_ratio2 > containment_threshold
|
||||
or containment1 > containment_threshold
|
||||
or containment2 > containment_threshold
|
||||
)
|
||||
|
||||
|
||||
class IntervalTree:
|
||||
def __init__(self):
|
||||
self.intervals = [] # List of (min, max, box_id) sorted by min
|
||||
"""Memory-efficient interval tree for 1D overlap queries."""
|
||||
|
||||
def insert(self, min_val: float, max_val: float, box_id: int):
|
||||
self.intervals.append((min_val, max_val, box_id))
|
||||
self.intervals.sort(key=lambda x: x[0])
|
||||
def __init__(self):
|
||||
self.intervals: List[Tuple[float, float, int]] = (
|
||||
[]
|
||||
) # (min, max, id) sorted by min
|
||||
|
||||
def insert(self, min_val: float, max_val: float, id: int):
|
||||
bisect.insort(self.intervals, (min_val, max_val, id), key=lambda x: x[0])
|
||||
|
||||
def find_containing(self, point: float) -> Set[int]:
|
||||
pos = bisect_left(self.intervals, (point, float("-inf"), -1))
|
||||
"""Find all intervals containing the point."""
|
||||
pos = bisect.bisect_left(self.intervals, (point, float("-inf"), -1))
|
||||
result = set()
|
||||
|
||||
i = pos - 1
|
||||
while i >= 0:
|
||||
min_val, max_val, box_id = self.intervals[i]
|
||||
# Check intervals starting before point
|
||||
for min_val, max_val, id in reversed(self.intervals[:pos]):
|
||||
if min_val <= point <= max_val:
|
||||
result.add(box_id)
|
||||
result.add(id)
|
||||
else:
|
||||
break
|
||||
|
||||
# Check intervals starting at/after point
|
||||
for min_val, max_val, id in self.intervals[pos:]:
|
||||
if point <= max_val:
|
||||
if min_val <= point:
|
||||
result.add(id)
|
||||
else:
|
||||
break
|
||||
i -= 1
|
||||
|
||||
return result
|
||||
|
||||
@ -149,7 +146,7 @@ class LayoutPostprocessor:
|
||||
|
||||
# Cluster type-specific parameters for overlap resolution
|
||||
OVERLAP_PARAMS = {
|
||||
"regular": {"area_threshold": 1.3, "conf_threshold": 0.15},
|
||||
"regular": {"area_threshold": 1.3, "conf_threshold": 0.05},
|
||||
"picture": {"area_threshold": 2.0, "conf_threshold": 0.3},
|
||||
"wrapper": {"area_threshold": 2.0, "conf_threshold": 0.2},
|
||||
}
|
||||
@ -184,17 +181,42 @@ class LayoutPostprocessor:
|
||||
|
||||
def __init__(self, cells: List[Cell], clusters: List[Cluster]):
|
||||
"""Initialize processor with cells and clusters."""
|
||||
"""Initialize processor with cells and spatial indices."""
|
||||
self.cells = cells
|
||||
self.regular_clusters = [
|
||||
c for c in clusters if c.label not in self.SPECIAL_TYPES
|
||||
]
|
||||
self.special_clusters = [c for c in clusters if c.label in self.SPECIAL_TYPES]
|
||||
|
||||
# Build spatial indices once
|
||||
self.regular_index = SpatialClusterIndex(self.regular_clusters)
|
||||
self.picture_index = SpatialClusterIndex(
|
||||
[c for c in self.special_clusters if c.label == DocItemLabel.PICTURE]
|
||||
)
|
||||
self.wrapper_index = SpatialClusterIndex(
|
||||
[c for c in self.special_clusters if c.label in self.WRAPPER_TYPES]
|
||||
)
|
||||
|
||||
def postprocess(self) -> Tuple[List[Cluster], List[Cell]]:
|
||||
"""Main processing pipeline."""
|
||||
regular_clusters = self._process_regular_clusters()
|
||||
special_clusters = self._process_special_clusters()
|
||||
final_clusters = self._sort_clusters(regular_clusters + special_clusters)
|
||||
self.regular_clusters = self._process_regular_clusters()
|
||||
self.special_clusters = self._process_special_clusters()
|
||||
|
||||
# Remove regular clusters that are included in wrappers
|
||||
contained_ids = {
|
||||
child.id
|
||||
for wrapper in self.special_clusters
|
||||
if wrapper.label in self.SPECIAL_TYPES
|
||||
for child in wrapper.children
|
||||
}
|
||||
self.regular_clusters = [
|
||||
c for c in self.regular_clusters if c.id not in contained_ids
|
||||
]
|
||||
|
||||
# Combine and sort final clusters
|
||||
final_clusters = self._sort_clusters(
|
||||
self.regular_clusters + self.special_clusters
|
||||
)
|
||||
return final_clusters, self.cells
|
||||
|
||||
def _process_regular_clusters(self) -> List[Cluster]:
|
||||
@ -241,49 +263,48 @@ class LayoutPostprocessor:
|
||||
return clusters
|
||||
|
||||
def _process_special_clusters(self) -> List[Cluster]:
|
||||
"""Process special clusters (pictures and wrappers)."""
|
||||
# Handle pictures
|
||||
picture_clusters = [
|
||||
special_clusters = [
|
||||
c
|
||||
for c in self.special_clusters
|
||||
if c.label == DocItemLabel.PICTURE
|
||||
and c.confidence >= self.CONFIDENCE_THRESHOLDS[c.label]
|
||||
if c.confidence >= self.CONFIDENCE_THRESHOLDS[c.label]
|
||||
]
|
||||
|
||||
for special in special_clusters:
|
||||
contained = []
|
||||
for cluster in self.regular_clusters:
|
||||
overlap = cluster.bbox.intersection_area_with(special.bbox)
|
||||
if overlap > 0:
|
||||
containment = overlap / cluster.bbox.area()
|
||||
if containment > 0.8:
|
||||
contained.append(cluster)
|
||||
|
||||
if contained:
|
||||
special.children = contained
|
||||
# Adjust bbox only for wrapper types
|
||||
if special.label in self.WRAPPER_TYPES:
|
||||
special.bbox = BoundingBox(
|
||||
l=min(c.bbox.l for c in contained),
|
||||
t=min(c.bbox.t for c in contained),
|
||||
r=max(c.bbox.r for c in contained),
|
||||
b=max(c.bbox.b for c in contained),
|
||||
)
|
||||
|
||||
picture_clusters = [
|
||||
c for c in special_clusters if c.label == DocItemLabel.PICTURE
|
||||
]
|
||||
picture_clusters = self._remove_overlapping_clusters(
|
||||
picture_clusters, "picture"
|
||||
)
|
||||
|
||||
# Process wrapper clusters
|
||||
wrapper_clusters = []
|
||||
for wrapper in (
|
||||
c for c in self.special_clusters if c.label in self.WRAPPER_TYPES
|
||||
):
|
||||
if wrapper.confidence < self.CONFIDENCE_THRESHOLDS[wrapper.label]:
|
||||
continue
|
||||
|
||||
# Find contained regular clusters
|
||||
contained = []
|
||||
for cluster in self.regular_clusters:
|
||||
overlap = cluster.bbox.intersection_area_with(wrapper.bbox)
|
||||
if overlap > 0:
|
||||
containment = overlap / cluster.bbox.area()
|
||||
if containment > 0.8: # High containment threshold for wrappers
|
||||
contained.append(cluster)
|
||||
|
||||
if contained:
|
||||
wrapper.children = contained
|
||||
wrapper.bbox = BoundingBox(
|
||||
l=min(c.bbox.l for c in contained),
|
||||
t=min(c.bbox.t for c in contained),
|
||||
r=max(c.bbox.r for c in contained),
|
||||
b=max(c.bbox.b for c in contained),
|
||||
)
|
||||
wrapper_clusters.append(wrapper)
|
||||
|
||||
return picture_clusters + self._remove_overlapping_clusters(
|
||||
wrapper_clusters = [
|
||||
c for c in special_clusters if c.label in self.WRAPPER_TYPES
|
||||
]
|
||||
wrapper_clusters = self._remove_overlapping_clusters(
|
||||
wrapper_clusters, "wrapper"
|
||||
)
|
||||
|
||||
return picture_clusters + wrapper_clusters
|
||||
|
||||
def _remove_overlapping_clusters(
|
||||
self,
|
||||
clusters: List[Cluster],
|
||||
@ -291,50 +312,71 @@ class LayoutPostprocessor:
|
||||
overlap_threshold: float = 0.8,
|
||||
containment_threshold: float = 0.8,
|
||||
) -> List[Cluster]:
|
||||
"""Remove overlapping clusters using efficient spatial indexing."""
|
||||
if not clusters:
|
||||
return []
|
||||
|
||||
# Initialize spatial index
|
||||
spatial_index = SpatialClusterIndex(clusters)
|
||||
uf = UnionFind(spatial_index.clusters_by_id.keys())
|
||||
spatial_index = (
|
||||
self.regular_index
|
||||
if cluster_type == "regular"
|
||||
else self.picture_index if cluster_type == "picture" else self.wrapper_index
|
||||
)
|
||||
|
||||
# Map of currently valid clusters
|
||||
valid_clusters = {c.id: c for c in clusters}
|
||||
uf = UnionFind(valid_clusters.keys())
|
||||
params = self.OVERLAP_PARAMS[cluster_type]
|
||||
|
||||
# Group overlapping clusters using spatial index
|
||||
for cluster in clusters:
|
||||
candidates = spatial_index.find_candidates(cluster.bbox)
|
||||
candidates.discard(cluster.id) # Remove self
|
||||
candidates &= valid_clusters.keys() # Only keep existing candidates
|
||||
candidates.discard(cluster.id)
|
||||
|
||||
for other_id in candidates:
|
||||
if spatial_index.check_overlap(
|
||||
cluster.bbox,
|
||||
spatial_index.clusters_by_id[other_id].bbox,
|
||||
valid_clusters[other_id].bbox,
|
||||
overlap_threshold,
|
||||
containment_threshold,
|
||||
):
|
||||
uf.union(cluster.id, other_id)
|
||||
|
||||
# Process each group using type-specific parameters
|
||||
params = self.OVERLAP_PARAMS[cluster_type]
|
||||
result = []
|
||||
|
||||
for group in uf.groups().values():
|
||||
for group in uf.get_groups().values():
|
||||
if len(group) == 1:
|
||||
result.append(spatial_index.clusters_by_id[group[0]])
|
||||
result.append(valid_clusters[group[0]])
|
||||
continue
|
||||
|
||||
# Get clusters in group
|
||||
group_clusters = [spatial_index.clusters_by_id[cid] for cid in group]
|
||||
group_clusters = [valid_clusters[cid] for cid in group]
|
||||
current_best = None
|
||||
|
||||
# Find best cluster using area and confidence
|
||||
best = self._select_best_cluster(
|
||||
group_clusters, params["area_threshold"], params["conf_threshold"]
|
||||
)
|
||||
for candidate in group_clusters:
|
||||
should_select = True
|
||||
for other in group_clusters:
|
||||
if other == candidate:
|
||||
continue
|
||||
|
||||
# Merge cells from other clusters into best
|
||||
area_ratio = candidate.bbox.area() / other.bbox.area()
|
||||
conf_diff = other.confidence - candidate.confidence
|
||||
|
||||
if (
|
||||
area_ratio <= params["area_threshold"]
|
||||
and conf_diff > params["conf_threshold"]
|
||||
):
|
||||
should_select = False
|
||||
break
|
||||
|
||||
if should_select:
|
||||
if current_best is None or (
|
||||
candidate.bbox.area() > current_best.bbox.area()
|
||||
and current_best.confidence - candidate.confidence
|
||||
<= params["conf_threshold"]
|
||||
):
|
||||
current_best = candidate
|
||||
|
||||
best = current_best if current_best else group_clusters[0]
|
||||
for cluster in group_clusters:
|
||||
if cluster != best:
|
||||
best.cells.extend(cluster.cells)
|
||||
|
||||
result.append(best)
|
||||
|
||||
return result
|
||||
|
Loading…
Reference in New Issue
Block a user