mirror of
https://github.com/DS4SD/docling.git
synced 2025-07-27 12:34:22 +00:00
Implement hierachical cluster layout processing
Signed-off-by: Christoph Auer <cau@zurich.ibm.com> Signed-off-by: Christoph Auer <cau@zurich.ibm.com>
This commit is contained in:
parent
e0cf80a919
commit
7245cc6080
@ -97,30 +97,32 @@ class LayoutModel(BasePageModel):
|
|||||||
# Function to draw clusters on an image
|
# Function to draw clusters on an image
|
||||||
def draw_clusters(image, clusters):
|
def draw_clusters(image, clusters):
|
||||||
draw = ImageDraw.Draw(image, "RGBA")
|
draw = ImageDraw.Draw(image, "RGBA")
|
||||||
for c in clusters:
|
for c_tl in clusters:
|
||||||
cell_color = (0, 0, 0, 40) # Transparent black for cells
|
all_clusters = [c_tl, *c_tl.children]
|
||||||
for tc in c.cells:
|
for c in all_clusters:
|
||||||
cx0, cy0, cx1, cy1 = tc.bbox.as_tuple()
|
cell_color = (0, 0, 0, 40) # Transparent black for cells
|
||||||
draw.rectangle(
|
for tc in c.cells:
|
||||||
[(cx0, cy0), (cx1, cy1)],
|
cx0, cy0, cx1, cy1 = tc.bbox.as_tuple()
|
||||||
outline=None,
|
draw.rectangle(
|
||||||
fill=cell_color,
|
[(cx0, cy0), (cx1, cy1)],
|
||||||
)
|
outline=None,
|
||||||
|
fill=cell_color,
|
||||||
|
)
|
||||||
|
|
||||||
x0, y0, x1, y1 = c.bbox.as_tuple()
|
x0, y0, x1, y1 = c.bbox.as_tuple()
|
||||||
cluster_fill_color = (
|
cluster_fill_color = (
|
||||||
*list(label_to_color.get(c.label)), # type: ignore
|
*list(label_to_color.get(c.label)), # type: ignore
|
||||||
70,
|
70,
|
||||||
)
|
)
|
||||||
cluster_outline_color = (
|
cluster_outline_color = (
|
||||||
*list(label_to_color.get(c.label)), # type: ignore
|
*list(label_to_color.get(c.label)), # type: ignore
|
||||||
255,
|
255,
|
||||||
)
|
)
|
||||||
draw.rectangle(
|
draw.rectangle(
|
||||||
[(x0, y0), (x1, y1)],
|
[(x0, y0), (x1, y1)],
|
||||||
outline=cluster_outline_color,
|
outline=cluster_outline_color,
|
||||||
fill=cluster_fill_color,
|
fill=cluster_fill_color,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Draw clusters on both images
|
# Draw clusters on both images
|
||||||
draw_clusters(left_image, left_clusters)
|
draw_clusters(left_image, left_clusters)
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
|
import bisect
|
||||||
import logging
|
import logging
|
||||||
from bisect import bisect_left
|
from collections import defaultdict
|
||||||
from typing import Dict, List, Set, Tuple
|
from typing import Dict, List, Set, Tuple
|
||||||
|
|
||||||
from docling_core.types.doc import DocItemLabel
|
from docling_core.types.doc import DocItemLabel
|
||||||
@ -11,7 +12,7 @@ _log = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class UnionFind:
|
class UnionFind:
|
||||||
"""Union-Find (Disjoint-Set) data structure."""
|
"""Efficient Union-Find data structure for grouping elements."""
|
||||||
|
|
||||||
def __init__(self, elements):
|
def __init__(self, elements):
|
||||||
self.parent = {elem: elem for elem in elements}
|
self.parent = {elem: elem for elem in elements}
|
||||||
@ -23,73 +24,61 @@ class UnionFind:
|
|||||||
return self.parent[x]
|
return self.parent[x]
|
||||||
|
|
||||||
def union(self, x, y):
|
def union(self, x, y):
|
||||||
root_x = self.find(x)
|
root_x, root_y = self.find(x), self.find(y)
|
||||||
root_y = self.find(y)
|
if root_x == root_y:
|
||||||
if root_x != root_y:
|
return
|
||||||
if self.rank[root_x] > self.rank[root_y]:
|
|
||||||
self.parent[root_y] = root_x
|
|
||||||
elif self.rank[root_x] < self.rank[root_y]:
|
|
||||||
self.parent[root_x] = root_y
|
|
||||||
else:
|
|
||||||
self.parent[root_y] = root_x
|
|
||||||
self.rank[root_x] += 1
|
|
||||||
|
|
||||||
def groups(self):
|
if self.rank[root_x] > self.rank[root_y]:
|
||||||
"""Return groups as {root: [elements]}."""
|
self.parent[root_y] = root_x
|
||||||
from collections import defaultdict
|
elif self.rank[root_x] < self.rank[root_y]:
|
||||||
|
self.parent[root_x] = root_y
|
||||||
|
else:
|
||||||
|
self.parent[root_y] = root_x
|
||||||
|
self.rank[root_x] += 1
|
||||||
|
|
||||||
components = defaultdict(list)
|
def get_groups(self) -> Dict[int, List[int]]:
|
||||||
|
"""Returns groups as {root: [elements]}."""
|
||||||
|
groups = defaultdict(list)
|
||||||
for elem in self.parent:
|
for elem in self.parent:
|
||||||
root = self.find(elem)
|
groups[self.find(elem)].append(elem)
|
||||||
components[root].append(elem)
|
return groups
|
||||||
return components
|
|
||||||
|
|
||||||
|
|
||||||
class SpatialClusterIndex:
|
class SpatialClusterIndex:
|
||||||
"""Helper class to manage spatial indexes for clusters and find overlaps efficiently."""
|
"""Efficient spatial indexing for clusters using R-tree and interval trees."""
|
||||||
|
|
||||||
def __init__(self, clusters: List[Cluster]):
|
def __init__(self, clusters: List[Cluster]):
|
||||||
# Create spatial index
|
|
||||||
p = index.Property()
|
p = index.Property()
|
||||||
p.dimension = 2
|
p.dimension = 2
|
||||||
self.spatial_index = index.Index(properties=p)
|
self.spatial_index = index.Index(properties=p)
|
||||||
|
|
||||||
# Initialize interval trees
|
|
||||||
self.x_intervals = IntervalTree()
|
self.x_intervals = IntervalTree()
|
||||||
self.y_intervals = IntervalTree()
|
self.y_intervals = IntervalTree()
|
||||||
|
self.clusters_by_id: Dict[int, Cluster] = {}
|
||||||
|
|
||||||
# Map to store clusters by ID
|
|
||||||
self.clusters_by_id = {} # type: ignore
|
|
||||||
|
|
||||||
# Populate indexes and maps
|
|
||||||
for cluster in clusters:
|
for cluster in clusters:
|
||||||
self.add_cluster(cluster)
|
self.add_cluster(cluster)
|
||||||
|
|
||||||
def add_cluster(self, cluster: Cluster):
|
def add_cluster(self, cluster: Cluster):
|
||||||
"""Add a cluster to all indexes."""
|
bbox = cluster.bbox
|
||||||
self.spatial_index.insert(cluster.id, cluster.bbox.as_tuple())
|
self.spatial_index.insert(cluster.id, bbox.as_tuple())
|
||||||
self.x_intervals.insert(cluster.bbox.l, cluster.bbox.r, cluster.id)
|
self.x_intervals.insert(bbox.l, bbox.r, cluster.id)
|
||||||
self.y_intervals.insert(cluster.bbox.t, cluster.bbox.b, cluster.id)
|
self.y_intervals.insert(bbox.t, bbox.b, cluster.id)
|
||||||
self.clusters_by_id[cluster.id] = cluster
|
self.clusters_by_id[cluster.id] = cluster
|
||||||
|
|
||||||
def remove_cluster(self, cluster: Cluster):
|
def remove_cluster(self, cluster: Cluster):
|
||||||
"""Remove a cluster from all indexes."""
|
|
||||||
self.spatial_index.delete(cluster.id, cluster.bbox.as_tuple())
|
self.spatial_index.delete(cluster.id, cluster.bbox.as_tuple())
|
||||||
# Note: IntervalTree doesn't support deletion, but we handle this
|
|
||||||
# by checking clusters_by_id membership
|
|
||||||
del self.clusters_by_id[cluster.id]
|
del self.clusters_by_id[cluster.id]
|
||||||
|
|
||||||
def find_candidates(self, bbox: BoundingBox) -> Set[int]:
|
def find_candidates(self, bbox: BoundingBox) -> Set[int]:
|
||||||
"""Find all potential overlapping cluster IDs using all indexes."""
|
"""Find potential overlapping cluster IDs using all indexes."""
|
||||||
bbox_tuple = bbox.as_tuple()
|
spatial = set(self.spatial_index.intersection(bbox.as_tuple()))
|
||||||
spatial_candidates = set(self.spatial_index.intersection(bbox_tuple))
|
|
||||||
x_candidates = self.x_intervals.find_containing(
|
x_candidates = self.x_intervals.find_containing(
|
||||||
bbox.l
|
bbox.l
|
||||||
) | self.x_intervals.find_containing(bbox.r)
|
) | self.x_intervals.find_containing(bbox.r)
|
||||||
y_candidates = self.y_intervals.find_containing(
|
y_candidates = self.y_intervals.find_containing(
|
||||||
bbox.t
|
bbox.t
|
||||||
) | self.y_intervals.find_containing(bbox.b)
|
) | self.y_intervals.find_containing(bbox.b)
|
||||||
return spatial_candidates | x_candidates | y_candidates
|
return spatial | x_candidates | y_candidates
|
||||||
|
|
||||||
def check_overlap(
|
def check_overlap(
|
||||||
self,
|
self,
|
||||||
@ -99,8 +88,7 @@ class SpatialClusterIndex:
|
|||||||
containment_threshold: float,
|
containment_threshold: float,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Check if two bboxes overlap sufficiently."""
|
"""Check if two bboxes overlap sufficiently."""
|
||||||
area1 = bbox1.area()
|
area1, area2 = bbox1.area(), bbox2.area()
|
||||||
area2 = bbox2.area()
|
|
||||||
if area1 <= 0 or area2 <= 0:
|
if area1 <= 0 or area2 <= 0:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -108,38 +96,47 @@ class SpatialClusterIndex:
|
|||||||
if overlap_area <= 0:
|
if overlap_area <= 0:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Check both IoU and containment
|
|
||||||
iou = overlap_area / (area1 + area2 - overlap_area)
|
iou = overlap_area / (area1 + area2 - overlap_area)
|
||||||
containment_ratio1 = overlap_area / area1
|
containment1 = overlap_area / area1
|
||||||
containment_ratio2 = overlap_area / area2
|
containment2 = overlap_area / area2
|
||||||
|
|
||||||
return (
|
return (
|
||||||
iou > overlap_threshold
|
iou > overlap_threshold
|
||||||
or containment_ratio1 > containment_threshold
|
or containment1 > containment_threshold
|
||||||
or containment_ratio2 > containment_threshold
|
or containment2 > containment_threshold
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class IntervalTree:
|
class IntervalTree:
|
||||||
def __init__(self):
|
"""Memory-efficient interval tree for 1D overlap queries."""
|
||||||
self.intervals = [] # List of (min, max, box_id) sorted by min
|
|
||||||
|
|
||||||
def insert(self, min_val: float, max_val: float, box_id: int):
|
def __init__(self):
|
||||||
self.intervals.append((min_val, max_val, box_id))
|
self.intervals: List[Tuple[float, float, int]] = (
|
||||||
self.intervals.sort(key=lambda x: x[0])
|
[]
|
||||||
|
) # (min, max, id) sorted by min
|
||||||
|
|
||||||
|
def insert(self, min_val: float, max_val: float, id: int):
|
||||||
|
bisect.insort(self.intervals, (min_val, max_val, id), key=lambda x: x[0])
|
||||||
|
|
||||||
def find_containing(self, point: float) -> Set[int]:
|
def find_containing(self, point: float) -> Set[int]:
|
||||||
pos = bisect_left(self.intervals, (point, float("-inf"), -1))
|
"""Find all intervals containing the point."""
|
||||||
|
pos = bisect.bisect_left(self.intervals, (point, float("-inf"), -1))
|
||||||
result = set()
|
result = set()
|
||||||
|
|
||||||
i = pos - 1
|
# Check intervals starting before point
|
||||||
while i >= 0:
|
for min_val, max_val, id in reversed(self.intervals[:pos]):
|
||||||
min_val, max_val, box_id = self.intervals[i]
|
|
||||||
if min_val <= point <= max_val:
|
if min_val <= point <= max_val:
|
||||||
result.add(box_id)
|
result.add(id)
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Check intervals starting at/after point
|
||||||
|
for min_val, max_val, id in self.intervals[pos:]:
|
||||||
|
if point <= max_val:
|
||||||
|
if min_val <= point:
|
||||||
|
result.add(id)
|
||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
i -= 1
|
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@ -149,7 +146,7 @@ class LayoutPostprocessor:
|
|||||||
|
|
||||||
# Cluster type-specific parameters for overlap resolution
|
# Cluster type-specific parameters for overlap resolution
|
||||||
OVERLAP_PARAMS = {
|
OVERLAP_PARAMS = {
|
||||||
"regular": {"area_threshold": 1.3, "conf_threshold": 0.15},
|
"regular": {"area_threshold": 1.3, "conf_threshold": 0.05},
|
||||||
"picture": {"area_threshold": 2.0, "conf_threshold": 0.3},
|
"picture": {"area_threshold": 2.0, "conf_threshold": 0.3},
|
||||||
"wrapper": {"area_threshold": 2.0, "conf_threshold": 0.2},
|
"wrapper": {"area_threshold": 2.0, "conf_threshold": 0.2},
|
||||||
}
|
}
|
||||||
@ -184,17 +181,42 @@ class LayoutPostprocessor:
|
|||||||
|
|
||||||
def __init__(self, cells: List[Cell], clusters: List[Cluster]):
|
def __init__(self, cells: List[Cell], clusters: List[Cluster]):
|
||||||
"""Initialize processor with cells and clusters."""
|
"""Initialize processor with cells and clusters."""
|
||||||
|
"""Initialize processor with cells and spatial indices."""
|
||||||
self.cells = cells
|
self.cells = cells
|
||||||
self.regular_clusters = [
|
self.regular_clusters = [
|
||||||
c for c in clusters if c.label not in self.SPECIAL_TYPES
|
c for c in clusters if c.label not in self.SPECIAL_TYPES
|
||||||
]
|
]
|
||||||
self.special_clusters = [c for c in clusters if c.label in self.SPECIAL_TYPES]
|
self.special_clusters = [c for c in clusters if c.label in self.SPECIAL_TYPES]
|
||||||
|
|
||||||
|
# Build spatial indices once
|
||||||
|
self.regular_index = SpatialClusterIndex(self.regular_clusters)
|
||||||
|
self.picture_index = SpatialClusterIndex(
|
||||||
|
[c for c in self.special_clusters if c.label == DocItemLabel.PICTURE]
|
||||||
|
)
|
||||||
|
self.wrapper_index = SpatialClusterIndex(
|
||||||
|
[c for c in self.special_clusters if c.label in self.WRAPPER_TYPES]
|
||||||
|
)
|
||||||
|
|
||||||
def postprocess(self) -> Tuple[List[Cluster], List[Cell]]:
|
def postprocess(self) -> Tuple[List[Cluster], List[Cell]]:
|
||||||
"""Main processing pipeline."""
|
"""Main processing pipeline."""
|
||||||
regular_clusters = self._process_regular_clusters()
|
self.regular_clusters = self._process_regular_clusters()
|
||||||
special_clusters = self._process_special_clusters()
|
self.special_clusters = self._process_special_clusters()
|
||||||
final_clusters = self._sort_clusters(regular_clusters + special_clusters)
|
|
||||||
|
# Remove regular clusters that are included in wrappers
|
||||||
|
contained_ids = {
|
||||||
|
child.id
|
||||||
|
for wrapper in self.special_clusters
|
||||||
|
if wrapper.label in self.SPECIAL_TYPES
|
||||||
|
for child in wrapper.children
|
||||||
|
}
|
||||||
|
self.regular_clusters = [
|
||||||
|
c for c in self.regular_clusters if c.id not in contained_ids
|
||||||
|
]
|
||||||
|
|
||||||
|
# Combine and sort final clusters
|
||||||
|
final_clusters = self._sort_clusters(
|
||||||
|
self.regular_clusters + self.special_clusters
|
||||||
|
)
|
||||||
return final_clusters, self.cells
|
return final_clusters, self.cells
|
||||||
|
|
||||||
def _process_regular_clusters(self) -> List[Cluster]:
|
def _process_regular_clusters(self) -> List[Cluster]:
|
||||||
@ -241,49 +263,48 @@ class LayoutPostprocessor:
|
|||||||
return clusters
|
return clusters
|
||||||
|
|
||||||
def _process_special_clusters(self) -> List[Cluster]:
|
def _process_special_clusters(self) -> List[Cluster]:
|
||||||
"""Process special clusters (pictures and wrappers)."""
|
special_clusters = [
|
||||||
# Handle pictures
|
|
||||||
picture_clusters = [
|
|
||||||
c
|
c
|
||||||
for c in self.special_clusters
|
for c in self.special_clusters
|
||||||
if c.label == DocItemLabel.PICTURE
|
if c.confidence >= self.CONFIDENCE_THRESHOLDS[c.label]
|
||||||
and c.confidence >= self.CONFIDENCE_THRESHOLDS[c.label]
|
]
|
||||||
|
|
||||||
|
for special in special_clusters:
|
||||||
|
contained = []
|
||||||
|
for cluster in self.regular_clusters:
|
||||||
|
overlap = cluster.bbox.intersection_area_with(special.bbox)
|
||||||
|
if overlap > 0:
|
||||||
|
containment = overlap / cluster.bbox.area()
|
||||||
|
if containment > 0.8:
|
||||||
|
contained.append(cluster)
|
||||||
|
|
||||||
|
if contained:
|
||||||
|
special.children = contained
|
||||||
|
# Adjust bbox only for wrapper types
|
||||||
|
if special.label in self.WRAPPER_TYPES:
|
||||||
|
special.bbox = BoundingBox(
|
||||||
|
l=min(c.bbox.l for c in contained),
|
||||||
|
t=min(c.bbox.t for c in contained),
|
||||||
|
r=max(c.bbox.r for c in contained),
|
||||||
|
b=max(c.bbox.b for c in contained),
|
||||||
|
)
|
||||||
|
|
||||||
|
picture_clusters = [
|
||||||
|
c for c in special_clusters if c.label == DocItemLabel.PICTURE
|
||||||
]
|
]
|
||||||
picture_clusters = self._remove_overlapping_clusters(
|
picture_clusters = self._remove_overlapping_clusters(
|
||||||
picture_clusters, "picture"
|
picture_clusters, "picture"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Process wrapper clusters
|
wrapper_clusters = [
|
||||||
wrapper_clusters = []
|
c for c in special_clusters if c.label in self.WRAPPER_TYPES
|
||||||
for wrapper in (
|
]
|
||||||
c for c in self.special_clusters if c.label in self.WRAPPER_TYPES
|
wrapper_clusters = self._remove_overlapping_clusters(
|
||||||
):
|
|
||||||
if wrapper.confidence < self.CONFIDENCE_THRESHOLDS[wrapper.label]:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Find contained regular clusters
|
|
||||||
contained = []
|
|
||||||
for cluster in self.regular_clusters:
|
|
||||||
overlap = cluster.bbox.intersection_area_with(wrapper.bbox)
|
|
||||||
if overlap > 0:
|
|
||||||
containment = overlap / cluster.bbox.area()
|
|
||||||
if containment > 0.8: # High containment threshold for wrappers
|
|
||||||
contained.append(cluster)
|
|
||||||
|
|
||||||
if contained:
|
|
||||||
wrapper.children = contained
|
|
||||||
wrapper.bbox = BoundingBox(
|
|
||||||
l=min(c.bbox.l for c in contained),
|
|
||||||
t=min(c.bbox.t for c in contained),
|
|
||||||
r=max(c.bbox.r for c in contained),
|
|
||||||
b=max(c.bbox.b for c in contained),
|
|
||||||
)
|
|
||||||
wrapper_clusters.append(wrapper)
|
|
||||||
|
|
||||||
return picture_clusters + self._remove_overlapping_clusters(
|
|
||||||
wrapper_clusters, "wrapper"
|
wrapper_clusters, "wrapper"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return picture_clusters + wrapper_clusters
|
||||||
|
|
||||||
def _remove_overlapping_clusters(
|
def _remove_overlapping_clusters(
|
||||||
self,
|
self,
|
||||||
clusters: List[Cluster],
|
clusters: List[Cluster],
|
||||||
@ -291,50 +312,71 @@ class LayoutPostprocessor:
|
|||||||
overlap_threshold: float = 0.8,
|
overlap_threshold: float = 0.8,
|
||||||
containment_threshold: float = 0.8,
|
containment_threshold: float = 0.8,
|
||||||
) -> List[Cluster]:
|
) -> List[Cluster]:
|
||||||
"""Remove overlapping clusters using efficient spatial indexing."""
|
|
||||||
if not clusters:
|
if not clusters:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
# Initialize spatial index
|
spatial_index = (
|
||||||
spatial_index = SpatialClusterIndex(clusters)
|
self.regular_index
|
||||||
uf = UnionFind(spatial_index.clusters_by_id.keys())
|
if cluster_type == "regular"
|
||||||
|
else self.picture_index if cluster_type == "picture" else self.wrapper_index
|
||||||
|
)
|
||||||
|
|
||||||
|
# Map of currently valid clusters
|
||||||
|
valid_clusters = {c.id: c for c in clusters}
|
||||||
|
uf = UnionFind(valid_clusters.keys())
|
||||||
|
params = self.OVERLAP_PARAMS[cluster_type]
|
||||||
|
|
||||||
# Group overlapping clusters using spatial index
|
|
||||||
for cluster in clusters:
|
for cluster in clusters:
|
||||||
candidates = spatial_index.find_candidates(cluster.bbox)
|
candidates = spatial_index.find_candidates(cluster.bbox)
|
||||||
candidates.discard(cluster.id) # Remove self
|
candidates &= valid_clusters.keys() # Only keep existing candidates
|
||||||
|
candidates.discard(cluster.id)
|
||||||
|
|
||||||
for other_id in candidates:
|
for other_id in candidates:
|
||||||
if spatial_index.check_overlap(
|
if spatial_index.check_overlap(
|
||||||
cluster.bbox,
|
cluster.bbox,
|
||||||
spatial_index.clusters_by_id[other_id].bbox,
|
valid_clusters[other_id].bbox,
|
||||||
overlap_threshold,
|
overlap_threshold,
|
||||||
containment_threshold,
|
containment_threshold,
|
||||||
):
|
):
|
||||||
uf.union(cluster.id, other_id)
|
uf.union(cluster.id, other_id)
|
||||||
|
|
||||||
# Process each group using type-specific parameters
|
|
||||||
params = self.OVERLAP_PARAMS[cluster_type]
|
|
||||||
result = []
|
result = []
|
||||||
|
for group in uf.get_groups().values():
|
||||||
for group in uf.groups().values():
|
|
||||||
if len(group) == 1:
|
if len(group) == 1:
|
||||||
result.append(spatial_index.clusters_by_id[group[0]])
|
result.append(valid_clusters[group[0]])
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Get clusters in group
|
group_clusters = [valid_clusters[cid] for cid in group]
|
||||||
group_clusters = [spatial_index.clusters_by_id[cid] for cid in group]
|
current_best = None
|
||||||
|
|
||||||
# Find best cluster using area and confidence
|
for candidate in group_clusters:
|
||||||
best = self._select_best_cluster(
|
should_select = True
|
||||||
group_clusters, params["area_threshold"], params["conf_threshold"]
|
for other in group_clusters:
|
||||||
)
|
if other == candidate:
|
||||||
|
continue
|
||||||
|
|
||||||
# Merge cells from other clusters into best
|
area_ratio = candidate.bbox.area() / other.bbox.area()
|
||||||
|
conf_diff = other.confidence - candidate.confidence
|
||||||
|
|
||||||
|
if (
|
||||||
|
area_ratio <= params["area_threshold"]
|
||||||
|
and conf_diff > params["conf_threshold"]
|
||||||
|
):
|
||||||
|
should_select = False
|
||||||
|
break
|
||||||
|
|
||||||
|
if should_select:
|
||||||
|
if current_best is None or (
|
||||||
|
candidate.bbox.area() > current_best.bbox.area()
|
||||||
|
and current_best.confidence - candidate.confidence
|
||||||
|
<= params["conf_threshold"]
|
||||||
|
):
|
||||||
|
current_best = candidate
|
||||||
|
|
||||||
|
best = current_best if current_best else group_clusters[0]
|
||||||
for cluster in group_clusters:
|
for cluster in group_clusters:
|
||||||
if cluster != best:
|
if cluster != best:
|
||||||
best.cells.extend(cluster.cells)
|
best.cells.extend(cluster.cells)
|
||||||
|
|
||||||
result.append(best)
|
result.append(best)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
Loading…
Reference in New Issue
Block a user