diff --git a/docling/models/layout_model.py b/docling/models/layout_model.py index 624f8e02..b0536b6d 100644 --- a/docling/models/layout_model.py +++ b/docling/models/layout_model.py @@ -97,30 +97,32 @@ class LayoutModel(BasePageModel): # Function to draw clusters on an image def draw_clusters(image, clusters): draw = ImageDraw.Draw(image, "RGBA") - for c in clusters: - cell_color = (0, 0, 0, 40) # Transparent black for cells - for tc in c.cells: - cx0, cy0, cx1, cy1 = tc.bbox.as_tuple() - draw.rectangle( - [(cx0, cy0), (cx1, cy1)], - outline=None, - fill=cell_color, - ) + for c_tl in clusters: + all_clusters = [c_tl, *c_tl.children] + for c in all_clusters: + cell_color = (0, 0, 0, 40) # Transparent black for cells + for tc in c.cells: + cx0, cy0, cx1, cy1 = tc.bbox.as_tuple() + draw.rectangle( + [(cx0, cy0), (cx1, cy1)], + outline=None, + fill=cell_color, + ) - x0, y0, x1, y1 = c.bbox.as_tuple() - cluster_fill_color = ( - *list(label_to_color.get(c.label)), # type: ignore - 70, - ) - cluster_outline_color = ( - *list(label_to_color.get(c.label)), # type: ignore - 255, - ) - draw.rectangle( - [(x0, y0), (x1, y1)], - outline=cluster_outline_color, - fill=cluster_fill_color, - ) + x0, y0, x1, y1 = c.bbox.as_tuple() + cluster_fill_color = ( + *list(label_to_color.get(c.label)), # type: ignore + 70, + ) + cluster_outline_color = ( + *list(label_to_color.get(c.label)), # type: ignore + 255, + ) + draw.rectangle( + [(x0, y0), (x1, y1)], + outline=cluster_outline_color, + fill=cluster_fill_color, + ) # Draw clusters on both images draw_clusters(left_image, left_clusters) diff --git a/docling/utils/layout_postprocessor.py b/docling/utils/layout_postprocessor.py index 76d548c6..daa47469 100644 --- a/docling/utils/layout_postprocessor.py +++ b/docling/utils/layout_postprocessor.py @@ -1,5 +1,6 @@ +import bisect import logging -from bisect import bisect_left +from collections import defaultdict from typing import Dict, List, Set, Tuple from docling_core.types.doc import DocItemLabel @@ -11,7 +12,7 @@ _log = logging.getLogger(__name__) class UnionFind: - """Union-Find (Disjoint-Set) data structure.""" + """Efficient Union-Find data structure for grouping elements.""" def __init__(self, elements): self.parent = {elem: elem for elem in elements} @@ -23,73 +24,61 @@ class UnionFind: return self.parent[x] def union(self, x, y): - root_x = self.find(x) - root_y = self.find(y) - if root_x != root_y: - if self.rank[root_x] > self.rank[root_y]: - self.parent[root_y] = root_x - elif self.rank[root_x] < self.rank[root_y]: - self.parent[root_x] = root_y - else: - self.parent[root_y] = root_x - self.rank[root_x] += 1 + root_x, root_y = self.find(x), self.find(y) + if root_x == root_y: + return - def groups(self): - """Return groups as {root: [elements]}.""" - from collections import defaultdict + if self.rank[root_x] > self.rank[root_y]: + self.parent[root_y] = root_x + elif self.rank[root_x] < self.rank[root_y]: + self.parent[root_x] = root_y + else: + self.parent[root_y] = root_x + self.rank[root_x] += 1 - components = defaultdict(list) + def get_groups(self) -> Dict[int, List[int]]: + """Returns groups as {root: [elements]}.""" + groups = defaultdict(list) for elem in self.parent: - root = self.find(elem) - components[root].append(elem) - return components + groups[self.find(elem)].append(elem) + return groups class SpatialClusterIndex: - """Helper class to manage spatial indexes for clusters and find overlaps efficiently.""" + """Efficient spatial indexing for clusters using R-tree and interval trees.""" def __init__(self, clusters: List[Cluster]): - # Create spatial index p = index.Property() p.dimension = 2 self.spatial_index = index.Index(properties=p) - - # Initialize interval trees self.x_intervals = IntervalTree() self.y_intervals = IntervalTree() + self.clusters_by_id: Dict[int, Cluster] = {} - # Map to store clusters by ID - self.clusters_by_id = {} # type: ignore - - # Populate indexes and maps for cluster in clusters: self.add_cluster(cluster) def add_cluster(self, cluster: Cluster): - """Add a cluster to all indexes.""" - self.spatial_index.insert(cluster.id, cluster.bbox.as_tuple()) - self.x_intervals.insert(cluster.bbox.l, cluster.bbox.r, cluster.id) - self.y_intervals.insert(cluster.bbox.t, cluster.bbox.b, cluster.id) + bbox = cluster.bbox + self.spatial_index.insert(cluster.id, bbox.as_tuple()) + self.x_intervals.insert(bbox.l, bbox.r, cluster.id) + self.y_intervals.insert(bbox.t, bbox.b, cluster.id) self.clusters_by_id[cluster.id] = cluster def remove_cluster(self, cluster: Cluster): - """Remove a cluster from all indexes.""" self.spatial_index.delete(cluster.id, cluster.bbox.as_tuple()) - # Note: IntervalTree doesn't support deletion, but we handle this - # by checking clusters_by_id membership del self.clusters_by_id[cluster.id] def find_candidates(self, bbox: BoundingBox) -> Set[int]: - """Find all potential overlapping cluster IDs using all indexes.""" - bbox_tuple = bbox.as_tuple() - spatial_candidates = set(self.spatial_index.intersection(bbox_tuple)) + """Find potential overlapping cluster IDs using all indexes.""" + spatial = set(self.spatial_index.intersection(bbox.as_tuple())) x_candidates = self.x_intervals.find_containing( bbox.l ) | self.x_intervals.find_containing(bbox.r) y_candidates = self.y_intervals.find_containing( bbox.t ) | self.y_intervals.find_containing(bbox.b) - return spatial_candidates | x_candidates | y_candidates + return spatial | x_candidates | y_candidates def check_overlap( self, @@ -99,8 +88,7 @@ class SpatialClusterIndex: containment_threshold: float, ) -> bool: """Check if two bboxes overlap sufficiently.""" - area1 = bbox1.area() - area2 = bbox2.area() + area1, area2 = bbox1.area(), bbox2.area() if area1 <= 0 or area2 <= 0: return False @@ -108,38 +96,47 @@ class SpatialClusterIndex: if overlap_area <= 0: return False - # Check both IoU and containment iou = overlap_area / (area1 + area2 - overlap_area) - containment_ratio1 = overlap_area / area1 - containment_ratio2 = overlap_area / area2 + containment1 = overlap_area / area1 + containment2 = overlap_area / area2 return ( iou > overlap_threshold - or containment_ratio1 > containment_threshold - or containment_ratio2 > containment_threshold + or containment1 > containment_threshold + or containment2 > containment_threshold ) class IntervalTree: - def __init__(self): - self.intervals = [] # List of (min, max, box_id) sorted by min + """Memory-efficient interval tree for 1D overlap queries.""" - def insert(self, min_val: float, max_val: float, box_id: int): - self.intervals.append((min_val, max_val, box_id)) - self.intervals.sort(key=lambda x: x[0]) + def __init__(self): + self.intervals: List[Tuple[float, float, int]] = ( + [] + ) # (min, max, id) sorted by min + + def insert(self, min_val: float, max_val: float, id: int): + bisect.insort(self.intervals, (min_val, max_val, id), key=lambda x: x[0]) def find_containing(self, point: float) -> Set[int]: - pos = bisect_left(self.intervals, (point, float("-inf"), -1)) + """Find all intervals containing the point.""" + pos = bisect.bisect_left(self.intervals, (point, float("-inf"), -1)) result = set() - i = pos - 1 - while i >= 0: - min_val, max_val, box_id = self.intervals[i] + # Check intervals starting before point + for min_val, max_val, id in reversed(self.intervals[:pos]): if min_val <= point <= max_val: - result.add(box_id) + result.add(id) + else: + break + + # Check intervals starting at/after point + for min_val, max_val, id in self.intervals[pos:]: + if point <= max_val: + if min_val <= point: + result.add(id) else: break - i -= 1 return result @@ -149,7 +146,7 @@ class LayoutPostprocessor: # Cluster type-specific parameters for overlap resolution OVERLAP_PARAMS = { - "regular": {"area_threshold": 1.3, "conf_threshold": 0.15}, + "regular": {"area_threshold": 1.3, "conf_threshold": 0.05}, "picture": {"area_threshold": 2.0, "conf_threshold": 0.3}, "wrapper": {"area_threshold": 2.0, "conf_threshold": 0.2}, } @@ -184,17 +181,42 @@ class LayoutPostprocessor: def __init__(self, cells: List[Cell], clusters: List[Cluster]): """Initialize processor with cells and clusters.""" + """Initialize processor with cells and spatial indices.""" self.cells = cells self.regular_clusters = [ c for c in clusters if c.label not in self.SPECIAL_TYPES ] self.special_clusters = [c for c in clusters if c.label in self.SPECIAL_TYPES] + # Build spatial indices once + self.regular_index = SpatialClusterIndex(self.regular_clusters) + self.picture_index = SpatialClusterIndex( + [c for c in self.special_clusters if c.label == DocItemLabel.PICTURE] + ) + self.wrapper_index = SpatialClusterIndex( + [c for c in self.special_clusters if c.label in self.WRAPPER_TYPES] + ) + def postprocess(self) -> Tuple[List[Cluster], List[Cell]]: """Main processing pipeline.""" - regular_clusters = self._process_regular_clusters() - special_clusters = self._process_special_clusters() - final_clusters = self._sort_clusters(regular_clusters + special_clusters) + self.regular_clusters = self._process_regular_clusters() + self.special_clusters = self._process_special_clusters() + + # Remove regular clusters that are included in wrappers + contained_ids = { + child.id + for wrapper in self.special_clusters + if wrapper.label in self.SPECIAL_TYPES + for child in wrapper.children + } + self.regular_clusters = [ + c for c in self.regular_clusters if c.id not in contained_ids + ] + + # Combine and sort final clusters + final_clusters = self._sort_clusters( + self.regular_clusters + self.special_clusters + ) return final_clusters, self.cells def _process_regular_clusters(self) -> List[Cluster]: @@ -241,49 +263,48 @@ class LayoutPostprocessor: return clusters def _process_special_clusters(self) -> List[Cluster]: - """Process special clusters (pictures and wrappers).""" - # Handle pictures - picture_clusters = [ + special_clusters = [ c for c in self.special_clusters - if c.label == DocItemLabel.PICTURE - and c.confidence >= self.CONFIDENCE_THRESHOLDS[c.label] + if c.confidence >= self.CONFIDENCE_THRESHOLDS[c.label] + ] + + for special in special_clusters: + contained = [] + for cluster in self.regular_clusters: + overlap = cluster.bbox.intersection_area_with(special.bbox) + if overlap > 0: + containment = overlap / cluster.bbox.area() + if containment > 0.8: + contained.append(cluster) + + if contained: + special.children = contained + # Adjust bbox only for wrapper types + if special.label in self.WRAPPER_TYPES: + special.bbox = BoundingBox( + l=min(c.bbox.l for c in contained), + t=min(c.bbox.t for c in contained), + r=max(c.bbox.r for c in contained), + b=max(c.bbox.b for c in contained), + ) + + picture_clusters = [ + c for c in special_clusters if c.label == DocItemLabel.PICTURE ] picture_clusters = self._remove_overlapping_clusters( picture_clusters, "picture" ) - # Process wrapper clusters - wrapper_clusters = [] - for wrapper in ( - c for c in self.special_clusters if c.label in self.WRAPPER_TYPES - ): - if wrapper.confidence < self.CONFIDENCE_THRESHOLDS[wrapper.label]: - continue - - # Find contained regular clusters - contained = [] - for cluster in self.regular_clusters: - overlap = cluster.bbox.intersection_area_with(wrapper.bbox) - if overlap > 0: - containment = overlap / cluster.bbox.area() - if containment > 0.8: # High containment threshold for wrappers - contained.append(cluster) - - if contained: - wrapper.children = contained - wrapper.bbox = BoundingBox( - l=min(c.bbox.l for c in contained), - t=min(c.bbox.t for c in contained), - r=max(c.bbox.r for c in contained), - b=max(c.bbox.b for c in contained), - ) - wrapper_clusters.append(wrapper) - - return picture_clusters + self._remove_overlapping_clusters( + wrapper_clusters = [ + c for c in special_clusters if c.label in self.WRAPPER_TYPES + ] + wrapper_clusters = self._remove_overlapping_clusters( wrapper_clusters, "wrapper" ) + return picture_clusters + wrapper_clusters + def _remove_overlapping_clusters( self, clusters: List[Cluster], @@ -291,50 +312,71 @@ class LayoutPostprocessor: overlap_threshold: float = 0.8, containment_threshold: float = 0.8, ) -> List[Cluster]: - """Remove overlapping clusters using efficient spatial indexing.""" if not clusters: return [] - # Initialize spatial index - spatial_index = SpatialClusterIndex(clusters) - uf = UnionFind(spatial_index.clusters_by_id.keys()) + spatial_index = ( + self.regular_index + if cluster_type == "regular" + else self.picture_index if cluster_type == "picture" else self.wrapper_index + ) + + # Map of currently valid clusters + valid_clusters = {c.id: c for c in clusters} + uf = UnionFind(valid_clusters.keys()) + params = self.OVERLAP_PARAMS[cluster_type] - # Group overlapping clusters using spatial index for cluster in clusters: candidates = spatial_index.find_candidates(cluster.bbox) - candidates.discard(cluster.id) # Remove self + candidates &= valid_clusters.keys() # Only keep existing candidates + candidates.discard(cluster.id) for other_id in candidates: if spatial_index.check_overlap( cluster.bbox, - spatial_index.clusters_by_id[other_id].bbox, + valid_clusters[other_id].bbox, overlap_threshold, containment_threshold, ): uf.union(cluster.id, other_id) - # Process each group using type-specific parameters - params = self.OVERLAP_PARAMS[cluster_type] result = [] - - for group in uf.groups().values(): + for group in uf.get_groups().values(): if len(group) == 1: - result.append(spatial_index.clusters_by_id[group[0]]) + result.append(valid_clusters[group[0]]) continue - # Get clusters in group - group_clusters = [spatial_index.clusters_by_id[cid] for cid in group] + group_clusters = [valid_clusters[cid] for cid in group] + current_best = None - # Find best cluster using area and confidence - best = self._select_best_cluster( - group_clusters, params["area_threshold"], params["conf_threshold"] - ) + for candidate in group_clusters: + should_select = True + for other in group_clusters: + if other == candidate: + continue - # Merge cells from other clusters into best + area_ratio = candidate.bbox.area() / other.bbox.area() + conf_diff = other.confidence - candidate.confidence + + if ( + area_ratio <= params["area_threshold"] + and conf_diff > params["conf_threshold"] + ): + should_select = False + break + + if should_select: + if current_best is None or ( + candidate.bbox.area() > current_best.bbox.area() + and current_best.confidence - candidate.confidence + <= params["conf_threshold"] + ): + current_best = candidate + + best = current_best if current_best else group_clusters[0] for cluster in group_clusters: if cluster != best: best.cells.extend(cluster.cells) - result.append(best) return result