feat: Updated Layout processing with forms and key-value areas (#530)

* Upgraded Layout Postprocessing, sending old code back to ERZ

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Implement hierachical cluster layout processing

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Pass nested cluster processing through full pipeline

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Pass nested clusters through GLM as payload

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Move to_docling_document from ds-glm to this repo

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Clean up imports again

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* feat(Accelerator): Introduce options to control the num_threads and device from API, envvars, CLI.
- Introduce the AcceleratorOptions, AcceleratorDevice and use them to set the device where the models run.
- Introduce the accelerator_utils with function to decide the device and resolve the AUTO setting.
- Refactor the way how the docling-ibm-models are called to match the new init signature of models.
- Translate the accelerator options to the specific inputs for third-party models.
- Extend the docling CLI with parameters to set the num_threads and device.
- Add new unit tests.
- Write new example how to use the accelerator options.

* fix: Improve the pydantic objects in the pipeline_options and imports.

Signed-off-by: Nikos Livathinos <nli@zurich.ibm.com>

* fix: TableStructureModel: Refactor the artifacts path to use the new structure for fast/accurate model

Signed-off-by: Nikos Livathinos <nli@zurich.ibm.com>

* Updated test ground-truth

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Updated test ground-truth (again), bugfix for empty layout

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* fix: Do proper check to set the device in EasyOCR, RapidOCR.

Signed-off-by: Nikos Livathinos <nli@zurich.ibm.com>

* fix: Correct the way to set GPU for EasyOCR, RapidOCR

Signed-off-by: Nikos Livathinos <nli@zurich.ibm.com>

* fix: Ocr AccleratorDevice

Signed-off-by: Nikos Livathinos <nli@zurich.ibm.com>

* Merge pull request #556 from DS4SD/cau/layout-processing-improvement

feat: layout processing improvements and bugfixes

* Update lockfile

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Update tests

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Update HF model ref, reset test generate

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Repin to release package versions

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Many layout processing improvements, add document index type

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Update pinnings to docling-core

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Update test GT

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Fix table box snapping

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Fixes for cluster pre-ordering

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Introduce OCR confidence, propagate to orphan in post-processing

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Fix form and key value area groups

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Adjust confidence in EasyOcr

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Roll back CLI changes from main

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Update test GT

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Update docling-core pinning

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Annoying fixes for historical python versions

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Updated test GT for legacy

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Comment cleanup

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

---------

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>
Signed-off-by: Nikos Livathinos <nli@zurich.ibm.com>
Co-authored-by: Nikos Livathinos <nli@zurich.ibm.com>
This commit is contained in:
Christoph Auer
2024-12-17 17:32:24 +01:00
committed by GitHub
parent 00dec7a2f3
commit 60dc852f16
56 changed files with 1659 additions and 1718 deletions

View File

@@ -169,6 +169,8 @@ def to_docling_document(doc_glm, update_name_label=False) -> DoclingDocument:
current_list = None
text = ""
caption_refs = []
item_label = DocItemLabel(pelem["name"])
for caption in obj["captions"]:
text += caption["text"]
@@ -254,12 +256,18 @@ def to_docling_document(doc_glm, update_name_label=False) -> DoclingDocument:
),
)
tbl = doc.add_table(data=tbl_data, prov=prov)
tbl = doc.add_table(data=tbl_data, prov=prov, label=item_label)
tbl.captions.extend(caption_refs)
elif ptype in ["form", "key_value_region"]:
elif ptype in [DocItemLabel.FORM.value, DocItemLabel.KEY_VALUE_REGION.value]:
label = DocItemLabel(ptype)
container_el = doc.add_group(label=GroupLabel.UNSPECIFIED, name=label)
group_label = GroupLabel.UNSPECIFIED
if label == DocItemLabel.FORM:
group_label = GroupLabel.FORM_AREA
elif label == DocItemLabel.KEY_VALUE_REGION:
group_label = GroupLabel.KEY_VALUE_AREA
container_el = doc.add_group(label=group_label)
_add_child_elements(container_el, doc, obj, pelem)

View File

