Implement hierachical cluster layout processing

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>
This commit is contained in:
Christoph Auer 2024-12-03 10:28:36 +01:00
parent e0cf80a919
commit 7245cc6080
2 changed files with 180 additions and 136 deletions

View File

@ -97,7 +97,9 @@ class LayoutModel(BasePageModel):
# Function to draw clusters on an image # Function to draw clusters on an image
def draw_clusters(image, clusters): def draw_clusters(image, clusters):
draw = ImageDraw.Draw(image, "RGBA") draw = ImageDraw.Draw(image, "RGBA")
for c in clusters: for c_tl in clusters:
all_clusters = [c_tl, *c_tl.children]
for c in all_clusters:
cell_color = (0, 0, 0, 40) # Transparent black for cells cell_color = (0, 0, 0, 40) # Transparent black for cells
for tc in c.cells: for tc in c.cells:
cx0, cy0, cx1, cy1 = tc.bbox.as_tuple() cx0, cy0, cx1, cy1 = tc.bbox.as_tuple()

View File

@ -1,5 +1,6 @@
import bisect
import logging import logging
from bisect import bisect_left from collections import defaultdict
from typing import Dict, List, Set, Tuple from typing import Dict, List, Set, Tuple
from docling_core.types.doc import DocItemLabel from docling_core.types.doc import DocItemLabel
@ -11,7 +12,7 @@ _log = logging.getLogger(__name__)
class UnionFind: class UnionFind:
"""Union-Find (Disjoint-Set) data structure.""" """Efficient Union-Find data structure for grouping elements."""
def __init__(self, elements): def __init__(self, elements):
self.parent = {elem: elem for elem in elements} self.parent = {elem: elem for elem in elements}
@ -23,9 +24,10 @@ class UnionFind:
return self.parent[x] return self.parent[x]
def union(self, x, y): def union(self, x, y):
root_x = self.find(x) root_x, root_y = self.find(x), self.find(y)
root_y = self.find(y) if root_x == root_y:
if root_x != root_y: return
if self.rank[root_x] > self.rank[root_y]: if self.rank[root_x] > self.rank[root_y]:
self.parent[root_y] = root_x self.parent[root_y] = root_x
elif self.rank[root_x] < self.rank[root_y]: elif self.rank[root_x] < self.rank[root_y]:
@ -34,62 +36,49 @@ class UnionFind:
self.parent[root_y] = root_x self.parent[root_y] = root_x
self.rank[root_x] += 1 self.rank[root_x] += 1
def groups(self): def get_groups(self) -> Dict[int, List[int]]:
"""Return groups as {root: [elements]}.""" """Returns groups as {root: [elements]}."""
from collections import defaultdict groups = defaultdict(list)
components = defaultdict(list)
for elem in self.parent: for elem in self.parent:
root = self.find(elem) groups[self.find(elem)].append(elem)
components[root].append(elem) return groups
return components
class SpatialClusterIndex: class SpatialClusterIndex:
"""Helper class to manage spatial indexes for clusters and find overlaps efficiently.""" """Efficient spatial indexing for clusters using R-tree and interval trees."""
def __init__(self, clusters: List[Cluster]): def __init__(self, clusters: List[Cluster]):
# Create spatial index
p = index.Property() p = index.Property()
p.dimension = 2 p.dimension = 2
self.spatial_index = index.Index(properties=p) self.spatial_index = index.Index(properties=p)
# Initialize interval trees
self.x_intervals = IntervalTree() self.x_intervals = IntervalTree()
self.y_intervals = IntervalTree() self.y_intervals = IntervalTree()
self.clusters_by_id: Dict[int, Cluster] = {}
# Map to store clusters by ID
self.clusters_by_id = {} # type: ignore
# Populate indexes and maps
for cluster in clusters: for cluster in clusters:
self.add_cluster(cluster) self.add_cluster(cluster)
def add_cluster(self, cluster: Cluster): def add_cluster(self, cluster: Cluster):
"""Add a cluster to all indexes.""" bbox = cluster.bbox
self.spatial_index.insert(cluster.id, cluster.bbox.as_tuple()) self.spatial_index.insert(cluster.id, bbox.as_tuple())
self.x_intervals.insert(cluster.bbox.l, cluster.bbox.r, cluster.id) self.x_intervals.insert(bbox.l, bbox.r, cluster.id)
self.y_intervals.insert(cluster.bbox.t, cluster.bbox.b, cluster.id) self.y_intervals.insert(bbox.t, bbox.b, cluster.id)
self.clusters_by_id[cluster.id] = cluster self.clusters_by_id[cluster.id] = cluster
def remove_cluster(self, cluster: Cluster): def remove_cluster(self, cluster: Cluster):
"""Remove a cluster from all indexes."""
self.spatial_index.delete(cluster.id, cluster.bbox.as_tuple()) self.spatial_index.delete(cluster.id, cluster.bbox.as_tuple())
# Note: IntervalTree doesn't support deletion, but we handle this
# by checking clusters_by_id membership
del self.clusters_by_id[cluster.id] del self.clusters_by_id[cluster.id]
def find_candidates(self, bbox: BoundingBox) -> Set[int]: def find_candidates(self, bbox: BoundingBox) -> Set[int]:
"""Find all potential overlapping cluster IDs using all indexes.""" """Find potential overlapping cluster IDs using all indexes."""
bbox_tuple = bbox.as_tuple() spatial = set(self.spatial_index.intersection(bbox.as_tuple()))
spatial_candidates = set(self.spatial_index.intersection(bbox_tuple))
x_candidates = self.x_intervals.find_containing( x_candidates = self.x_intervals.find_containing(
bbox.l bbox.l
) | self.x_intervals.find_containing(bbox.r) ) | self.x_intervals.find_containing(bbox.r)
y_candidates = self.y_intervals.find_containing( y_candidates = self.y_intervals.find_containing(
bbox.t bbox.t
) | self.y_intervals.find_containing(bbox.b) ) | self.y_intervals.find_containing(bbox.b)
return spatial_candidates | x_candidates | y_candidates return spatial | x_candidates | y_candidates
def check_overlap( def check_overlap(
self, self,
@ -99,8 +88,7 @@ class SpatialClusterIndex:
containment_threshold: float, containment_threshold: float,
) -> bool: ) -> bool:
"""Check if two bboxes overlap sufficiently.""" """Check if two bboxes overlap sufficiently."""
area1 = bbox1.area() area1, area2 = bbox1.area(), bbox2.area()
area2 = bbox2.area()
if area1 <= 0 or area2 <= 0: if area1 <= 0 or area2 <= 0:
return False return False
@ -108,38 +96,47 @@ class SpatialClusterIndex:
if overlap_area <= 0: if overlap_area <= 0:
return False return False
# Check both IoU and containment
iou = overlap_area / (area1 + area2 - overlap_area) iou = overlap_area / (area1 + area2 - overlap_area)
containment_ratio1 = overlap_area / area1 containment1 = overlap_area / area1
containment_ratio2 = overlap_area / area2 containment2 = overlap_area / area2
return ( return (
iou > overlap_threshold iou > overlap_threshold
or containment_ratio1 > containment_threshold or containment1 > containment_threshold
or containment_ratio2 > containment_threshold or containment2 > containment_threshold
) )
class IntervalTree: class IntervalTree:
def __init__(self): """Memory-efficient interval tree for 1D overlap queries."""
self.intervals = [] # List of (min, max, box_id) sorted by min
def insert(self, min_val: float, max_val: float, box_id: int): def __init__(self):
self.intervals.append((min_val, max_val, box_id)) self.intervals: List[Tuple[float, float, int]] = (
self.intervals.sort(key=lambda x: x[0]) []
) # (min, max, id) sorted by min
def insert(self, min_val: float, max_val: float, id: int):
bisect.insort(self.intervals, (min_val, max_val, id), key=lambda x: x[0])
def find_containing(self, point: float) -> Set[int]: def find_containing(self, point: float) -> Set[int]:
pos = bisect_left(self.intervals, (point, float("-inf"), -1)) """Find all intervals containing the point."""
pos = bisect.bisect_left(self.intervals, (point, float("-inf"), -1))
result = set() result = set()
i = pos - 1 # Check intervals starting before point
while i >= 0: for min_val, max_val, id in reversed(self.intervals[:pos]):
min_val, max_val, box_id = self.intervals[i]
if min_val <= point <= max_val: if min_val <= point <= max_val:
result.add(box_id) result.add(id)
else:
break
# Check intervals starting at/after point
for min_val, max_val, id in self.intervals[pos:]:
if point <= max_val:
if min_val <= point:
result.add(id)
else: else:
break break
i -= 1
return result return result
@ -149,7 +146,7 @@ class LayoutPostprocessor:
# Cluster type-specific parameters for overlap resolution # Cluster type-specific parameters for overlap resolution
OVERLAP_PARAMS = { OVERLAP_PARAMS = {
"regular": {"area_threshold": 1.3, "conf_threshold": 0.15}, "regular": {"area_threshold": 1.3, "conf_threshold": 0.05},
"picture": {"area_threshold": 2.0, "conf_threshold": 0.3}, "picture": {"area_threshold": 2.0, "conf_threshold": 0.3},
"wrapper": {"area_threshold": 2.0, "conf_threshold": 0.2}, "wrapper": {"area_threshold": 2.0, "conf_threshold": 0.2},
} }
@ -184,17 +181,42 @@ class LayoutPostprocessor:
def __init__(self, cells: List[Cell], clusters: List[Cluster]): def __init__(self, cells: List[Cell], clusters: List[Cluster]):
"""Initialize processor with cells and clusters.""" """Initialize processor with cells and clusters."""
"""Initialize processor with cells and spatial indices."""
self.cells = cells self.cells = cells
self.regular_clusters = [ self.regular_clusters = [
c for c in clusters if c.label not in self.SPECIAL_TYPES c for c in clusters if c.label not in self.SPECIAL_TYPES
] ]
self.special_clusters = [c for c in clusters if c.label in self.SPECIAL_TYPES] self.special_clusters = [c for c in clusters if c.label in self.SPECIAL_TYPES]
# Build spatial indices once
self.regular_index = SpatialClusterIndex(self.regular_clusters)
self.picture_index = SpatialClusterIndex(
[c for c in self.special_clusters if c.label == DocItemLabel.PICTURE]
)
self.wrapper_index = SpatialClusterIndex(
[c for c in self.special_clusters if c.label in self.WRAPPER_TYPES]
)
def postprocess(self) -> Tuple[List[Cluster], List[Cell]]: def postprocess(self) -> Tuple[List[Cluster], List[Cell]]:
"""Main processing pipeline.""" """Main processing pipeline."""
regular_clusters = self._process_regular_clusters() self.regular_clusters = self._process_regular_clusters()
special_clusters = self._process_special_clusters() self.special_clusters = self._process_special_clusters()
final_clusters = self._sort_clusters(regular_clusters + special_clusters)
# Remove regular clusters that are included in wrappers
contained_ids = {
child.id
for wrapper in self.special_clusters
if wrapper.label in self.SPECIAL_TYPES
for child in wrapper.children
}
self.regular_clusters = [
c for c in self.regular_clusters if c.id not in contained_ids
]
# Combine and sort final clusters
final_clusters = self._sort_clusters(
self.regular_clusters + self.special_clusters
)
return final_clusters, self.cells return final_clusters, self.cells
def _process_regular_clusters(self) -> List[Cluster]: def _process_regular_clusters(self) -> List[Cluster]:
@ -241,49 +263,48 @@ class LayoutPostprocessor:
return clusters return clusters
def _process_special_clusters(self) -> List[Cluster]: def _process_special_clusters(self) -> List[Cluster]:
"""Process special clusters (pictures and wrappers).""" special_clusters = [
# Handle pictures
picture_clusters = [
c c
for c in self.special_clusters for c in self.special_clusters
if c.label == DocItemLabel.PICTURE if c.confidence >= self.CONFIDENCE_THRESHOLDS[c.label]
and c.confidence >= self.CONFIDENCE_THRESHOLDS[c.label]
] ]
picture_clusters = self._remove_overlapping_clusters(
picture_clusters, "picture"
)
# Process wrapper clusters for special in special_clusters:
wrapper_clusters = []
for wrapper in (
c for c in self.special_clusters if c.label in self.WRAPPER_TYPES
):
if wrapper.confidence < self.CONFIDENCE_THRESHOLDS[wrapper.label]:
continue
# Find contained regular clusters
contained = [] contained = []
for cluster in self.regular_clusters: for cluster in self.regular_clusters:
overlap = cluster.bbox.intersection_area_with(wrapper.bbox) overlap = cluster.bbox.intersection_area_with(special.bbox)
if overlap > 0: if overlap > 0:
containment = overlap / cluster.bbox.area() containment = overlap / cluster.bbox.area()
if containment > 0.8: # High containment threshold for wrappers if containment > 0.8:
contained.append(cluster) contained.append(cluster)
if contained: if contained:
wrapper.children = contained special.children = contained
wrapper.bbox = BoundingBox( # Adjust bbox only for wrapper types
if special.label in self.WRAPPER_TYPES:
special.bbox = BoundingBox(
l=min(c.bbox.l for c in contained), l=min(c.bbox.l for c in contained),
t=min(c.bbox.t for c in contained), t=min(c.bbox.t for c in contained),
r=max(c.bbox.r for c in contained), r=max(c.bbox.r for c in contained),
b=max(c.bbox.b for c in contained), b=max(c.bbox.b for c in contained),
) )
wrapper_clusters.append(wrapper)
return picture_clusters + self._remove_overlapping_clusters( picture_clusters = [
c for c in special_clusters if c.label == DocItemLabel.PICTURE
]
picture_clusters = self._remove_overlapping_clusters(
picture_clusters, "picture"
)
wrapper_clusters = [
c for c in special_clusters if c.label in self.WRAPPER_TYPES
]
wrapper_clusters = self._remove_overlapping_clusters(
wrapper_clusters, "wrapper" wrapper_clusters, "wrapper"
) )
return picture_clusters + wrapper_clusters
def _remove_overlapping_clusters( def _remove_overlapping_clusters(
self, self,
clusters: List[Cluster], clusters: List[Cluster],
@ -291,50 +312,71 @@ class LayoutPostprocessor:
overlap_threshold: float = 0.8, overlap_threshold: float = 0.8,
containment_threshold: float = 0.8, containment_threshold: float = 0.8,
) -> List[Cluster]: ) -> List[Cluster]:
"""Remove overlapping clusters using efficient spatial indexing."""
if not clusters: if not clusters:
return [] return []
# Initialize spatial index spatial_index = (
spatial_index = SpatialClusterIndex(clusters) self.regular_index
uf = UnionFind(spatial_index.clusters_by_id.keys()) if cluster_type == "regular"
else self.picture_index if cluster_type == "picture" else self.wrapper_index
)
# Map of currently valid clusters
valid_clusters = {c.id: c for c in clusters}
uf = UnionFind(valid_clusters.keys())
params = self.OVERLAP_PARAMS[cluster_type]
# Group overlapping clusters using spatial index
for cluster in clusters: for cluster in clusters:
candidates = spatial_index.find_candidates(cluster.bbox) candidates = spatial_index.find_candidates(cluster.bbox)
candidates.discard(cluster.id) # Remove self candidates &= valid_clusters.keys() # Only keep existing candidates
candidates.discard(cluster.id)
for other_id in candidates: for other_id in candidates:
if spatial_index.check_overlap( if spatial_index.check_overlap(
cluster.bbox, cluster.bbox,
spatial_index.clusters_by_id[other_id].bbox, valid_clusters[other_id].bbox,
overlap_threshold, overlap_threshold,
containment_threshold, containment_threshold,
): ):
uf.union(cluster.id, other_id) uf.union(cluster.id, other_id)
# Process each group using type-specific parameters
params = self.OVERLAP_PARAMS[cluster_type]
result = [] result = []
for group in uf.get_groups().values():
for group in uf.groups().values():
if len(group) == 1: if len(group) == 1:
result.append(spatial_index.clusters_by_id[group[0]]) result.append(valid_clusters[group[0]])
continue continue
# Get clusters in group group_clusters = [valid_clusters[cid] for cid in group]
group_clusters = [spatial_index.clusters_by_id[cid] for cid in group] current_best = None
# Find best cluster using area and confidence for candidate in group_clusters:
best = self._select_best_cluster( should_select = True
group_clusters, params["area_threshold"], params["conf_threshold"] for other in group_clusters:
) if other == candidate:
continue
# Merge cells from other clusters into best area_ratio = candidate.bbox.area() / other.bbox.area()
conf_diff = other.confidence - candidate.confidence
if (
area_ratio <= params["area_threshold"]
and conf_diff > params["conf_threshold"]
):
should_select = False
break
if should_select:
if current_best is None or (
candidate.bbox.area() > current_best.bbox.area()
and current_best.confidence - candidate.confidence
<= params["conf_threshold"]
):
current_best = candidate
best = current_best if current_best else group_clusters[0]
for cluster in group_clusters: for cluster in group_clusters:
if cluster != best: if cluster != best:
best.cells.extend(cluster.cells) best.cells.extend(cluster.cells)
result.append(best) result.append(best)
return result return result