diff --git a/docling/datamodel/base_models.py b/docling/datamodel/base_models.py index 311d6d01..281d6735 100644 --- a/docling/datamodel/base_models.py +++ b/docling/datamodel/base_models.py @@ -117,6 +117,7 @@ class Cluster(BaseModel): bbox: BoundingBox confidence: float = 1.0 cells: List[Cell] = [] + children: List["Cluster"] = [] # Add child cluster support class BasePageElement(BaseModel): diff --git a/docling/datamodel/settings.py b/docling/datamodel/settings.py index b1c47305..46bab75c 100644 --- a/docling/datamodel/settings.py +++ b/docling/datamodel/settings.py @@ -31,6 +31,7 @@ class DebugSettings(BaseModel): visualize_cells: bool = False visualize_ocr: bool = False visualize_layout: bool = False + visualize_raw_layout: bool = False visualize_tables: bool = False profile_pipeline_timings: bool = False diff --git a/docling/models/layout_model.py b/docling/models/layout_model.py index 91897df4..624f8e02 100644 --- a/docling/models/layout_model.py +++ b/docling/models/layout_model.py @@ -7,7 +7,7 @@ from typing import Iterable, List from docling_core.types.doc import CoordOrigin, DocItemLabel from docling_ibm_models.layoutmodel.layout_predictor import LayoutPredictor -from PIL import ImageDraw +from PIL import Image, ImageDraw from docling.datamodel.base_models import ( BoundingBox, @@ -19,7 +19,7 @@ from docling.datamodel.base_models import ( from docling.datamodel.document import ConversionResult from docling.datamodel.settings import settings from docling.models.base_model import BasePageModel -from docling.utils import layout_utils as lu +from docling.utils.layout_postprocessor import LayoutPostprocessor from docling.utils.profiling import TimeRecorder _log = logging.getLogger(__name__) @@ -49,230 +49,101 @@ class LayoutModel(BasePageModel): def __init__(self, artifacts_path: Path): self.layout_predictor = LayoutPredictor(artifacts_path) # TODO temporary - def postprocess(self, clusters_in: List[Cluster], cells: List[Cell], page_height): - MIN_INTERSECTION = 0.2 - CLASS_THRESHOLDS = { - DocItemLabel.CAPTION: 0.35, - DocItemLabel.FOOTNOTE: 0.35, - DocItemLabel.FORMULA: 0.35, - DocItemLabel.LIST_ITEM: 0.35, - DocItemLabel.PAGE_FOOTER: 0.35, - DocItemLabel.PAGE_HEADER: 0.35, - DocItemLabel.PICTURE: 0.2, # low threshold adjust to capture chemical structures for examples. - DocItemLabel.SECTION_HEADER: 0.45, - DocItemLabel.TABLE: 0.35, - DocItemLabel.TEXT: 0.45, - DocItemLabel.TITLE: 0.45, - DocItemLabel.DOCUMENT_INDEX: 0.45, - DocItemLabel.CODE: 0.45, - DocItemLabel.CHECKBOX_SELECTED: 0.45, - DocItemLabel.CHECKBOX_UNSELECTED: 0.45, - DocItemLabel.FORM: 0.45, - DocItemLabel.KEY_VALUE_REGION: 0.45, + def draw_clusters_and_cells_side_by_side( + self, conv_res, page, clusters, mode_prefix: str, show: bool = False + ): + """ + Draws a page image side by side with clusters filtered into two categories: + - Left: Clusters excluding FORM, KEY_VALUE_REGION, and PICTURE. + - Right: Clusters including FORM, KEY_VALUE_REGION, and PICTURE. + """ + label_to_color = { + DocItemLabel.TEXT: (255, 255, 153), # Light Yellow + DocItemLabel.CAPTION: (255, 204, 153), # Light Orange + DocItemLabel.LIST_ITEM: (153, 153, 255), # Light Purple + DocItemLabel.FORMULA: (192, 192, 192), # Gray + DocItemLabel.TABLE: (255, 204, 204), # Light Pink + DocItemLabel.PICTURE: (255, 204, 164), # Light Beige + DocItemLabel.SECTION_HEADER: (255, 153, 153), # Light Red + DocItemLabel.PAGE_HEADER: (204, 255, 204), # Light Green + DocItemLabel.PAGE_FOOTER: ( + 204, + 255, + 204, + ), # Light Green (same as Page-Header) + DocItemLabel.TITLE: (255, 153, 153), # Light Red (same as Section-Header) + DocItemLabel.FOOTNOTE: (200, 200, 255), # Light Blue + DocItemLabel.DOCUMENT_INDEX: (220, 220, 220), # Light Gray + DocItemLabel.CODE: (255, 223, 186), # Peach + DocItemLabel.CHECKBOX_SELECTED: (255, 182, 193), # Pale Green + DocItemLabel.CHECKBOX_UNSELECTED: (255, 182, 193), # Light Pink + DocItemLabel.FORM: (200, 255, 255), # Light Cyan + DocItemLabel.KEY_VALUE_REGION: (183, 65, 14), # Rusty orange } - CLASS_REMAPPINGS = { - DocItemLabel.DOCUMENT_INDEX: DocItemLabel.TABLE, - DocItemLabel.TITLE: DocItemLabel.SECTION_HEADER, + # Filter clusters for left and right images + exclude_labels = { + DocItemLabel.FORM, + DocItemLabel.KEY_VALUE_REGION, + DocItemLabel.PICTURE, } + left_clusters = [c for c in clusters if c.label not in exclude_labels] + right_clusters = [c for c in clusters if c.label in exclude_labels] - _log.debug("================= Start postprocess function ====================") - start_time = time.time() - # Apply Confidence Threshold to cluster predictions - # confidence = self.conf_threshold - clusters_mod = [] + # Create a deep copy of the original image for both sides + left_image = copy.deepcopy(page.image) + right_image = copy.deepcopy(page.image) - for cluster in clusters_in: - confidence = CLASS_THRESHOLDS[cluster.label] - if cluster.confidence >= confidence: - # annotation["created_by"] = "high_conf_pred" + # Function to draw clusters on an image + def draw_clusters(image, clusters): + draw = ImageDraw.Draw(image, "RGBA") + for c in clusters: + cell_color = (0, 0, 0, 40) # Transparent black for cells + for tc in c.cells: + cx0, cy0, cx1, cy1 = tc.bbox.as_tuple() + draw.rectangle( + [(cx0, cy0), (cx1, cy1)], + outline=None, + fill=cell_color, + ) - # Remap class labels where needed. - if cluster.label in CLASS_REMAPPINGS.keys(): - cluster.label = CLASS_REMAPPINGS[cluster.label] - clusters_mod.append(cluster) + x0, y0, x1, y1 = c.bbox.as_tuple() + cluster_fill_color = ( + *list(label_to_color.get(c.label)), # type: ignore + 70, + ) + cluster_outline_color = ( + *list(label_to_color.get(c.label)), # type: ignore + 255, + ) + draw.rectangle( + [(x0, y0), (x1, y1)], + outline=cluster_outline_color, + fill=cluster_fill_color, + ) - # map to dictionary clusters and cells, with bottom left origin - clusters_orig = [ - { - "id": c.id, - "bbox": list( - c.bbox.to_bottom_left_origin(page_height).as_tuple() - ), # TODO - "confidence": c.confidence, - "cell_ids": [], - "type": c.label, - } - for c in clusters_in - ] + # Draw clusters on both images + draw_clusters(left_image, left_clusters) + draw_clusters(right_image, right_clusters) - clusters_out = [ - { - "id": c.id, - "bbox": list( - c.bbox.to_bottom_left_origin(page_height).as_tuple() - ), # TODO - "confidence": c.confidence, - "created_by": "high_conf_pred", - "cell_ids": [], - "type": c.label, - } - for c in clusters_mod - ] + # Combine the images side by side + combined_width = left_image.width * 2 + combined_height = left_image.height + combined_image = Image.new("RGB", (combined_width, combined_height)) + combined_image.paste(left_image, (0, 0)) + combined_image.paste(right_image, (left_image.width, 0)) - del clusters_mod - - raw_cells = [ - { - "id": c.id, - "bbox": list( - c.bbox.to_bottom_left_origin(page_height).as_tuple() - ), # TODO - "text": c.text, - } - for c in cells - ] - cell_count = len(raw_cells) - - _log.debug("---- 0. Treat cluster overlaps ------") - clusters_out = lu.remove_cluster_duplicates_by_conf(clusters_out, 0.8) - - _log.debug( - "---- 1. Initially assign cells to clusters based on minimum intersection ------" - ) - ## Check for cells included in or touched by clusters: - clusters_out = lu.assigning_cell_ids_to_clusters( - clusters_out, raw_cells, MIN_INTERSECTION - ) - - _log.debug("---- 2. Assign Orphans with Low Confidence Detections") - # Creates a map of cell_id->cluster_id - ( - clusters_around_cells, - orphan_cell_indices, - ambiguous_cell_indices, - ) = lu.cell_id_state_map(clusters_out, cell_count) - - # Assign orphan cells with lower confidence predictions - clusters_out, orphan_cell_indices = lu.assign_orphans_with_low_conf_pred( - clusters_out, clusters_orig, raw_cells, orphan_cell_indices - ) - - # Refresh the cell_ids assignment, after creating new clusters using low conf predictions - clusters_out = lu.assigning_cell_ids_to_clusters( - clusters_out, raw_cells, MIN_INTERSECTION - ) - - _log.debug("---- 3. Settle Ambigous Cells") - # Creates an update map after assignment of cell_id->cluster_id - ( - clusters_around_cells, - orphan_cell_indices, - ambiguous_cell_indices, - ) = lu.cell_id_state_map(clusters_out, cell_count) - - # Settle pdf cells that belong to multiple clusters - clusters_out, ambiguous_cell_indices = lu.remove_ambigous_pdf_cell_by_conf( - clusters_out, raw_cells, ambiguous_cell_indices - ) - - _log.debug("---- 4. Set Orphans as Text") - ( - clusters_around_cells, - orphan_cell_indices, - ambiguous_cell_indices, - ) = lu.cell_id_state_map(clusters_out, cell_count) - - clusters_out, orphan_cell_indices = lu.set_orphan_as_text( - clusters_out, clusters_orig, raw_cells, orphan_cell_indices - ) - - _log.debug("---- 5. Merge Cells & and adapt the bounding boxes") - # Merge cells orphan cells - clusters_out = lu.merge_cells(clusters_out) - - # Clean up clusters that remain from merged and unreasonable clusters - clusters_out = lu.clean_up_clusters( - clusters_out, - raw_cells, - merge_cells=True, - img_table=True, - one_cell_table=True, - ) - - new_clusters = lu.adapt_bboxes(raw_cells, clusters_out, orphan_cell_indices) - clusters_out = new_clusters - - ## We first rebuild where every cell is now: - ## Now we write into a prediction cells list, not into the raw cells list. - ## As we don't need previous labels, we best overwrite any old list, because that might - ## have been sorted differently. - ( - clusters_around_cells, - orphan_cell_indices, - ambiguous_cell_indices, - ) = lu.cell_id_state_map(clusters_out, cell_count) - - target_cells = [] - for ix, cell in enumerate(raw_cells): - new_cell = { - "id": ix, - "rawcell_id": ix, - "label": "None", - "bbox": cell["bbox"], - "text": cell["text"], - } - for cluster_index in clusters_around_cells[ - ix - ]: # By previous analysis, this is always 1 cluster. - new_cell["label"] = clusters_out[cluster_index]["type"] - target_cells.append(new_cell) - # _log.debug("New label of cell " + str(ix) + " is " + str(new_cell["label"])) - cells_out = target_cells - - ## ------------------------------- - ## Sort clusters into reasonable reading order, and sort the cells inside each cluster - _log.debug("---- 5. Sort clusters in reading order ------") - sorted_clusters = lu.produce_reading_order( - clusters_out, "raw_cell_ids", "raw_cell_ids", True - ) - clusters_out = sorted_clusters - - # end_time = timer() - _log.debug("---- End of postprocessing function ------") - end_time = time.time() - start_time - _log.debug(f"Finished post processing in seconds={end_time:.3f}") - - cells_out_new = [ - Cell( - id=c["id"], # type: ignore - bbox=BoundingBox.from_tuple( - coord=c["bbox"], origin=CoordOrigin.BOTTOMLEFT # type: ignore - ).to_top_left_origin(page_height), - text=c["text"], # type: ignore + if show: + combined_image.show() + else: + out_path: Path = ( + Path(settings.debug.debug_output_path) + / f"debug_{conv_res.input.file.stem}" ) - for c in cells_out - ] + out_path.mkdir(parents=True, exist_ok=True) - del cells_out - - clusters_out_new = [] - for c in clusters_out: - cluster_cells = [ - ccell for ccell in cells_out_new if ccell.id in c["cell_ids"] # type: ignore - ] - c_new = Cluster( - id=c["id"], # type: ignore - bbox=BoundingBox.from_tuple( - coord=c["bbox"], origin=CoordOrigin.BOTTOMLEFT # type: ignore - ).to_top_left_origin(page_height), - confidence=c["confidence"], # type: ignore - label=DocItemLabel(c["type"]), - cells=cluster_cells, - ) - clusters_out_new.append(c_new) - - return clusters_out_new, cells_out_new + out_file = out_path / f"{mode_prefix}_layout_page_{page.page_no:05}.png" + combined_image.save(str(out_file), format="png") def __call__( self, conv_res: ConversionResult, page_batch: Iterable[Page] @@ -304,44 +175,97 @@ class LayoutModel(BasePageModel): cells=[], ) clusters.append(cluster) - - # Map cells to clusters - # TODO: Remove, postprocess should take care of it anyway. - for cell in page.cells: - for cluster in clusters: - if not cell.bbox.area() > 0: - overlap_frac = 0.0 - else: - overlap_frac = ( - cell.bbox.intersection_area_with(cluster.bbox) - / cell.bbox.area() - ) - - if overlap_frac > 0.5: - cluster.cells.append(cell) + # + # # Map cells to clusters + # # TODO: Remove, postprocess should take care of it anyway. + # for cell in page.cells: + # for cluster in clusters: + # if not cell.bbox.area() > 0: + # overlap_frac = 0.0 + # else: + # overlap_frac = ( + # cell.bbox.intersection_area_with(cluster.bbox) + # / cell.bbox.area() + # ) + # + # if overlap_frac > 0.5: + # cluster.cells.append(cell) # Pre-sort clusters # clusters = self.sort_clusters_by_cell_order(clusters) # DEBUG code: - def draw_clusters_and_cells(show: bool = False): + def draw_clusters_and_cells( + clusters, mode_prefix: str, show: bool = False + ): + label_to_color = { + DocItemLabel.TEXT: (255, 255, 153), # Light Yellow + DocItemLabel.CAPTION: (255, 204, 153), # Light Orange + DocItemLabel.LIST_ITEM: (153, 153, 255), # Light Purple + DocItemLabel.FORMULA: (192, 192, 192), # Gray + DocItemLabel.TABLE: (255, 204, 204), # Light Pink + DocItemLabel.PICTURE: (255, 255, 204), # Light Beige + DocItemLabel.SECTION_HEADER: (255, 153, 153), # Light Red + DocItemLabel.PAGE_HEADER: (204, 255, 204), # Light Green + DocItemLabel.PAGE_FOOTER: ( + 204, + 255, + 204, + ), # Light Green (same as Page-Header) + DocItemLabel.TITLE: ( + 255, + 153, + 153, + ), # Light Red (same as Section-Header) + DocItemLabel.FOOTNOTE: (200, 200, 255), # Light Blue + DocItemLabel.DOCUMENT_INDEX: (220, 220, 220), # Light Gray + DocItemLabel.CODE: (255, 223, 186), # Peach + DocItemLabel.CHECKBOX_SELECTED: ( + 255, + 182, + 193, + ), # Pale Green + DocItemLabel.CHECKBOX_UNSELECTED: ( + 255, + 182, + 193, + ), # Light Pink + DocItemLabel.FORM: (200, 255, 255), # Light Cyan + DocItemLabel.KEY_VALUE_REGION: ( + 183, + 65, + 14, + ), # Rusty orange + } + image = copy.deepcopy(page.image) if image is not None: - draw = ImageDraw.Draw(image) + draw = ImageDraw.Draw(image, "RGBA") for c in clusters: - x0, y0, x1, y1 = c.bbox.as_tuple() - draw.rectangle([(x0, y0), (x1, y1)], outline="green") - - cell_color = ( - random.randint(30, 140), - random.randint(30, 140), - random.randint(30, 140), - ) + cell_color = (0, 0, 0, 40) for tc in c.cells: # [:1]: - x0, y0, x1, y1 = tc.bbox.as_tuple() + cx0, cy0, cx1, cy1 = tc.bbox.as_tuple() draw.rectangle( - [(x0, y0), (x1, y1)], outline=cell_color + [(cx0, cy0), (cx1, cy1)], + outline=None, + fill=cell_color, ) + + x0, y0, x1, y1 = c.bbox.as_tuple() + cluster_fill_color = ( + *list(label_to_color.get(c.label)), # type: ignore + 70, + ) + cluster_outline_color = ( + *list(label_to_color.get(c.label)), # type: ignore + 255, + ) + draw.rectangle( + [(x0, y0), (x1, y1)], + outline=cluster_outline_color, + fill=cluster_fill_color, + ) + if show: image.show() else: @@ -352,19 +276,42 @@ class LayoutModel(BasePageModel): out_path.mkdir(parents=True, exist_ok=True) out_file = ( - out_path / f"layout_page_{page.page_no:05}.png" + out_path + / f"{mode_prefix}_layout_page_{page.page_no:05}.png" ) image.save(str(out_file), format="png") - # draw_clusters_and_cells() + if settings.debug.visualize_raw_layout: + self.draw_clusters_and_cells_side_by_side( + conv_res, page, clusters, mode_prefix="raw" + ) - clusters, page.cells = self.postprocess( - clusters, page.cells, page.size.height + # Apply postprocessing + processed_clusters, processed_cells = LayoutPostprocessor( + page.cells, clusters + ).postprocess() + # processed_clusters, processed_cells = clusters, page.cells + + page.cells = processed_cells + page.predictions.layout = LayoutPrediction( + clusters=processed_clusters ) - page.predictions.layout = LayoutPrediction(clusters=clusters) + # def check_for_overlaps(clusters): + # for i, cluster in enumerate(clusters): + # for j, other_cluster in enumerate(clusters): + # if i >= j or cluster.label in [DocItemLabel.FORM, DocItemLabel.KEY_VALUE_REGION, DocItemLabel.PICTURE]: + # continue + # + # overlap_area = cluster.bbox.intersection_area_with(other_cluster.bbox) + # if overlap_area > 0: + # print(f"Overlap detected between cluster {i} and cluster {j}") + # print(f"Cluster {i} bbox: {cluster.bbox}, Cluster {j} bbox: {other_cluster.bbox}") + # check_for_overlaps(processed_clusters) if settings.debug.visualize_layout: - draw_clusters_and_cells() + self.draw_clusters_and_cells_side_by_side( + conv_res, page, processed_clusters, mode_prefix="postprocessed" + ) yield page diff --git a/docling/utils/layout_postprocessor.py b/docling/utils/layout_postprocessor.py new file mode 100644 index 00000000..76d548c6 --- /dev/null +++ b/docling/utils/layout_postprocessor.py @@ -0,0 +1,444 @@ +import logging +from bisect import bisect_left +from typing import Dict, List, Set, Tuple + +from docling_core.types.doc import DocItemLabel +from rtree import index + +from docling.datamodel.base_models import BoundingBox, Cell, Cluster + +_log = logging.getLogger(__name__) + + +class UnionFind: + """Union-Find (Disjoint-Set) data structure.""" + + def __init__(self, elements): + self.parent = {elem: elem for elem in elements} + self.rank = {elem: 0 for elem in elements} + + def find(self, x): + if self.parent[x] != x: + self.parent[x] = self.find(self.parent[x]) # Path compression + return self.parent[x] + + def union(self, x, y): + root_x = self.find(x) + root_y = self.find(y) + if root_x != root_y: + if self.rank[root_x] > self.rank[root_y]: + self.parent[root_y] = root_x + elif self.rank[root_x] < self.rank[root_y]: + self.parent[root_x] = root_y + else: + self.parent[root_y] = root_x + self.rank[root_x] += 1 + + def groups(self): + """Return groups as {root: [elements]}.""" + from collections import defaultdict + + components = defaultdict(list) + for elem in self.parent: + root = self.find(elem) + components[root].append(elem) + return components + + +class SpatialClusterIndex: + """Helper class to manage spatial indexes for clusters and find overlaps efficiently.""" + + def __init__(self, clusters: List[Cluster]): + # Create spatial index + p = index.Property() + p.dimension = 2 + self.spatial_index = index.Index(properties=p) + + # Initialize interval trees + self.x_intervals = IntervalTree() + self.y_intervals = IntervalTree() + + # Map to store clusters by ID + self.clusters_by_id = {} # type: ignore + + # Populate indexes and maps + for cluster in clusters: + self.add_cluster(cluster) + + def add_cluster(self, cluster: Cluster): + """Add a cluster to all indexes.""" + self.spatial_index.insert(cluster.id, cluster.bbox.as_tuple()) + self.x_intervals.insert(cluster.bbox.l, cluster.bbox.r, cluster.id) + self.y_intervals.insert(cluster.bbox.t, cluster.bbox.b, cluster.id) + self.clusters_by_id[cluster.id] = cluster + + def remove_cluster(self, cluster: Cluster): + """Remove a cluster from all indexes.""" + self.spatial_index.delete(cluster.id, cluster.bbox.as_tuple()) + # Note: IntervalTree doesn't support deletion, but we handle this + # by checking clusters_by_id membership + del self.clusters_by_id[cluster.id] + + def find_candidates(self, bbox: BoundingBox) -> Set[int]: + """Find all potential overlapping cluster IDs using all indexes.""" + bbox_tuple = bbox.as_tuple() + spatial_candidates = set(self.spatial_index.intersection(bbox_tuple)) + x_candidates = self.x_intervals.find_containing( + bbox.l + ) | self.x_intervals.find_containing(bbox.r) + y_candidates = self.y_intervals.find_containing( + bbox.t + ) | self.y_intervals.find_containing(bbox.b) + return spatial_candidates | x_candidates | y_candidates + + def check_overlap( + self, + bbox1: BoundingBox, + bbox2: BoundingBox, + overlap_threshold: float, + containment_threshold: float, + ) -> bool: + """Check if two bboxes overlap sufficiently.""" + area1 = bbox1.area() + area2 = bbox2.area() + if area1 <= 0 or area2 <= 0: + return False + + overlap_area = bbox1.intersection_area_with(bbox2) + if overlap_area <= 0: + return False + + # Check both IoU and containment + iou = overlap_area / (area1 + area2 - overlap_area) + containment_ratio1 = overlap_area / area1 + containment_ratio2 = overlap_area / area2 + + return ( + iou > overlap_threshold + or containment_ratio1 > containment_threshold + or containment_ratio2 > containment_threshold + ) + + +class IntervalTree: + def __init__(self): + self.intervals = [] # List of (min, max, box_id) sorted by min + + def insert(self, min_val: float, max_val: float, box_id: int): + self.intervals.append((min_val, max_val, box_id)) + self.intervals.sort(key=lambda x: x[0]) + + def find_containing(self, point: float) -> Set[int]: + pos = bisect_left(self.intervals, (point, float("-inf"), -1)) + result = set() + + i = pos - 1 + while i >= 0: + min_val, max_val, box_id = self.intervals[i] + if min_val <= point <= max_val: + result.add(box_id) + else: + break + i -= 1 + + return result + + +class LayoutPostprocessor: + """Postprocesses layout predictions by cleaning up clusters and mapping cells.""" + + # Cluster type-specific parameters for overlap resolution + OVERLAP_PARAMS = { + "regular": {"area_threshold": 1.3, "conf_threshold": 0.15}, + "picture": {"area_threshold": 2.0, "conf_threshold": 0.3}, + "wrapper": {"area_threshold": 2.0, "conf_threshold": 0.2}, + } + + WRAPPER_TYPES = {DocItemLabel.FORM, DocItemLabel.KEY_VALUE_REGION} + SPECIAL_TYPES = WRAPPER_TYPES | {DocItemLabel.PICTURE} + + CONFIDENCE_THRESHOLDS = { + DocItemLabel.CAPTION: 0.35, + DocItemLabel.FOOTNOTE: 0.35, + DocItemLabel.FORMULA: 0.35, + DocItemLabel.LIST_ITEM: 0.35, + DocItemLabel.PAGE_FOOTER: 0.35, + DocItemLabel.PAGE_HEADER: 0.35, + DocItemLabel.PICTURE: 0.1, + DocItemLabel.SECTION_HEADER: 0.45, + DocItemLabel.TABLE: 0.35, + DocItemLabel.TEXT: 0.45, + DocItemLabel.TITLE: 0.45, + DocItemLabel.CODE: 0.45, + DocItemLabel.CHECKBOX_SELECTED: 0.45, + DocItemLabel.CHECKBOX_UNSELECTED: 0.45, + DocItemLabel.FORM: 0.45, + DocItemLabel.KEY_VALUE_REGION: 0.45, + DocItemLabel.DOCUMENT_INDEX: 0.45, + } + + LABEL_REMAPPING = { + DocItemLabel.DOCUMENT_INDEX: DocItemLabel.TABLE, + DocItemLabel.TITLE: DocItemLabel.SECTION_HEADER, + } + + def __init__(self, cells: List[Cell], clusters: List[Cluster]): + """Initialize processor with cells and clusters.""" + self.cells = cells + self.regular_clusters = [ + c for c in clusters if c.label not in self.SPECIAL_TYPES + ] + self.special_clusters = [c for c in clusters if c.label in self.SPECIAL_TYPES] + + def postprocess(self) -> Tuple[List[Cluster], List[Cell]]: + """Main processing pipeline.""" + regular_clusters = self._process_regular_clusters() + special_clusters = self._process_special_clusters() + final_clusters = self._sort_clusters(regular_clusters + special_clusters) + return final_clusters, self.cells + + def _process_regular_clusters(self) -> List[Cluster]: + """Process regular clusters with iterative refinement.""" + clusters = [ + c + for c in self.regular_clusters + if c.confidence >= self.CONFIDENCE_THRESHOLDS[c.label] + ] + + # Apply label remapping + for cluster in clusters: + if cluster.label in self.LABEL_REMAPPING: + cluster.label = self.LABEL_REMAPPING[cluster.label] + + # Initial cell assignment + clusters = self._assign_cells_to_clusters(clusters) + + # Handle orphaned cells + unassigned = self._find_unassigned_cells(clusters) + if unassigned: + next_id = max((c.id for c in clusters), default=0) + 1 + orphan_clusters = [ + Cluster( + id=next_id + i, + label=DocItemLabel.TEXT, + bbox=cell.bbox, + confidence=0.0, + cells=[cell], + ) + for i, cell in enumerate(unassigned) + ] + clusters.extend(orphan_clusters) + + # Iterative refinement + prev_count = len(clusters) + 1 + for _ in range(3): # Maximum 3 iterations + if prev_count == len(clusters): + break + prev_count = len(clusters) + clusters = self._adjust_cluster_bboxes(clusters) + clusters = self._remove_overlapping_clusters(clusters, "regular") + + return clusters + + def _process_special_clusters(self) -> List[Cluster]: + """Process special clusters (pictures and wrappers).""" + # Handle pictures + picture_clusters = [ + c + for c in self.special_clusters + if c.label == DocItemLabel.PICTURE + and c.confidence >= self.CONFIDENCE_THRESHOLDS[c.label] + ] + picture_clusters = self._remove_overlapping_clusters( + picture_clusters, "picture" + ) + + # Process wrapper clusters + wrapper_clusters = [] + for wrapper in ( + c for c in self.special_clusters if c.label in self.WRAPPER_TYPES + ): + if wrapper.confidence < self.CONFIDENCE_THRESHOLDS[wrapper.label]: + continue + + # Find contained regular clusters + contained = [] + for cluster in self.regular_clusters: + overlap = cluster.bbox.intersection_area_with(wrapper.bbox) + if overlap > 0: + containment = overlap / cluster.bbox.area() + if containment > 0.8: # High containment threshold for wrappers + contained.append(cluster) + + if contained: + wrapper.children = contained + wrapper.bbox = BoundingBox( + l=min(c.bbox.l for c in contained), + t=min(c.bbox.t for c in contained), + r=max(c.bbox.r for c in contained), + b=max(c.bbox.b for c in contained), + ) + wrapper_clusters.append(wrapper) + + return picture_clusters + self._remove_overlapping_clusters( + wrapper_clusters, "wrapper" + ) + + def _remove_overlapping_clusters( + self, + clusters: List[Cluster], + cluster_type: str, + overlap_threshold: float = 0.8, + containment_threshold: float = 0.8, + ) -> List[Cluster]: + """Remove overlapping clusters using efficient spatial indexing.""" + if not clusters: + return [] + + # Initialize spatial index + spatial_index = SpatialClusterIndex(clusters) + uf = UnionFind(spatial_index.clusters_by_id.keys()) + + # Group overlapping clusters using spatial index + for cluster in clusters: + candidates = spatial_index.find_candidates(cluster.bbox) + candidates.discard(cluster.id) # Remove self + + for other_id in candidates: + if spatial_index.check_overlap( + cluster.bbox, + spatial_index.clusters_by_id[other_id].bbox, + overlap_threshold, + containment_threshold, + ): + uf.union(cluster.id, other_id) + + # Process each group using type-specific parameters + params = self.OVERLAP_PARAMS[cluster_type] + result = [] + + for group in uf.groups().values(): + if len(group) == 1: + result.append(spatial_index.clusters_by_id[group[0]]) + continue + + # Get clusters in group + group_clusters = [spatial_index.clusters_by_id[cid] for cid in group] + + # Find best cluster using area and confidence + best = self._select_best_cluster( + group_clusters, params["area_threshold"], params["conf_threshold"] + ) + + # Merge cells from other clusters into best + for cluster in group_clusters: + if cluster != best: + best.cells.extend(cluster.cells) + + result.append(best) + + return result + + def _select_best_cluster( + self, + clusters: List[Cluster], + area_threshold: float, + conf_threshold: float, + ) -> Cluster: + """Iteratively select best cluster based on area and confidence thresholds.""" + current_best = None + for candidate in clusters: + should_select = True + for other in clusters: + if other == candidate: + continue + + area_ratio = candidate.bbox.area() / other.bbox.area() + conf_diff = other.confidence - candidate.confidence + + if area_ratio <= area_threshold and conf_diff > conf_threshold: + should_select = False + break + + if should_select: + if current_best is None or ( + candidate.bbox.area() > current_best.bbox.area() + and current_best.confidence - candidate.confidence <= conf_threshold + ): + current_best = candidate + + return current_best if current_best else clusters[0] + + def _assign_cells_to_clusters( + self, clusters: List[Cluster], min_overlap: float = 0.2 + ) -> List[Cluster]: + """Assign cells to best overlapping cluster.""" + for cluster in clusters: + cluster.cells = [] + + for cell in self.cells: + if not cell.text.strip(): + continue + + best_overlap = min_overlap + best_cluster = None + + for cluster in clusters: + if cell.bbox.area() <= 0: + continue + + overlap = cell.bbox.intersection_area_with(cluster.bbox) + overlap_ratio = overlap / cell.bbox.area() + + if overlap_ratio > best_overlap: + best_overlap = overlap_ratio + best_cluster = cluster + + if best_cluster is not None: + best_cluster.cells.append(cell) + + return clusters + + def _find_unassigned_cells(self, clusters: List[Cluster]) -> List[Cell]: + """Find cells not assigned to any cluster.""" + assigned = {cell.id for cluster in clusters for cell in cluster.cells} + return [ + cell for cell in self.cells if cell.id not in assigned and cell.text.strip() + ] + + def _adjust_cluster_bboxes(self, clusters: List[Cluster]) -> List[Cluster]: + """Adjust cluster bounding boxes to contain their cells.""" + for cluster in clusters: + if not cluster.cells: + continue + + cells_bbox = BoundingBox( + l=min(cell.bbox.l for cell in cluster.cells), + t=min(cell.bbox.t for cell in cluster.cells), + r=max(cell.bbox.r for cell in cluster.cells), + b=max(cell.bbox.b for cell in cluster.cells), + ) + + if cluster.label == DocItemLabel.TABLE: + # For tables, take union of current bbox and cells bbox + cluster.bbox = BoundingBox( + l=min(cluster.bbox.l, cells_bbox.l), + t=min(cluster.bbox.t, cells_bbox.t), + r=max(cluster.bbox.r, cells_bbox.r), + b=max(cluster.bbox.b, cells_bbox.b), + ) + else: + cluster.bbox = cells_bbox + + return clusters + + def _sort_clusters(self, clusters: List[Cluster]) -> List[Cluster]: + """Sort clusters in reading order (top-to-bottom, left-to-right).""" + + def reading_order_key(cluster: Cluster) -> Tuple[float, float]: + if cluster.cells and cluster.label != DocItemLabel.PICTURE: + first_cell = min(cluster.cells, key=lambda c: (c.bbox.t, c.bbox.l)) + return (first_cell.bbox.t, first_cell.bbox.l) + return (cluster.bbox.t, cluster.bbox.l) + + return sorted(clusters, key=reading_order_key) diff --git a/docling/utils/layout_utils.py b/docling/utils/layout_utils.py deleted file mode 100644 index ceb18047..00000000 --- a/docling/utils/layout_utils.py +++ /dev/null @@ -1,812 +0,0 @@ -import copy -import logging - -import networkx as nx -from docling_core.types.doc import DocItemLabel - -logger = logging.getLogger("layout_utils") - - -## ------------------------------- -## Geometric helper functions -## The coordinates grow left to right, and bottom to top. -## The bounding box list elements 0 to 3 are x_left, y_bottom, x_right, y_top. - - -def area(bbox): - return (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) - - -def contains(bbox_i, bbox_j): - ## Returns True if bbox_i contains bbox_j, else False - return ( - bbox_i[0] <= bbox_j[0] - and bbox_i[1] <= bbox_j[1] - and bbox_i[2] >= bbox_j[2] - and bbox_i[3] >= bbox_j[3] - ) - - -def is_intersecting(bbox_i, bbox_j): - return not ( - bbox_i[2] < bbox_j[0] - or bbox_i[0] > bbox_j[2] - or bbox_i[3] < bbox_j[1] - or bbox_i[1] > bbox_j[3] - ) - - -def bb_iou(boxA, boxB): - # determine the (x, y)-coordinates of the intersection rectangle - xA = max(boxA[0], boxB[0]) - yA = max(boxA[1], boxB[1]) - xB = min(boxA[2], boxB[2]) - yB = min(boxA[3], boxB[3]) - # compute the area of intersection rectangle - interArea = max(0, xB - xA + 1) * max(0, yB - yA + 1) - # compute the area of both the prediction and ground-truth - # rectangles - boxAArea = (boxA[2] - boxA[0] + 1) * (boxA[3] - boxA[1] + 1) - boxBArea = (boxB[2] - boxB[0] + 1) * (boxB[3] - boxB[1] + 1) - # compute the intersection over union by taking the intersection - # area and dividing it by the sum of prediction + ground-truth - # areas - the interesection area - iou = interArea / float(boxAArea + boxBArea - interArea) - # return the intersection over union value - return iou - - -def compute_intersection(bbox_i, bbox_j): - ## Returns the size of the intersection area of the two boxes - if not is_intersecting(bbox_i, bbox_j): - return 0 - ## Determine the (x, y)-coordinates of the intersection rectangle: - xA = max(bbox_i[0], bbox_j[0]) - yA = max(bbox_i[1], bbox_j[1]) - xB = min(bbox_i[2], bbox_j[2]) - yB = min(bbox_i[3], bbox_j[3]) - ## Compute the area of intersection rectangle: - interArea = (xB - xA) * (yB - yA) - if interArea < 0: - logger.debug("Warning: Negative intersection detected!") - return 0 - return interArea - - -def surrounding(bbox_i, bbox_j): - ## Computes minimal box that contains both input boxes - sbox = [] - sbox.append(min(bbox_i[0], bbox_j[0])) - sbox.append(min(bbox_i[1], bbox_j[1])) - sbox.append(max(bbox_i[2], bbox_j[2])) - sbox.append(max(bbox_i[3], bbox_j[3])) - return sbox - - -def surrounding_list(bbox_list): - ## Computes minimal box that contains all boxes in the input list - ## The list should be non-empty, but just in case it's not: - if len(bbox_list) == 0: - sbox = [0, 0, 0, 0] - else: - sbox = [] - sbox.append(min([bbox[0] for bbox in bbox_list])) - sbox.append(min([bbox[1] for bbox in bbox_list])) - sbox.append(max([bbox[2] for bbox in bbox_list])) - sbox.append(max([bbox[3] for bbox in bbox_list])) - return sbox - - -def vertical_overlap(bboxA, bboxB): - ## bbox[1] is the lower bound, bbox[3] the upper bound (larger number) - if bboxB[3] < bboxA[1]: ## B below A - return False - elif bboxA[3] < bboxB[1]: ## A below B - return False - else: - return True - - -def vertical_overlap_fraction(bboxA, bboxB): - ## Returns the vertical overlap as fraction of the lower bbox height. - ## bbox[1] is the lower bound, bbox[3] the upper bound (larger number) - ## Height 0 is permitted in the input. - heightA = bboxA[3] - bboxA[1] - heightB = bboxB[3] - bboxB[1] - min_height = min(heightA, heightB) - if bboxA[3] >= bboxB[3]: ## A starts higher or equal - if ( - bboxA[1] <= bboxB[1] - ): ## B is completely in A; this can include height of B = 0: - fraction = 1 - else: - overlap = max(bboxB[3] - bboxA[1], 0) - fraction = overlap / max(min_height, 0.001) - else: - if ( - bboxB[1] <= bboxA[1] - ): ## A is completely in B; this can include height of A = 0: - fraction = 1 - else: - overlap = max(bboxA[3] - bboxB[1], 0) - fraction = overlap / max(min_height, 0.001) - return fraction - - -## ------------------------------- -## Cluster-and-cell relations - - -def compute_enclosed_cells( - cluster_bbox, raw_cells, min_cell_intersection_with_cluster=0.2 -): - cells_in_cluster = [] - cells_in_cluster_int = [] - for ix, cell in enumerate(raw_cells): - cell_bbox = cell["bbox"] - intersection = compute_intersection(cell_bbox, cluster_bbox) - frac_area = area(cell_bbox) * min_cell_intersection_with_cluster - - if ( - intersection > frac_area and frac_area > 0 - ): # intersect > certain fraction of cell - cells_in_cluster.append(ix) - cells_in_cluster_int.append(intersection) - elif contains( - cluster_bbox, - [cell_bbox[0] + 3, cell_bbox[1] + 3, cell_bbox[2] - 3, cell_bbox[3] - 3], - ): - cells_in_cluster.append(ix) - return cells_in_cluster, cells_in_cluster_int - - -def find_clusters_around_cells(cell_count, clusters): - ## Per raw cell, find to which clusters it belongs. - ## Return list of these indices in the raw-cell order. - clusters_around_cells = [[] for _ in range(cell_count)] - for cl_ix, cluster in enumerate(clusters): - for ix in cluster["cell_ids"]: - clusters_around_cells[ix].append(cl_ix) - return clusters_around_cells - - -def find_cell_index(raw_ix, cell_array): - ## "raw_ix" is a rawcell_id. - ## "cell_array" has the structure of an (annotation) cells array. - ## Returns index of cell in cell_array that has this rawcell_id. - for ix, cell in enumerate(cell_array): - if cell["rawcell_id"] == raw_ix: - return ix - - -def find_cell_indices(cluster, cell_array): - ## "cluster" must have the structure as in a clusters array in a prediction, - ## "cell_array" that of a cells array. - ## Returns list of indices of cells in cell_array that have the rawcell_ids as in the cluster, - ## in the order of the rawcell_ids. - result = [] - for raw_ix in sorted(cluster["cell_ids"]): - ## Find the cell with this rawcell_id (if any) - for ix, cell in enumerate(cell_array): - if cell["rawcell_id"] == raw_ix: - result.append(ix) - return result - - -def find_first_cell_index(cluster, cell_array): - ## "cluster" must be a dict with key "cell_ids"; it can also be a line. - ## "cell_array" has the structure of a cells array in an annotation. - ## Returns index of cell in cell_array that has the lowest rawcell_id from the cluster. - result = [] ## We keep it a list as it can be empty (picture without text cells) - if len(cluster["cell_ids"]) == 0: - return result - raw_ix = min(cluster["cell_ids"]) - ## Find the cell with this rawcell_id (if any) - for ix, cell in enumerate(cell_array): - if cell["rawcell_id"] == raw_ix: - result.append(ix) - break ## One is enough; should be only one anyway. - if result == []: - logger.debug( - " Warning: Raw cell " + str(raw_ix) + " not found in annotation cells" - ) - return result - - -## ------------------------------- -## Cluster labels and text - - -def relabel_cluster(cluster, cl_ix, new_label, target_pred): - ## "cluster" must have the structure as in a clusters array in a prediction, - ## "cl_ix" is its index in target_pred, - ## "new_label" is the intended new label, - ## "target_pred" is the entire current target prediction. - ## Sets label on the cluster itself, and on the cells in the target_pred. - ## Returns new_label so that also the cl_label variable in the main code is easily set. - target_pred["clusters"][cl_ix]["type"] = new_label - cluster_target_cells = find_cell_indices(cluster, target_pred["cells"]) - for ix in cluster_target_cells: - target_pred["cells"][ix]["label"] = new_label - return new_label - - -def find_cluster_text(cluster, raw_cells): - ## "cluster" must be a dict with "cell_ids"; it can also be a line. - ## "raw_cells" must have the format of item["raw"]["cells"] - ## Returns the text of the cluster, with blanks between the cell contents - ## (which seem to be words or phrases without starting or trailing blanks). - ## Note that in formulas, this may give a lot more blanks than originally - cluster_text = "" - for raw_ix in sorted(cluster["cell_ids"]): - cluster_text = cluster_text + raw_cells[raw_ix]["text"] + " " - return cluster_text.rstrip() - - -def find_cluster_text_without_blanks(cluster, raw_cells): - ## "cluster" must be a dict with "cell_ids"; it can also be a line. - ## "raw_cells" must have the format of item["raw"]["cells"] - ## Returns the text of the cluster, without blanks between the cell contents - ## Interesting in formula analysis. - cluster_text = "" - for raw_ix in sorted(cluster["cell_ids"]): - cluster_text = cluster_text + raw_cells[raw_ix]["text"] - return cluster_text.rstrip() - - -## ------------------------------- -## Clusters and lines -## (Most line-oriented functions are only needed in TextAnalysisGivenClusters, -## but this one also in FormulaAnalysis) - - -def build_cluster_from_lines(lines, label, id): - ## Lines must be a non-empty list of dicts (lines) with elements "cell_ids" and "bbox" - ## (There is no condition that they are really geometrically lines) - ## A cluster in standard format is returned with given label and id - local_lines = copy.deepcopy( - lines - ) ## without this, it changes "lines" also outside this function - first_line = local_lines.pop(0) - cluster = { - "id": id, - "type": label, - "cell_ids": first_line["cell_ids"], - "bbox": first_line["bbox"], - "confidence": 0, - "created_by": "merged_cells", - } - confidence = 0 - counter = 0 - for line in local_lines: - new_cell_ids = cluster["cell_ids"] + line["cell_ids"] - cluster["cell_ids"] = new_cell_ids - cluster["bbox"] = surrounding(cluster["bbox"], line["bbox"]) - counter += 1 - confidence += line["confidence"] - confidence = confidence / counter - cluster["confidence"] = confidence - return cluster - - -## ------------------------------- -## Reading order - - -def produce_reading_order(clusters, cluster_sort_type, cell_sort_type, sort_ids): - ## In: - ## Clusters: list as in predictions. - ## cluster_sort_type: string, currently only "raw_cells". - ## cell_sort_type: string, currently only "raw_cells". - ## sort_ids: Boolean, whether the cluster ids should be adapted to their new position - ## Out: Another clusters list, sorted according to the type. - - logger.debug("---- Start cluster sorting ------") - - if cell_sort_type == "raw_cell_ids": - for cl in clusters: - sorted_cell_ids = sorted(cl["cell_ids"]) - cl["cell_ids"] = sorted_cell_ids - else: - logger.debug( - "Unknown cell_sort_type `" - + cell_sort_type - + "`, no cell sorting will happen." - ) - - if cluster_sort_type == "raw_cell_ids": - clusters_with_cells = [cl for cl in clusters if cl["cell_ids"] != []] - clusters_without_cells = [cl for cl in clusters if cl["cell_ids"] == []] - logger.debug( - "Clusters with cells: " + str([cl["id"] for cl in clusters_with_cells]) - ) - logger.debug( - " Their first cell ids: " - + str([cl["cell_ids"][0] for cl in clusters_with_cells]) - ) - logger.debug( - "Clusters without cells: " - + str([cl["id"] for cl in clusters_without_cells]) - ) - clusters_with_cells_sorted = sorted( - clusters_with_cells, key=lambda cluster: cluster["cell_ids"][0] - ) - logger.debug( - " First cell ids after sorting: " - + str([cl["cell_ids"][0] for cl in clusters_with_cells_sorted]) - ) - sorted_clusters = clusters_with_cells_sorted + clusters_without_cells - else: - logger.debug( - "Unknown cluster_sort_type: `" - + cluster_sort_type - + "`, no cluster sorting will happen." - ) - - if sort_ids: - for i, cl in enumerate(sorted_clusters): - cl["id"] = i - return sorted_clusters - - -## ------------------------------- -## Line Splitting - - -def sort_cells_horizontal(line_cell_ids, raw_cells): - ## "line_cells" should be a non-empty list of (raw) cell_ids - ## "raw_cells" has the structure of item["raw"]["cells"]. - ## Sorts the cells in the line by x0 (left start). - new_line_cell_ids = sorted( - line_cell_ids, key=lambda cell_id: raw_cells[cell_id]["bbox"][0] - ) - return new_line_cell_ids - - -def adapt_bboxes(raw_cells, clusters, orphan_cell_indices): - new_clusters = [] - for ix, cluster in enumerate(clusters): - new_cluster = copy.deepcopy(cluster) - logger.debug( - "Treating cluster " + str(ix) + ", type " + str(new_cluster["type"]) - ) - logger.debug(" with cells: " + str(new_cluster["cell_ids"])) - if len(cluster["cell_ids"]) == 0 and cluster["type"] != DocItemLabel.PICTURE: - logger.debug(" Empty non-picture, removed") - continue ## Skip this former cluster, now without cells. - new_bbox = adapt_bbox(raw_cells, new_cluster, orphan_cell_indices) - new_cluster["bbox"] = new_bbox - new_clusters.append(new_cluster) - return new_clusters - - -def adapt_bbox(raw_cells, cluster, orphan_cell_indices): - if not (cluster["type"] in [DocItemLabel.TABLE, DocItemLabel.PICTURE]): - ## A text-like cluster. The bbox only needs to be around the text cells: - logger.debug(" Initial bbox: " + str(cluster["bbox"])) - new_bbox = surrounding_list( - [raw_cells[cid]["bbox"] for cid in cluster["cell_ids"]] - ) - logger.debug(" New bounding box:" + str(new_bbox)) - if cluster["type"] == DocItemLabel.PICTURE: - ## We only make the bbox completely comprise included text cells: - logger.debug(" Picture") - if len(cluster["cell_ids"]) != 0: - min_bbox = surrounding_list( - [raw_cells[cid]["bbox"] for cid in cluster["cell_ids"]] - ) - logger.debug(" Minimum bbox: " + str(min_bbox)) - logger.debug(" Initial bbox: " + str(cluster["bbox"])) - new_bbox = surrounding(min_bbox, cluster["bbox"]) - logger.debug(" New bbox (initial and text cells): " + str(new_bbox)) - else: - logger.debug(" without text cells, no change.") - new_bbox = cluster["bbox"] - else: ## A table - ## At least we have to keep the included text cells, and we make the bbox completely comprise them - min_bbox = surrounding_list( - [raw_cells[cid]["bbox"] for cid in cluster["cell_ids"]] - ) - logger.debug(" Minimum bbox: " + str(min_bbox)) - logger.debug(" Initial bbox: " + str(cluster["bbox"])) - new_bbox = surrounding(min_bbox, cluster["bbox"]) - logger.debug(" Possibly increased bbox: " + str(new_bbox)) - - ## Now we look which non-belonging cells are covered. - ## (To decrease dependencies, we don't make use of which cells we actually removed.) - ## We don't worry about orphan cells, those could still be added to the table. - enclosed_cells = compute_enclosed_cells( - new_bbox, raw_cells, min_cell_intersection_with_cluster=0.3 - )[0] - additional_cells = set(enclosed_cells) - set(cluster["cell_ids"]) - logger.debug( - " Additional cells enclosed by Table bbox: " + str(additional_cells) - ) - spurious_cells = additional_cells - set(orphan_cell_indices) - logger.debug( - " Spurious cells enclosed by Table bbox (additional minus orphans): " - + str(spurious_cells) - ) - if len(spurious_cells) == 0: - return new_bbox - - ## Else we want to keep as much as possible, e.g., grid lines, but not the spurious cells if we can. - ## We initialize possible cuts with the current bbox. - left_cut = new_bbox[0] - right_cut = new_bbox[2] - upper_cut = new_bbox[3] - lower_cut = new_bbox[1] - - for cell_ix in spurious_cells: - cell = raw_cells[cell_ix] - # logger.debug(" Spurious cell bbox: " + str(cell["bbox"])) - is_left = cell["bbox"][2] < min_bbox[0] - is_right = cell["bbox"][0] > min_bbox[2] - is_above = cell["bbox"][1] > min_bbox[3] - is_below = cell["bbox"][3] < min_bbox[1] - # logger.debug(" Left, right, above, below? " + str([is_left, is_right, is_above, is_below])) - - if is_left: - if cell["bbox"][2] > left_cut: - ## We move the left cut to exclude this cell: - left_cut = cell["bbox"][2] - if is_right: - if cell["bbox"][0] < right_cut: - ## We move the right cut to exclude this cell: - right_cut = cell["bbox"][0] - if is_above: - if cell["bbox"][1] < upper_cut: - ## We move the upper cut to exclude this cell: - upper_cut = cell["bbox"][1] - if is_below: - if cell["bbox"][3] > lower_cut: - ## We move the left cut to exclude this cell: - lower_cut = cell["bbox"][3] - # logger.debug(" Current bbox: " + str([left_cut, lower_cut, right_cut, upper_cut])) - - new_bbox = [left_cut, lower_cut, right_cut, upper_cut] - - logger.debug(" Final bbox: " + str(new_bbox)) - return new_bbox - - -def remove_cluster_duplicates_by_conf(cluster_predictions, threshold=0.5): - DuplicateDeletedClusterIDs = [] - for cluster_1 in cluster_predictions: - for cluster_2 in cluster_predictions: - if cluster_1["id"] != cluster_2["id"]: - if_conf = False - if cluster_1["confidence"] > cluster_2["confidence"]: - if_conf = True - if if_conf == True: - if bb_iou(cluster_1["bbox"], cluster_2["bbox"]) > threshold: - DuplicateDeletedClusterIDs.append(cluster_2["id"]) - elif contains( - cluster_1["bbox"], - [ - cluster_2["bbox"][0] + 3, - cluster_2["bbox"][1] + 3, - cluster_2["bbox"][2] - 3, - cluster_2["bbox"][3] - 3, - ], - ): - DuplicateDeletedClusterIDs.append(cluster_2["id"]) - - DuplicateDeletedClusterIDs = list(set(DuplicateDeletedClusterIDs)) - - for cl_id in DuplicateDeletedClusterIDs: - for cluster in cluster_predictions: - if cl_id == cluster["id"]: - cluster_predictions.remove(cluster) - return cluster_predictions - - -# Assign orphan cells by a low confidence prediction that is below the assigned confidence -def assign_orphans_with_low_conf_pred( - cluster_predictions, cluster_predictions_low, raw_cells, orphan_cell_indices -): - for orph_id in orphan_cell_indices: - cluster_chosen = {} - iou_thresh = 0.05 - confidence = 0.05 - - # Loop over all predictions, and find the one with the highest IOU, and confidence - for cluster in cluster_predictions_low: - calc_iou = bb_iou(cluster["bbox"], raw_cells[orph_id]["bbox"]) - cluster_area = (cluster["bbox"][3] - cluster["bbox"][1]) * ( - cluster["bbox"][2] - cluster["bbox"][0] - ) - cell_area = ( - raw_cells[orph_id]["bbox"][3] - raw_cells[orph_id]["bbox"][1] - ) * (raw_cells[orph_id]["bbox"][2] - raw_cells[orph_id]["bbox"][0]) - - if ( - (iou_thresh < calc_iou) - and (cluster["confidence"] > confidence) - and (cell_area * 3 > cluster_area) - ): - cluster_chosen = cluster - iou_thresh = calc_iou - confidence = cluster["confidence"] - # If a candidate is found, assign to it the PDF cell ids, and tag that it was created by this function for tracking - if iou_thresh != 0.05 and confidence != 0.05: - cluster_chosen["cell_ids"].append(orph_id) - cluster_chosen["created_by"] = "orph_low_conf" - cluster_predictions.append(cluster_chosen) - orphan_cell_indices.remove(orph_id) - return cluster_predictions, orphan_cell_indices - - -def remove_ambigous_pdf_cell_by_conf(cluster_predictions, raw_cells, amb_cell_idxs): - for amb_cell_id in amb_cell_idxs: - highest_conf = 0 - highest_bbox_iou = 0 - cluster_chosen = None - problamatic_clusters = [] - - # Find clusters in question - for cluster in cluster_predictions: - - if amb_cell_id in cluster["cell_ids"]: - problamatic_clusters.append(amb_cell_id) - - # If the cell_id is in a cluster of high conf, and highest iou score, and smaller in area - bbox_iou_val = bb_iou(cluster["bbox"], raw_cells[amb_cell_id]["bbox"]) - - if ( - cluster["confidence"] > highest_conf - and bbox_iou_val > highest_bbox_iou - ): - cluster_chosen = cluster - highest_conf = cluster["confidence"] - highest_bbox_iou = bbox_iou_val - if cluster["id"] in problamatic_clusters: - problamatic_clusters.remove(cluster["id"]) - - # now remove the assigning of cell id from lower confidence, and threshold - for cluster in cluster_predictions: - for prob_amb_id in problamatic_clusters: - if prob_amb_id in cluster["cell_ids"]: - cluster["cell_ids"].remove(prob_amb_id) - amb_cell_idxs.remove(amb_cell_id) - - return cluster_predictions, amb_cell_idxs - - -def ranges(nums): - # Find if consecutive numbers exist within pdf cells - # Used to remove line numbers for review manuscripts - nums = sorted(set(nums)) - gaps = [[s, e] for s, e in zip(nums, nums[1:]) if s + 1 < e] - edges = iter(nums[:1] + sum(gaps, []) + nums[-1:]) - return list(zip(edges, edges)) - - -def set_orphan_as_text( - cluster_predictions, cluster_predictions_low, raw_cells, orphan_cell_indices -): - max_id = -1 - figures = [] - for cluster in cluster_predictions: - if cluster["type"] == DocItemLabel.PICTURE: - figures.append(cluster) - - if cluster["id"] > max_id: - max_id = cluster["id"] - max_id += 1 - - lines_detector = False - content_of_orphans = [] - for orph_id in orphan_cell_indices: - orph_cell = raw_cells[orph_id] - content_of_orphans.append(raw_cells[orph_id]["text"]) - - fil_content_of_orphans = [] - for cell_content in content_of_orphans: - if cell_content.isnumeric(): - try: - num = int(cell_content) - fil_content_of_orphans.append(num) - except ValueError: # ignore the cell - pass - - # line_orphans = [] - # Check if there are more than 2 pdf orphan cells, if there are more than 2, - # then check between the orphan cells if they are numeric - # and if they are a consecutive series of numbers (using ranges function) to decide - - if len(fil_content_of_orphans) > 2: - out_ranges = ranges(fil_content_of_orphans) - if len(out_ranges) > 1: - cnt_range = 0 - for ranges_ in out_ranges: - if ranges_[0] != ranges_[1]: - # If there are more than 75 (half the total line number of a review manuscript page) - # decide that there are line numbers on page to be ignored. - if len(list(range(ranges_[0], ranges_[1]))) > 75: - lines_detector = True - # line_orphans = line_orphans + list(range(ranges_[0], ranges_[1])) - - for orph_id in orphan_cell_indices: - orph_cell = raw_cells[orph_id] - if bool(orph_cell["text"] and not orph_cell["text"].isspace()): - fig_flag = False - # Do not assign orphan cells if they are inside a figure - for fig in figures: - if contains(fig["bbox"], orph_cell["bbox"]): - fig_flag = True - - # if fig_flag == False and raw_cells[orph_id]["text"] not in line_orphans: - if fig_flag == False and lines_detector == False: - # get class from low confidence detections if not set as text: - class_type = DocItemLabel.TEXT - - for cluster in cluster_predictions_low: - intersection = compute_intersection( - orph_cell["bbox"], cluster["bbox"] - ) - class_type = DocItemLabel.TEXT - if ( - cluster["confidence"] > 0.1 - and bb_iou(cluster["bbox"], orph_cell["bbox"]) > 0.4 - ): - class_type = cluster["type"] - elif contains( - cluster["bbox"], - [ - orph_cell["bbox"][0] + 3, - orph_cell["bbox"][1] + 3, - orph_cell["bbox"][2] - 3, - orph_cell["bbox"][3] - 3, - ], - ): - class_type = cluster["type"] - elif intersection > area(orph_cell["bbox"]) * 0.2: - class_type = cluster["type"] - - new_cluster = { - "id": max_id, - "bbox": orph_cell["bbox"], - "type": class_type, - "cell_ids": [orph_id], - "confidence": -1, - "created_by": "orphan_default", - } - max_id += 1 - cluster_predictions.append(new_cluster) - return cluster_predictions, orphan_cell_indices - - -def merge_cells(cluster_predictions): - # Using graph component creates clusters if orphan cells are touching or too close. - G = nx.Graph() - for cluster in cluster_predictions: - if cluster["created_by"] == "orphan_default": - G.add_node(cluster["id"]) - - for cluster_1 in cluster_predictions: - for cluster_2 in cluster_predictions: - if ( - cluster_1["id"] != cluster_2["id"] - and cluster_2["created_by"] == "orphan_default" - and cluster_1["created_by"] == "orphan_default" - ): - cl1 = copy.deepcopy(cluster_1["bbox"]) - cl2 = copy.deepcopy(cluster_2["bbox"]) - cl1[0] = cl1[0] - 2 - cl1[1] = cl1[1] - 2 - cl1[2] = cl1[2] + 2 - cl1[3] = cl1[3] + 2 - cl2[0] = cl2[0] - 2 - cl2[1] = cl2[1] - 2 - cl2[2] = cl2[2] + 2 - cl2[3] = cl2[3] + 2 - if is_intersecting(cl1, cl2): - G.add_edge(cluster_1["id"], cluster_2["id"]) - - component = sorted(map(sorted, nx.k_edge_components(G, k=1))) - max_id = -1 - for cluster_1 in cluster_predictions: - if cluster_1["id"] > max_id: - max_id = cluster_1["id"] - - for nodes in component: - if len(nodes) > 1: - max_id += 1 - lines = [] - for node in nodes: - for cluster in cluster_predictions: - if cluster["id"] == node: - lines.append(cluster) - cluster_predictions.remove(cluster) - new_merged_cluster = build_cluster_from_lines( - lines, DocItemLabel.TEXT, max_id - ) - cluster_predictions.append(new_merged_cluster) - return cluster_predictions - - -def clean_up_clusters( - cluster_predictions, - raw_cells, - merge_cells=False, - img_table=False, - one_cell_table=False, -): - DuplicateDeletedClusterIDs = [] - - for cluster_1 in cluster_predictions: - for cluster_2 in cluster_predictions: - if cluster_1["id"] != cluster_2["id"]: - # remove any artifcats created by merging clusters - if merge_cells == True: - if contains( - cluster_1["bbox"], - [ - cluster_2["bbox"][0] + 3, - cluster_2["bbox"][1] + 3, - cluster_2["bbox"][2] - 3, - cluster_2["bbox"][3] - 3, - ], - ): - cluster_1["cell_ids"] = ( - cluster_1["cell_ids"] + cluster_2["cell_ids"] - ) - DuplicateDeletedClusterIDs.append(cluster_2["id"]) - # remove clusters that might appear inside tables, or images (such as pdf cells in graphs) - elif img_table == True: - if ( - cluster_1["type"] == DocItemLabel.TEXT - and cluster_2["type"] == DocItemLabel.PICTURE - or cluster_2["type"] == DocItemLabel.TABLE - ): - if bb_iou(cluster_1["bbox"], cluster_2["bbox"]) > 0.5: - DuplicateDeletedClusterIDs.append(cluster_1["id"]) - elif contains( - [ - cluster_2["bbox"][0] - 3, - cluster_2["bbox"][1] - 3, - cluster_2["bbox"][2] + 3, - cluster_2["bbox"][3] + 3, - ], - cluster_1["bbox"], - ): - DuplicateDeletedClusterIDs.append(cluster_1["id"]) - # remove tables that have one pdf cell - if one_cell_table == True: - if ( - cluster_1["type"] == DocItemLabel.TABLE - and len(cluster_1["cell_ids"]) < 2 - ): - DuplicateDeletedClusterIDs.append(cluster_1["id"]) - - DuplicateDeletedClusterIDs = list(set(DuplicateDeletedClusterIDs)) - - for cl_id in DuplicateDeletedClusterIDs: - for cluster in cluster_predictions: - if cl_id == cluster["id"]: - cluster_predictions.remove(cluster) - return cluster_predictions - - -def assigning_cell_ids_to_clusters(clusters, raw_cells, threshold): - for cluster in clusters: - cells_in_cluster, _ = compute_enclosed_cells( - cluster["bbox"], raw_cells, min_cell_intersection_with_cluster=threshold - ) - cluster["cell_ids"] = cells_in_cluster - ## These cell_ids are ids of the raw cells. - ## They are often, but not always, the same as the "id" or the index of the "cells" list in a prediction. - return clusters - - -# Creates a map of cell_id->cluster_id -def cell_id_state_map(clusters, cell_count): - clusters_around_cells = find_clusters_around_cells(cell_count, clusters) - orphan_cell_indices = [ - ix for ix in range(cell_count) if len(clusters_around_cells[ix]) == 0 - ] # which cells are assigned no cluster? - ambiguous_cell_indices = [ - ix for ix in range(cell_count) if len(clusters_around_cells[ix]) > 1 - ] # which cells are assigned > 1 clusters? - return clusters_around_cells, orphan_cell_indices, ambiguous_cell_indices diff --git a/docs/examples/custom_convert.py b/docs/examples/custom_convert.py index 2d300904..12893e22 100644 --- a/docs/examples/custom_convert.py +++ b/docs/examples/custom_convert.py @@ -74,6 +74,10 @@ def main(): pipeline_options.do_ocr = True pipeline_options.do_table_structure = True pipeline_options.table_structure_options.do_cell_matching = True + pipeline_options.ocr_options.lang = "es" + pipeline_options.accelerator_options = AcceleratorOptions( + num_threads=4, device=Device.AUTO + ) doc_converter = DocumentConverter( format_options={