@@ -0,0 +1,666 @@
import bisect
import logging
import sys
from collections import defaultdict
from typing import Dict, List, Set, Tuple
from docling_core.types.doc import DocItemLabel, Size
from rtree import index
from docling.datamodel.base_models import BoundingBox, Cell, Cluster, OcrCell
_log = logging.getLogger(__name__)
class UnionFind:
"""Efficient Union-Find data structure for grouping elements."""
def __init__(self, elements):
self.parent = {elem: elem for elem in elements}
self.rank = {elem: 0 for elem in elements}
def find(self, x):
if self.parent[x] != x:
self.parent[x] = self.find(self.parent[x]) # Path compression
return self.parent[x]
def union(self, x, y):
root_x, root_y = self.find(x), self.find(y)
if root_x == root_y:
return
if self.rank[root_x] > self.rank[root_y]:
self.parent[root_y] = root_x
elif self.rank[root_x] < self.rank[root_y]:
self.parent[root_x] = root_y
else:
self.parent[root_y] = root_x
self.rank[root_x] += 1
def get_groups(self) -> Dict[int, List[int]]:
"""Returns groups as {root: [elements]}."""
groups = defaultdict(list)
for elem in self.parent:
groups[self.find(elem)].append(elem)
return groups
class SpatialClusterIndex:
"""Efficient spatial indexing for clusters using R-tree and interval trees."""
def __init__(self, clusters: List[Cluster]):
p = index.Property()
p.dimension = 2
self.spatial_index = index.Index(properties=p)
self.x_intervals = IntervalTree()
self.y_intervals = IntervalTree()
self.clusters_by_id: Dict[int, Cluster] = {}
for cluster in clusters:
self.add_cluster(cluster)
def add_cluster(self, cluster: Cluster):
bbox = cluster.bbox
self.spatial_index.insert(cluster.id, bbox.as_tuple())
self.x_intervals.insert(bbox.l, bbox.r, cluster.id)
self.y_intervals.insert(bbox.t, bbox.b, cluster.id)
self.clusters_by_id[cluster.id] = cluster
def remove_cluster(self, cluster: Cluster):
self.spatial_index.delete(cluster.id, cluster.bbox.as_tuple())
del self.clusters_by_id[cluster.id]
def find_candidates(self, bbox: BoundingBox) -> Set[int]:
"""Find potential overlapping cluster IDs using all indexes."""
spatial = set(self.spatial_index.intersection(bbox.as_tuple()))
x_candidates = self.x_intervals.find_containing(
bbox.l
) | self.x_intervals.find_containing(bbox.r)
y_candidates = self.y_intervals.find_containing(
bbox.t
) | self.y_intervals.find_containing(bbox.b)
return spatial.union(x_candidates).union(y_candidates)
def check_overlap(
self,
bbox1: BoundingBox,
bbox2: BoundingBox,
overlap_threshold: float,
containment_threshold: float,
) -> bool:
"""Check if two bboxes overlap sufficiently."""
area1, area2 = bbox1.area(), bbox2.area()
if area1 <= 0 or area2 <= 0:
return False
overlap_area = bbox1.intersection_area_with(bbox2)
if overlap_area <= 0:
return False
iou = overlap_area / (area1 + area2 - overlap_area)
containment1 = overlap_area / area1
containment2 = overlap_area / area2
return (
iou > overlap_threshold
or containment1 > containment_threshold
or containment2 > containment_threshold
)
class Interval:
"""Helper class for sortable intervals."""
def __init__(self, min_val: float, max_val: float, id: int):
self.min_val = min_val
self.max_val = max_val
self.id = id
def __lt__(self, other):
if isinstance(other, Interval):
return self.min_val < other.min_val
return self.min_val < other
class IntervalTree:
"""Memory-efficient interval tree for 1D overlap queries."""
def __init__(self):
self.intervals: List[Interval] = [] # Sorted by min_val
def insert(self, min_val: float, max_val: float, id: int):
interval = Interval(min_val, max_val, id)
bisect.insort(self.intervals, interval)
def find_containing(self, point: float) -> Set[int]:
"""Find all intervals containing the point."""
pos = bisect.bisect_left(self.intervals, point)
result = set()
# Check intervals starting before point
for interval in reversed(self.intervals[:pos]):
if interval.min_val <= point <= interval.max_val:
result.add(interval.id)
else:
break
# Check intervals starting at/after point
for interval in self.intervals[pos:]:
if point <= interval.max_val:
if interval.min_val <= point:
result.add(interval.id)
else:
break
return result
class LayoutPostprocessor:
"""Postprocesses layout predictions by cleaning up clusters and mapping cells."""
# Cluster type-specific parameters for overlap resolution
OVERLAP_PARAMS = {
"regular": {"area_threshold": 1.3, "conf_threshold": 0.05},
"picture": {"area_threshold": 2.0, "conf_threshold": 0.3},
"wrapper": {"area_threshold": 2.0, "conf_threshold": 0.2},
}
WRAPPER_TYPES = {
DocItemLabel.FORM,
DocItemLabel.KEY_VALUE_REGION,
DocItemLabel.TABLE,
DocItemLabel.DOCUMENT_INDEX,
}
SPECIAL_TYPES = WRAPPER_TYPES.union({DocItemLabel.PICTURE})
CONFIDENCE_THRESHOLDS = {
DocItemLabel.CAPTION: 0.5,
DocItemLabel.FOOTNOTE: 0.5,
DocItemLabel.FORMULA: 0.5,
DocItemLabel.LIST_ITEM: 0.5,
DocItemLabel.PAGE_FOOTER: 0.5,
DocItemLabel.PAGE_HEADER: 0.5,
DocItemLabel.PICTURE: 0.5,
DocItemLabel.SECTION_HEADER: 0.45,
DocItemLabel.TABLE: 0.5,
DocItemLabel.TEXT: 0.5, # 0.45,
DocItemLabel.TITLE: 0.45,
DocItemLabel.CODE: 0.45,
DocItemLabel.CHECKBOX_SELECTED: 0.45,
DocItemLabel.CHECKBOX_UNSELECTED: 0.45,
DocItemLabel.FORM: 0.45,
DocItemLabel.KEY_VALUE_REGION: 0.45,
DocItemLabel.DOCUMENT_INDEX: 0.45,
}
LABEL_REMAPPING = {
# DocItemLabel.DOCUMENT_INDEX: DocItemLabel.TABLE,
DocItemLabel.TITLE: DocItemLabel.SECTION_HEADER,
}
def __init__(self, cells: List[Cell], clusters: List[Cluster], page_size: Size):
"""Initialize processor with cells and clusters."""
"""Initialize processor with cells and spatial indices."""
self.cells = cells
self.page_size = page_size
self.regular_clusters = [
c for c in clusters if c.label not in self.SPECIAL_TYPES
]
self.special_clusters = [c for c in clusters if c.label in self.SPECIAL_TYPES]
# Build spatial indices once
self.regular_index = SpatialClusterIndex(self.regular_clusters)
self.picture_index = SpatialClusterIndex(
[c for c in self.special_clusters if c.label == DocItemLabel.PICTURE]
)
self.wrapper_index = SpatialClusterIndex(
[c for c in self.special_clusters if c.label in self.WRAPPER_TYPES]
)
def postprocess(self) -> Tuple[List[Cluster], List[Cell]]:
"""Main processing pipeline."""
self.regular_clusters = self._process_regular_clusters()
self.special_clusters = self._process_special_clusters()
# Remove regular clusters that are included in wrappers
contained_ids = {
child.id
for wrapper in self.special_clusters
if wrapper.label in self.SPECIAL_TYPES
for child in wrapper.children
}
self.regular_clusters = [
c for c in self.regular_clusters if c.id not in contained_ids
]
# Combine and sort final clusters
final_clusters = self._sort_clusters(
self.regular_clusters + self.special_clusters, mode="id"
)
for cluster in final_clusters:
cluster.cells = self._sort_cells(cluster.cells)
# Also sort cells in children if any
for child in cluster.children:
child.cells = self._sort_cells(child.cells)
return final_clusters, self.cells
def _process_regular_clusters(self) -> List[Cluster]:
"""Process regular clusters with iterative refinement."""
clusters = [
c
for c in self.regular_clusters
if c.confidence >= self.CONFIDENCE_THRESHOLDS[c.label]
]
# Apply label remapping
for cluster in clusters:
if cluster.label in self.LABEL_REMAPPING:
cluster.label = self.LABEL_REMAPPING[cluster.label]
# Initial cell assignment
clusters = self._assign_cells_to_clusters(clusters)
# Remove clusters with no cells
clusters = [cluster for cluster in clusters if cluster.cells]
# Handle orphaned cells
unassigned = self._find_unassigned_cells(clusters)
if unassigned:
next_id = max((c.id for c in clusters), default=0) + 1
orphan_clusters = []
for i, cell in enumerate(unassigned):
conf = 1.0
if isinstance(cell, OcrCell):
conf = cell.confidence
orphan_clusters.append(
Cluster(
id=next_id + i,
label=DocItemLabel.TEXT,
bbox=cell.bbox,
confidence=conf,
cells=[cell],
)
)
clusters.extend(orphan_clusters)
# Iterative refinement
prev_count = len(clusters) + 1
for _ in range(3): # Maximum 3 iterations
if prev_count == len(clusters):
break
prev_count = len(clusters)
clusters = self._adjust_cluster_bboxes(clusters)
clusters = self._remove_overlapping_clusters(clusters, "regular")
return clusters
def _process_special_clusters(self) -> List[Cluster]:
special_clusters = [
c
for c in self.special_clusters
if c.confidence >= self.CONFIDENCE_THRESHOLDS[c.label]
]
special_clusters = self._handle_cross_type_overlaps(special_clusters)
# Calculate page area from known page size
page_area = self.page_size.width * self.page_size.height
if page_area > 0:
# Filter out full-page pictures
special_clusters = [
cluster
for cluster in special_clusters
if not (
cluster.label == DocItemLabel.PICTURE
and cluster.bbox.area() / page_area > 0.90
)
]
for special in special_clusters:
contained = []
for cluster in self.regular_clusters:
overlap = cluster.bbox.intersection_area_with(special.bbox)
if overlap > 0:
containment = overlap / cluster.bbox.area()
if containment > 0.8:
contained.append(cluster)
if contained:
# Sort contained clusters by minimum cell ID:
contained = self._sort_clusters(contained, mode="id")
special.children = contained
# Adjust bbox only for Form and Key-Value-Region, not Table or Picture
if special.label in [DocItemLabel.FORM, DocItemLabel.KEY_VALUE_REGION]:
special.bbox = BoundingBox(
l=min(c.bbox.l for c in contained),
t=min(c.bbox.t for c in contained),
r=max(c.bbox.r for c in contained),
b=max(c.bbox.b for c in contained),
)
# Collect all cells from children
all_cells = []
for child in contained:
all_cells.extend(child.cells)
special.cells = self._deduplicate_cells(all_cells)
special.cells = self._sort_cells(special.cells)
picture_clusters = [
c for c in special_clusters if c.label == DocItemLabel.PICTURE
]
picture_clusters = self._remove_overlapping_clusters(
picture_clusters, "picture"
)
wrapper_clusters = [
c for c in special_clusters if c.label in self.WRAPPER_TYPES
]
wrapper_clusters = self._remove_overlapping_clusters(
wrapper_clusters, "wrapper"
)
return picture_clusters + wrapper_clusters
def _handle_cross_type_overlaps(self, special_clusters) -> List[Cluster]:
"""Handle overlaps between regular and wrapper clusters before child assignment.
In particular, KEY_VALUE_REGION proposals that are almost identical to a TABLE
should be removed.
"""
wrappers_to_remove = set()
for wrapper in special_clusters:
if wrapper.label not in self.WRAPPER_TYPES:
continue # only treat KEY_VALUE_REGION for now.
for regular in self.regular_clusters:
if regular.label == DocItemLabel.TABLE:
# Calculate overlap
overlap = regular.bbox.intersection_area_with(wrapper.bbox)
wrapper_area = wrapper.bbox.area()
overlap_ratio = overlap / wrapper_area
conf_diff = wrapper.confidence - regular.confidence
# If wrapper is mostly overlapping with a TABLE, remove the wrapper
if (
overlap_ratio > 0.9 and conf_diff < 0.1
): # self.OVERLAP_PARAMS["wrapper"]["conf_threshold"]): # 80% overlap threshold
wrappers_to_remove.add(wrapper.id)
break
# Filter out the identified wrappers
special_clusters = [
cluster
for cluster in special_clusters
if cluster.id not in wrappers_to_remove
]
return special_clusters
def _should_prefer_cluster(
self, candidate: Cluster, other: Cluster, params: dict
) -> bool:
"""Determine if candidate cluster should be preferred over other cluster based on rules.
Returns True if candidate should be preferred, False if not."""
# Rule 1: LIST_ITEM vs TEXT
if (
candidate.label == DocItemLabel.LIST_ITEM
and other.label == DocItemLabel.TEXT
):
# Check if areas are similar (within 20% of each other)
area_ratio = candidate.bbox.area() / other.bbox.area()
area_similarity = abs(1 - area_ratio) < 0.2
if area_similarity:
return True
# Rule 2: CODE vs others
if candidate.label == DocItemLabel.CODE:
# Calculate how much of the other cluster is contained within the CODE cluster
overlap = other.bbox.intersection_area_with(candidate.bbox)
containment = overlap / other.bbox.area()
if containment > 0.8: # other is 80% contained within CODE
return True
# If no label-based rules matched, fall back to area/confidence thresholds
area_ratio = candidate.bbox.area() / other.bbox.area()
conf_diff = other.confidence - candidate.confidence
if (
area_ratio <= params["area_threshold"]
and conf_diff > params["conf_threshold"]
):
return False
return True # Default to keeping candidate if no rules triggered rejection
def _select_best_cluster_from_group(
self,
group_clusters: List[Cluster],
params: dict,
) -> Cluster:
"""Select best cluster from a group of overlapping clusters based on all rules."""
current_best = None
for candidate in group_clusters:
should_select = True
for other in group_clusters:
if other == candidate:
continue
if not self._should_prefer_cluster(candidate, other, params):
should_select = False
break
if should_select:
if current_best is None:
current_best = candidate
else:
# If both clusters pass rules, prefer the larger one unless confidence differs significantly
if (
candidate.bbox.area() > current_best.bbox.area()
and current_best.confidence - candidate.confidence
<= params["conf_threshold"]
):
current_best = candidate
return current_best if current_best else group_clusters[0]
def _remove_overlapping_clusters(
self,
clusters: List[Cluster],
cluster_type: str,
overlap_threshold: float = 0.8,
containment_threshold: float = 0.8,
) -> List[Cluster]:
if not clusters:
return []
spatial_index = (
self.regular_index
if cluster_type == "regular"
else self.picture_index if cluster_type == "picture" else self.wrapper_index
)
# Map of currently valid clusters
valid_clusters = {c.id: c for c in clusters}
uf = UnionFind(valid_clusters.keys())
params = self.OVERLAP_PARAMS[cluster_type]
for cluster in clusters:
candidates = spatial_index.find_candidates(cluster.bbox)
candidates &= valid_clusters.keys() # Only keep existing candidates
candidates.discard(cluster.id)
for other_id in candidates:
if spatial_index.check_overlap(
cluster.bbox,
valid_clusters[other_id].bbox,
overlap_threshold,
containment_threshold,
):
uf.union(cluster.id, other_id)
result = []
for group in uf.get_groups().values():
if len(group) == 1:
result.append(valid_clusters[group[0]])
continue
group_clusters = [valid_clusters[cid] for cid in group]
best = self._select_best_cluster_from_group(group_clusters, params)
# Simple cell merging - no special cases
for cluster in group_clusters:
if cluster != best:
best.cells.extend(cluster.cells)
best.cells = self._deduplicate_cells(best.cells)
best.cells = self._sort_cells(best.cells)
result.append(best)
return result
def _select_best_cluster(
self,
clusters: List[Cluster],
area_threshold: float,
conf_threshold: float,
) -> Cluster:
"""Iteratively select best cluster based on area and confidence thresholds."""
current_best = None
for candidate in clusters:
should_select = True
for other in clusters:
if other == candidate:
continue
area_ratio = candidate.bbox.area() / other.bbox.area()
conf_diff = other.confidence - candidate.confidence
if area_ratio <= area_threshold and conf_diff > conf_threshold:
should_select = False
break
if should_select:
if current_best is None or (
candidate.bbox.area() > current_best.bbox.area()
and current_best.confidence - candidate.confidence <= conf_threshold
):
current_best = candidate
return current_best if current_best else clusters[0]
def _deduplicate_cells(self, cells: List[Cell]) -> List[Cell]:
"""Ensure each cell appears only once, maintaining order of first appearance."""
seen_ids = set()
unique_cells = []
for cell in cells:
if cell.id not in seen_ids:
seen_ids.add(cell.id)
unique_cells.append(cell)
return unique_cells
def _assign_cells_to_clusters(
self, clusters: List[Cluster], min_overlap: float = 0.2
) -> List[Cluster]:
"""Assign cells to best overlapping cluster."""
for cluster in clusters:
cluster.cells = []
for cell in self.cells:
if not cell.text.strip():
continue
best_overlap = min_overlap
best_cluster = None
for cluster in clusters:
if cell.bbox.area() <= 0:
continue
overlap = cell.bbox.intersection_area_with(cluster.bbox)
overlap_ratio = overlap / cell.bbox.area()
if overlap_ratio > best_overlap:
best_overlap = overlap_ratio
best_cluster = cluster
if best_cluster is not None:
best_cluster.cells.append(cell)
# Deduplicate cells in each cluster after assignment
for cluster in clusters:
cluster.cells = self._deduplicate_cells(cluster.cells)
return clusters
def _find_unassigned_cells(self, clusters: List[Cluster]) -> List[Cell]:
"""Find cells not assigned to any cluster."""
assigned = {cell.id for cluster in clusters for cell in cluster.cells}
return [
cell for cell in self.cells if cell.id not in assigned and cell.text.strip()
]
def _adjust_cluster_bboxes(self, clusters: List[Cluster]) -> List[Cluster]:
"""Adjust cluster bounding boxes to contain their cells."""
for cluster in clusters:
if not cluster.cells:
continue
cells_bbox = BoundingBox(
l=min(cell.bbox.l for cell in cluster.cells),
t=min(cell.bbox.t for cell in cluster.cells),
r=max(cell.bbox.r for cell in cluster.cells),
b=max(cell.bbox.b for cell in cluster.cells),
)
if cluster.label == DocItemLabel.TABLE:
# For tables, take union of current bbox and cells bbox
cluster.bbox = BoundingBox(
l=min(cluster.bbox.l, cells_bbox.l),
t=min(cluster.bbox.t, cells_bbox.t),
r=max(cluster.bbox.r, cells_bbox.r),
b=max(cluster.bbox.b, cells_bbox.b),
)
else:
cluster.bbox = cells_bbox
return clusters
def _sort_cells(self, cells: List[Cell]) -> List[Cell]:
"""Sort cells in native reading order."""
return sorted(cells, key=lambda c: (c.id))
def _sort_clusters(
self, clusters: List[Cluster], mode: str = "id"
) -> List[Cluster]:
"""Sort clusters in reading order (top-to-bottom, left-to-right)."""
if mode == "id": # sort in the order the cells are printed in the PDF.
return sorted(
clusters,
key=lambda cluster: (
(
min(cell.id for cell in cluster.cells)
if cluster.cells
else sys.maxsize
),
cluster.bbox.t,
cluster.bbox.l,
),
)
elif mode == "tblr": # Sort top-to-bottom, then left-to-right ("row first")
return sorted(
clusters, key=lambda cluster: (cluster.bbox.t, cluster.bbox.l)
)
elif mode == "lrtb": # Sort left-to-right, then top-to-bottom ("column first")
return sorted(
clusters, key=lambda cluster: (cluster.bbox.l, cluster.bbox.t)
)
else:
return clusters

View File

@@ -1,812 +0,0 @@
import copy
import logging
import networkx as nx
from docling_core.types.doc import DocItemLabel
logger = logging.getLogger("layout_utils")
## -------------------------------
## Geometric helper functions
## The coordinates grow left to right, and bottom to top.
## The bounding box list elements 0 to 3 are x_left, y_bottom, x_right, y_top.
def area(bbox):
return (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
def contains(bbox_i, bbox_j):
## Returns True if bbox_i contains bbox_j, else False
return (
bbox_i[0] <= bbox_j[0]
and bbox_i[1] <= bbox_j[1]
and bbox_i[2] >= bbox_j[2]
and bbox_i[3] >= bbox_j[3]
)
def is_intersecting(bbox_i, bbox_j):
return not (
bbox_i[2] < bbox_j[0]
or bbox_i[0] > bbox_j[2]
or bbox_i[3] < bbox_j[1]
or bbox_i[1] > bbox_j[3]
)
def bb_iou(boxA, boxB):
# determine the (x, y)-coordinates of the intersection rectangle
xA = max(boxA[0], boxB[0])
yA = max(boxA[1], boxB[1])
xB = min(boxA[2], boxB[2])
yB = min(boxA[3], boxB[3])
# compute the area of intersection rectangle
interArea = max(0, xB - xA + 1) * max(0, yB - yA + 1)
# compute the area of both the prediction and ground-truth
# rectangles
boxAArea = (boxA[2] - boxA[0] + 1) * (boxA[3] - boxA[1] + 1)
boxBArea = (boxB[2] - boxB[0] + 1) * (boxB[3] - boxB[1] + 1)
# compute the intersection over union by taking the intersection
# area and dividing it by the sum of prediction + ground-truth
# areas - the interesection area
iou = interArea / float(boxAArea + boxBArea - interArea)
# return the intersection over union value
return iou
def compute_intersection(bbox_i, bbox_j):
## Returns the size of the intersection area of the two boxes
if not is_intersecting(bbox_i, bbox_j):
return 0
## Determine the (x, y)-coordinates of the intersection rectangle:
xA = max(bbox_i[0], bbox_j[0])
yA = max(bbox_i[1], bbox_j[1])
xB = min(bbox_i[2], bbox_j[2])
yB = min(bbox_i[3], bbox_j[3])
## Compute the area of intersection rectangle:
interArea = (xB - xA) * (yB - yA)
if interArea < 0:
logger.debug("Warning: Negative intersection detected!")
return 0
return interArea
def surrounding(bbox_i, bbox_j):
## Computes minimal box that contains both input boxes
sbox = []
sbox.append(min(bbox_i[0], bbox_j[0]))
sbox.append(min(bbox_i[1], bbox_j[1]))
sbox.append(max(bbox_i[2], bbox_j[2]))
sbox.append(max(bbox_i[3], bbox_j[3]))
return sbox
def surrounding_list(bbox_list):
## Computes minimal box that contains all boxes in the input list
## The list should be non-empty, but just in case it's not:
if len(bbox_list) == 0:
sbox = [0, 0, 0, 0]
else:
sbox = []
sbox.append(min([bbox[0] for bbox in bbox_list]))
sbox.append(min([bbox[1] for bbox in bbox_list]))
sbox.append(max([bbox[2] for bbox in bbox_list]))
sbox.append(max([bbox[3] for bbox in bbox_list]))
return sbox
def vertical_overlap(bboxA, bboxB):
## bbox[1] is the lower bound, bbox[3] the upper bound (larger number)
if bboxB[3] < bboxA[1]: ## B below A
return False
elif bboxA[3] < bboxB[1]: ## A below B
return False
else:
return True
def vertical_overlap_fraction(bboxA, bboxB):
## Returns the vertical overlap as fraction of the lower bbox height.
## bbox[1] is the lower bound, bbox[3] the upper bound (larger number)
## Height 0 is permitted in the input.
heightA = bboxA[3] - bboxA[1]
heightB = bboxB[3] - bboxB[1]
min_height = min(heightA, heightB)
if bboxA[3] >= bboxB[3]: ## A starts higher or equal
if (
bboxA[1] <= bboxB[1]
): ## B is completely in A; this can include height of B = 0:
fraction = 1
else:
overlap = max(bboxB[3] - bboxA[1], 0)
fraction = overlap / max(min_height, 0.001)
else:
if (
bboxB[1] <= bboxA[1]
): ## A is completely in B; this can include height of A = 0:
fraction = 1
else:
overlap = max(bboxA[3] - bboxB[1], 0)
fraction = overlap / max(min_height, 0.001)
return fraction
## -------------------------------
## Cluster-and-cell relations
def compute_enclosed_cells(
cluster_bbox, raw_cells, min_cell_intersection_with_cluster=0.2
):
cells_in_cluster = []
cells_in_cluster_int = []
for ix, cell in enumerate(raw_cells):
cell_bbox = cell["bbox"]
intersection = compute_intersection(cell_bbox, cluster_bbox)
frac_area = area(cell_bbox) * min_cell_intersection_with_cluster
if (
intersection > frac_area and frac_area > 0
): # intersect > certain fraction of cell
cells_in_cluster.append(ix)
cells_in_cluster_int.append(intersection)
elif contains(
cluster_bbox,
[cell_bbox[0] + 3, cell_bbox[1] + 3, cell_bbox[2] - 3, cell_bbox[3] - 3],
):
cells_in_cluster.append(ix)
return cells_in_cluster, cells_in_cluster_int
def find_clusters_around_cells(cell_count, clusters):
## Per raw cell, find to which clusters it belongs.
## Return list of these indices in the raw-cell order.
clusters_around_cells = [[] for _ in range(cell_count)]
for cl_ix, cluster in enumerate(clusters):
for ix in cluster["cell_ids"]:
clusters_around_cells[ix].append(cl_ix)
return clusters_around_cells
def find_cell_index(raw_ix, cell_array):
## "raw_ix" is a rawcell_id.
## "cell_array" has the structure of an (annotation) cells array.
## Returns index of cell in cell_array that has this rawcell_id.
for ix, cell in enumerate(cell_array):
if cell["rawcell_id"] == raw_ix:
return ix
def find_cell_indices(cluster, cell_array):
## "cluster" must have the structure as in a clusters array in a prediction,
## "cell_array" that of a cells array.
## Returns list of indices of cells in cell_array that have the rawcell_ids as in the cluster,
## in the order of the rawcell_ids.
result = []
for raw_ix in sorted(cluster["cell_ids"]):
## Find the cell with this rawcell_id (if any)
for ix, cell in enumerate(cell_array):
if cell["rawcell_id"] == raw_ix:
result.append(ix)
return result
def find_first_cell_index(cluster, cell_array):
## "cluster" must be a dict with key "cell_ids"; it can also be a line.
## "cell_array" has the structure of a cells array in an annotation.
## Returns index of cell in cell_array that has the lowest rawcell_id from the cluster.
result = [] ## We keep it a list as it can be empty (picture without text cells)
if len(cluster["cell_ids"]) == 0:
return result
raw_ix = min(cluster["cell_ids"])
## Find the cell with this rawcell_id (if any)
for ix, cell in enumerate(cell_array):
if cell["rawcell_id"] == raw_ix:
result.append(ix)
break ## One is enough; should be only one anyway.
if result == []:
logger.debug(
" Warning: Raw cell " + str(raw_ix) + " not found in annotation cells"
)
return result
## -------------------------------
## Cluster labels and text
def relabel_cluster(cluster, cl_ix, new_label, target_pred):
## "cluster" must have the structure as in a clusters array in a prediction,
## "cl_ix" is its index in target_pred,
## "new_label" is the intended new label,
## "target_pred" is the entire current target prediction.
## Sets label on the cluster itself, and on the cells in the target_pred.
## Returns new_label so that also the cl_label variable in the main code is easily set.
target_pred["clusters"][cl_ix]["type"] = new_label
cluster_target_cells = find_cell_indices(cluster, target_pred["cells"])
for ix in cluster_target_cells:
target_pred["cells"][ix]["label"] = new_label
return new_label
def find_cluster_text(cluster, raw_cells):
## "cluster" must be a dict with "cell_ids"; it can also be a line.
## "raw_cells" must have the format of item["raw"]["cells"]
## Returns the text of the cluster, with blanks between the cell contents
## (which seem to be words or phrases without starting or trailing blanks).
## Note that in formulas, this may give a lot more blanks than originally
cluster_text = ""
for raw_ix in sorted(cluster["cell_ids"]):
cluster_text = cluster_text + raw_cells[raw_ix]["text"] + " "
return cluster_text.rstrip()
def find_cluster_text_without_blanks(cluster, raw_cells):
## "cluster" must be a dict with "cell_ids"; it can also be a line.
## "raw_cells" must have the format of item["raw"]["cells"]
## Returns the text of the cluster, without blanks between the cell contents
## Interesting in formula analysis.
cluster_text = ""
for raw_ix in sorted(cluster["cell_ids"]):
cluster_text = cluster_text + raw_cells[raw_ix]["text"]
return cluster_text.rstrip()
## -------------------------------
## Clusters and lines
## (Most line-oriented functions are only needed in TextAnalysisGivenClusters,
## but this one also in FormulaAnalysis)
def build_cluster_from_lines(lines, label, id):
## Lines must be a non-empty list of dicts (lines) with elements "cell_ids" and "bbox"
## (There is no condition that they are really geometrically lines)
## A cluster in standard format is returned with given label and id
local_lines = copy.deepcopy(
lines
) ## without this, it changes "lines" also outside this function
first_line = local_lines.pop(0)
cluster = {
"id": id,
"type": label,
"cell_ids": first_line["cell_ids"],
"bbox": first_line["bbox"],
"confidence": 0,
"created_by": "merged_cells",
}
confidence = 0
counter = 0
for line in local_lines:
new_cell_ids = cluster["cell_ids"] + line["cell_ids"]
cluster["cell_ids"] = new_cell_ids
cluster["bbox"] = surrounding(cluster["bbox"], line["bbox"])
counter += 1
confidence += line["confidence"]
confidence = confidence / counter
cluster["confidence"] = confidence
return cluster
## -------------------------------
## Reading order
def produce_reading_order(clusters, cluster_sort_type, cell_sort_type, sort_ids):
## In:
## Clusters: list as in predictions.
## cluster_sort_type: string, currently only "raw_cells".
## cell_sort_type: string, currently only "raw_cells".
## sort_ids: Boolean, whether the cluster ids should be adapted to their new position
## Out: Another clusters list, sorted according to the type.
logger.debug("---- Start cluster sorting ------")
if cell_sort_type == "raw_cell_ids":
for cl in clusters:
sorted_cell_ids = sorted(cl["cell_ids"])
cl["cell_ids"] = sorted_cell_ids
else:
logger.debug(
"Unknown cell_sort_type `"
+ cell_sort_type
+ "`, no cell sorting will happen."
)
if cluster_sort_type == "raw_cell_ids":
clusters_with_cells = [cl for cl in clusters if cl["cell_ids"] != []]
clusters_without_cells = [cl for cl in clusters if cl["cell_ids"] == []]
logger.debug(
"Clusters with cells: " + str([cl["id"] for cl in clusters_with_cells])
)
logger.debug(
" Their first cell ids: "
+ str([cl["cell_ids"][0] for cl in clusters_with_cells])
)
logger.debug(
"Clusters without cells: "
+ str([cl["id"] for cl in clusters_without_cells])
)
clusters_with_cells_sorted = sorted(
clusters_with_cells, key=lambda cluster: cluster["cell_ids"][0]
)
logger.debug(
" First cell ids after sorting: "
+ str([cl["cell_ids"][0] for cl in clusters_with_cells_sorted])
)
sorted_clusters = clusters_with_cells_sorted + clusters_without_cells
else:
logger.debug(
"Unknown cluster_sort_type: `"
+ cluster_sort_type
+ "`, no cluster sorting will happen."
)
if sort_ids:
for i, cl in enumerate(sorted_clusters):
cl["id"] = i
return sorted_clusters
## -------------------------------
## Line Splitting
def sort_cells_horizontal(line_cell_ids, raw_cells):
## "line_cells" should be a non-empty list of (raw) cell_ids
## "raw_cells" has the structure of item["raw"]["cells"].
## Sorts the cells in the line by x0 (left start).
new_line_cell_ids = sorted(
line_cell_ids, key=lambda cell_id: raw_cells[cell_id]["bbox"][0]
)
return new_line_cell_ids
def adapt_bboxes(raw_cells, clusters, orphan_cell_indices):
new_clusters = []
for ix, cluster in enumerate(clusters):
new_cluster = copy.deepcopy(cluster)
logger.debug(
"Treating cluster " + str(ix) + ", type " + str(new_cluster["type"])
)
logger.debug(" with cells: " + str(new_cluster["cell_ids"]))
if len(cluster["cell_ids"]) == 0 and cluster["type"] != DocItemLabel.PICTURE:
logger.debug(" Empty non-picture, removed")
continue ## Skip this former cluster, now without cells.
new_bbox = adapt_bbox(raw_cells, new_cluster, orphan_cell_indices)
new_cluster["bbox"] = new_bbox
new_clusters.append(new_cluster)
return new_clusters
def adapt_bbox(raw_cells, cluster, orphan_cell_indices):
if not (cluster["type"] in [DocItemLabel.TABLE, DocItemLabel.PICTURE]):
## A text-like cluster. The bbox only needs to be around the text cells:
logger.debug(" Initial bbox: " + str(cluster["bbox"]))
new_bbox = surrounding_list(
[raw_cells[cid]["bbox"] for cid in cluster["cell_ids"]]
)
logger.debug(" New bounding box:" + str(new_bbox))
if cluster["type"] == DocItemLabel.PICTURE:
## We only make the bbox completely comprise included text cells:
logger.debug(" Picture")
if len(cluster["cell_ids"]) != 0:
min_bbox = surrounding_list(
[raw_cells[cid]["bbox"] for cid in cluster["cell_ids"]]
)
logger.debug(" Minimum bbox: " + str(min_bbox))
logger.debug(" Initial bbox: " + str(cluster["bbox"]))
new_bbox = surrounding(min_bbox, cluster["bbox"])
logger.debug(" New bbox (initial and text cells): " + str(new_bbox))
else:
logger.debug(" without text cells, no change.")
new_bbox = cluster["bbox"]
else: ## A table
## At least we have to keep the included text cells, and we make the bbox completely comprise them
min_bbox = surrounding_list(
[raw_cells[cid]["bbox"] for cid in cluster["cell_ids"]]
)
logger.debug(" Minimum bbox: " + str(min_bbox))
logger.debug(" Initial bbox: " + str(cluster["bbox"]))
new_bbox = surrounding(min_bbox, cluster["bbox"])
logger.debug(" Possibly increased bbox: " + str(new_bbox))
## Now we look which non-belonging cells are covered.
## (To decrease dependencies, we don't make use of which cells we actually removed.)
## We don't worry about orphan cells, those could still be added to the table.
enclosed_cells = compute_enclosed_cells(
new_bbox, raw_cells, min_cell_intersection_with_cluster=0.3
)[0]
additional_cells = set(enclosed_cells) - set(cluster["cell_ids"])
logger.debug(
" Additional cells enclosed by Table bbox: " + str(additional_cells)
)
spurious_cells = additional_cells - set(orphan_cell_indices)
logger.debug(
" Spurious cells enclosed by Table bbox (additional minus orphans): "
+ str(spurious_cells)
)
if len(spurious_cells) == 0:
return new_bbox
## Else we want to keep as much as possible, e.g., grid lines, but not the spurious cells if we can.
## We initialize possible cuts with the current bbox.
left_cut = new_bbox[0]
right_cut = new_bbox[2]
upper_cut = new_bbox[3]
lower_cut = new_bbox[1]
for cell_ix in spurious_cells:
cell = raw_cells[cell_ix]
# logger.debug(" Spurious cell bbox: " + str(cell["bbox"]))
is_left = cell["bbox"][2] < min_bbox[0]
is_right = cell["bbox"][0] > min_bbox[2]
is_above = cell["bbox"][1] > min_bbox[3]
is_below = cell["bbox"][3] < min_bbox[1]
# logger.debug(" Left, right, above, below? " + str([is_left, is_right, is_above, is_below]))
if is_left:
if cell["bbox"][2] > left_cut:
## We move the left cut to exclude this cell:
left_cut = cell["bbox"][2]
if is_right:
if cell["bbox"][0] < right_cut:
## We move the right cut to exclude this cell:
right_cut = cell["bbox"][0]
if is_above:
if cell["bbox"][1] < upper_cut:
## We move the upper cut to exclude this cell:
upper_cut = cell["bbox"][1]
if is_below:
if cell["bbox"][3] > lower_cut:
## We move the left cut to exclude this cell:
lower_cut = cell["bbox"][3]
# logger.debug(" Current bbox: " + str([left_cut, lower_cut, right_cut, upper_cut]))
new_bbox = [left_cut, lower_cut, right_cut, upper_cut]
logger.debug(" Final bbox: " + str(new_bbox))
return new_bbox
def remove_cluster_duplicates_by_conf(cluster_predictions, threshold=0.5):
DuplicateDeletedClusterIDs = []
for cluster_1 in cluster_predictions:
for cluster_2 in cluster_predictions:
if cluster_1["id"] != cluster_2["id"]:
if_conf = False
if cluster_1["confidence"] > cluster_2["confidence"]:
if_conf = True
if if_conf == True:
if bb_iou(cluster_1["bbox"], cluster_2["bbox"]) > threshold:
DuplicateDeletedClusterIDs.append(cluster_2["id"])
elif contains(
cluster_1["bbox"],
[
cluster_2["bbox"][0] + 3,
cluster_2["bbox"][1] + 3,
cluster_2["bbox"][2] - 3,
cluster_2["bbox"][3] - 3,
],
):
DuplicateDeletedClusterIDs.append(cluster_2["id"])
DuplicateDeletedClusterIDs = list(set(DuplicateDeletedClusterIDs))
for cl_id in DuplicateDeletedClusterIDs:
for cluster in cluster_predictions:
if cl_id == cluster["id"]:
cluster_predictions.remove(cluster)
return cluster_predictions
# Assign orphan cells by a low confidence prediction that is below the assigned confidence
def assign_orphans_with_low_conf_pred(
cluster_predictions, cluster_predictions_low, raw_cells, orphan_cell_indices
):
for orph_id in orphan_cell_indices:
cluster_chosen = {}
iou_thresh = 0.05
confidence = 0.05
# Loop over all predictions, and find the one with the highest IOU, and confidence
for cluster in cluster_predictions_low:
calc_iou = bb_iou(cluster["bbox"], raw_cells[orph_id]["bbox"])
cluster_area = (cluster["bbox"][3] - cluster["bbox"][1]) * (
cluster["bbox"][2] - cluster["bbox"][0]
)
cell_area = (
raw_cells[orph_id]["bbox"][3] - raw_cells[orph_id]["bbox"][1]
) * (raw_cells[orph_id]["bbox"][2] - raw_cells[orph_id]["bbox"][0])
if (
(iou_thresh < calc_iou)
and (cluster["confidence"] > confidence)
and (cell_area * 3 > cluster_area)
):
cluster_chosen = cluster
iou_thresh = calc_iou
confidence = cluster["confidence"]
# If a candidate is found, assign to it the PDF cell ids, and tag that it was created by this function for tracking
if iou_thresh != 0.05 and confidence != 0.05:
cluster_chosen["cell_ids"].append(orph_id)
cluster_chosen["created_by"] = "orph_low_conf"
cluster_predictions.append(cluster_chosen)
orphan_cell_indices.remove(orph_id)
return cluster_predictions, orphan_cell_indices
def remove_ambigous_pdf_cell_by_conf(cluster_predictions, raw_cells, amb_cell_idxs):
for amb_cell_id in amb_cell_idxs:
highest_conf = 0
highest_bbox_iou = 0
cluster_chosen = None
problamatic_clusters = []
# Find clusters in question
for cluster in cluster_predictions:
if amb_cell_id in cluster["cell_ids"]:
problamatic_clusters.append(amb_cell_id)
# If the cell_id is in a cluster of high conf, and highest iou score, and smaller in area
bbox_iou_val = bb_iou(cluster["bbox"], raw_cells[amb_cell_id]["bbox"])
if (
cluster["confidence"] > highest_conf
and bbox_iou_val > highest_bbox_iou
):
cluster_chosen = cluster
highest_conf = cluster["confidence"]
highest_bbox_iou = bbox_iou_val
if cluster["id"] in problamatic_clusters:
problamatic_clusters.remove(cluster["id"])
# now remove the assigning of cell id from lower confidence, and threshold
for cluster in cluster_predictions:
for prob_amb_id in problamatic_clusters:
if prob_amb_id in cluster["cell_ids"]:
cluster["cell_ids"].remove(prob_amb_id)
amb_cell_idxs.remove(amb_cell_id)
return cluster_predictions, amb_cell_idxs
def ranges(nums):
# Find if consecutive numbers exist within pdf cells
# Used to remove line numbers for review manuscripts
nums = sorted(set(nums))
gaps = [[s, e] for s, e in zip(nums, nums[1:]) if s + 1 < e]
edges = iter(nums[:1] + sum(gaps, []) + nums[-1:])
return list(zip(edges, edges))
def set_orphan_as_text(
cluster_predictions, cluster_predictions_low, raw_cells, orphan_cell_indices
):
max_id = -1
figures = []
for cluster in cluster_predictions:
if cluster["type"] == DocItemLabel.PICTURE:
figures.append(cluster)
if cluster["id"] > max_id:
max_id = cluster["id"]
max_id += 1
lines_detector = False
content_of_orphans = []
for orph_id in orphan_cell_indices:
orph_cell = raw_cells[orph_id]
content_of_orphans.append(raw_cells[orph_id]["text"])
fil_content_of_orphans = []
for cell_content in content_of_orphans:
if cell_content.isnumeric():
try:
num = int(cell_content)
fil_content_of_orphans.append(num)
except ValueError: # ignore the cell
pass
# line_orphans = []
# Check if there are more than 2 pdf orphan cells, if there are more than 2,
# then check between the orphan cells if they are numeric
# and if they are a consecutive series of numbers (using ranges function) to decide
if len(fil_content_of_orphans) > 2:
out_ranges = ranges(fil_content_of_orphans)
if len(out_ranges) > 1:
cnt_range = 0
for ranges_ in out_ranges:
if ranges_[0] != ranges_[1]:
# If there are more than 75 (half the total line number of a review manuscript page)
# decide that there are line numbers on page to be ignored.
if len(list(range(ranges_[0], ranges_[1]))) > 75:
lines_detector = True
# line_orphans = line_orphans + list(range(ranges_[0], ranges_[1]))
for orph_id in orphan_cell_indices:
orph_cell = raw_cells[orph_id]
if bool(orph_cell["text"] and not orph_cell["text"].isspace()):
fig_flag = False
# Do not assign orphan cells if they are inside a figure
for fig in figures:
if contains(fig["bbox"], orph_cell["bbox"]):
fig_flag = True
# if fig_flag == False and raw_cells[orph_id]["text"] not in line_orphans:
if fig_flag == False and lines_detector == False:
# get class from low confidence detections if not set as text:
class_type = DocItemLabel.TEXT
for cluster in cluster_predictions_low:
intersection = compute_intersection(
orph_cell["bbox"], cluster["bbox"]
)
class_type = DocItemLabel.TEXT
if (
cluster["confidence"] > 0.1
and bb_iou(cluster["bbox"], orph_cell["bbox"]) > 0.4
):
class_type = cluster["type"]
elif contains(
cluster["bbox"],
[
orph_cell["bbox"][0] + 3,
orph_cell["bbox"][1] + 3,
orph_cell["bbox"][2] - 3,
orph_cell["bbox"][3] - 3,
],
):
class_type = cluster["type"]
elif intersection > area(orph_cell["bbox"]) * 0.2:
class_type = cluster["type"]
new_cluster = {
"id": max_id,
"bbox": orph_cell["bbox"],
"type": class_type,
"cell_ids": [orph_id],
"confidence": -1,
"created_by": "orphan_default",
}
max_id += 1
cluster_predictions.append(new_cluster)
return cluster_predictions, orphan_cell_indices
def merge_cells(cluster_predictions):
# Using graph component creates clusters if orphan cells are touching or too close.
G = nx.Graph()
for cluster in cluster_predictions:
if cluster["created_by"] == "orphan_default":
G.add_node(cluster["id"])
for cluster_1 in cluster_predictions:
for cluster_2 in cluster_predictions:
if (
cluster_1["id"] != cluster_2["id"]
and cluster_2["created_by"] == "orphan_default"
and cluster_1["created_by"] == "orphan_default"
):
cl1 = copy.deepcopy(cluster_1["bbox"])
cl2 = copy.deepcopy(cluster_2["bbox"])
cl1[0] = cl1[0] - 2
cl1[1] = cl1[1] - 2
cl1[2] = cl1[2] + 2
cl1[3] = cl1[3] + 2
cl2[0] = cl2[0] - 2
cl2[1] = cl2[1] - 2
cl2[2] = cl2[2] + 2
cl2[3] = cl2[3] + 2
if is_intersecting(cl1, cl2):
G.add_edge(cluster_1["id"], cluster_2["id"])
component = sorted(map(sorted, nx.k_edge_components(G, k=1)))
max_id = -1
for cluster_1 in cluster_predictions:
if cluster_1["id"] > max_id:
max_id = cluster_1["id"]
for nodes in component:
if len(nodes) > 1:
max_id += 1
lines = []
for node in nodes:
for cluster in cluster_predictions:
if cluster["id"] == node:
lines.append(cluster)
cluster_predictions.remove(cluster)
new_merged_cluster = build_cluster_from_lines(
lines, DocItemLabel.TEXT, max_id
)
cluster_predictions.append(new_merged_cluster)
return cluster_predictions
def clean_up_clusters(
cluster_predictions,
raw_cells,
merge_cells=False,
img_table=False,
one_cell_table=False,
):
DuplicateDeletedClusterIDs = []
for cluster_1 in cluster_predictions:
for cluster_2 in cluster_predictions:
if cluster_1["id"] != cluster_2["id"]:
# remove any artifcats created by merging clusters
if merge_cells == True:
if contains(
cluster_1["bbox"],
[
cluster_2["bbox"][0] + 3,
cluster_2["bbox"][1] + 3,
cluster_2["bbox"][2] - 3,
cluster_2["bbox"][3] - 3,
],
):
cluster_1["cell_ids"] = (
cluster_1["cell_ids"] + cluster_2["cell_ids"]
)
DuplicateDeletedClusterIDs.append(cluster_2["id"])
# remove clusters that might appear inside tables, or images (such as pdf cells in graphs)
elif img_table == True:
if (
cluster_1["type"] == DocItemLabel.TEXT
and cluster_2["type"] == DocItemLabel.PICTURE
or cluster_2["type"] == DocItemLabel.TABLE
):
if bb_iou(cluster_1["bbox"], cluster_2["bbox"]) > 0.5:
DuplicateDeletedClusterIDs.append(cluster_1["id"])
elif contains(
[
cluster_2["bbox"][0] - 3,
cluster_2["bbox"][1] - 3,
cluster_2["bbox"][2] + 3,
cluster_2["bbox"][3] + 3,
],
cluster_1["bbox"],
):
DuplicateDeletedClusterIDs.append(cluster_1["id"])
# remove tables that have one pdf cell
if one_cell_table == True:
if (
cluster_1["type"] == DocItemLabel.TABLE
and len(cluster_1["cell_ids"]) < 2
):
DuplicateDeletedClusterIDs.append(cluster_1["id"])
DuplicateDeletedClusterIDs = list(set(DuplicateDeletedClusterIDs))
for cl_id in DuplicateDeletedClusterIDs:
for cluster in cluster_predictions:
if cl_id == cluster["id"]:
cluster_predictions.remove(cluster)
return cluster_predictions
def assigning_cell_ids_to_clusters(clusters, raw_cells, threshold):
for cluster in clusters:
cells_in_cluster, _ = compute_enclosed_cells(
cluster["bbox"], raw_cells, min_cell_intersection_with_cluster=threshold
)
cluster["cell_ids"] = cells_in_cluster
## These cell_ids are ids of the raw cells.
## They are often, but not always, the same as the "id" or the index of the "cells" list in a prediction.
return clusters
# Creates a map of cell_id->cluster_id
def cell_id_state_map(clusters, cell_count):
clusters_around_cells = find_clusters_around_cells(cell_count, clusters)
orphan_cell_indices = [
ix for ix in range(cell_count) if len(clusters_around_cells[ix]) == 0
] # which cells are assigned no cluster?
ambiguous_cell_indices = [
ix for ix in range(cell_count) if len(clusters_around_cells[ix]) > 1
] # which cells are assigned > 1 clusters?
return clusters_around_cells, orphan_cell_indices, ambiguous_cell_indices