diff --git a/docling/datamodel/base_models.py b/docling/datamodel/base_models.py index e6b3607f..dd6291ab 100644 --- a/docling/datamodel/base_models.py +++ b/docling/datamodel/base_models.py @@ -122,7 +122,6 @@ class Cluster(BaseModel): bbox: BoundingBox confidence: float = 1.0 cells: List[Cell] = [] - children: List["Cluster"] = [] # Add child cluster support class BasePageElement(BaseModel): @@ -137,12 +136,6 @@ class LayoutPrediction(BaseModel): clusters: List[Cluster] = [] -class ContainerElement( - BasePageElement -): # Used for Form and Key-Value-Regions, only for typing. - pass - - class Table(BasePageElement): otsl_seq: List[str] num_rows: int = 0 @@ -182,7 +175,7 @@ class PagePredictions(BaseModel): equations_prediction: Optional[EquationPrediction] = None -PageElement = Union[TextElement, Table, FigureElement, ContainerElement] +PageElement = Union[TextElement, Table, FigureElement] class AssembledUnit(BaseModel): diff --git a/docling/datamodel/document.py b/docling/datamodel/document.py index bae4ab2a..f8dec5cb 100644 --- a/docling/datamodel/document.py +++ b/docling/datamodel/document.py @@ -78,8 +78,6 @@ layout_label_to_ds_type = { DocItemLabel.PICTURE: "figure", DocItemLabel.TEXT: "paragraph", DocItemLabel.PARAGRAPH: "paragraph", - DocItemLabel.FORM: DocItemLabel.FORM.value, - DocItemLabel.KEY_VALUE_REGION: DocItemLabel.KEY_VALUE_REGION.value, } _EMPTY_DOCLING_DOC = DoclingDocument(name="dummy") diff --git a/docling/models/ds_glm_model.py b/docling/models/ds_glm_model.py index 6f7de07a..be0cf487 100644 --- a/docling/models/ds_glm_model.py +++ b/docling/models/ds_glm_model.py @@ -22,15 +22,9 @@ from docling_core.types.legacy_doc.document import ( from docling_core.types.legacy_doc.document import CCSFileInfoObject as DsFileInfoObject from docling_core.types.legacy_doc.document import ExportedCCSDocument as DsDocument from PIL import ImageDraw -from pydantic import BaseModel, ConfigDict, TypeAdapter +from pydantic import BaseModel, ConfigDict -from docling.datamodel.base_models import ( - Cluster, - ContainerElement, - FigureElement, - Table, - TextElement, -) +from docling.datamodel.base_models import Cluster, FigureElement, Table, TextElement from docling.datamodel.document import ConversionResult, layout_label_to_ds_type from docling.datamodel.settings import settings from docling.utils.glm_utils import to_docling_document @@ -210,31 +204,7 @@ class GlmModel: ) ], obj_type=layout_label_to_ds_type.get(element.label), - payload={ - "children": TypeAdapter(List[Cluster]).dump_python( - element.cluster.children - ) - }, # hack to channel child clusters through GLM - ) - ) - elif isinstance(element, ContainerElement): - main_text.append( - BaseText( - text="", - payload={ - "children": TypeAdapter(List[Cluster]).dump_python( - element.cluster.children - ) - }, # hack to channel child clusters through GLM - obj_type=layout_label_to_ds_type.get(element.label), - name=element.label, - prov=[ - Prov( - bbox=target_bbox, - page=element.page_no + 1, - span=[0, 0], - ) - ], + # data=[[]], ) ) diff --git a/docling/models/layout_model.py b/docling/models/layout_model.py index eb96fd0d..4f7c1bb1 100644 --- a/docling/models/layout_model.py +++ b/docling/models/layout_model.py @@ -7,8 +7,9 @@ from typing import Iterable, List from docling_core.types.doc import CoordOrigin, DocItemLabel from docling_ibm_models.layoutmodel.layout_predictor import LayoutPredictor -from PIL import Image, ImageDraw +from PIL import ImageDraw +import docling.utils.layout_utils as lu from docling.datamodel.base_models import ( BoundingBox, Cell, @@ -21,7 +22,6 @@ from docling.datamodel.pipeline_options import AcceleratorOptions from docling.datamodel.settings import settings from docling.models.base_model import BasePageModel from docling.utils.accelerator_utils import decide_device -from docling.utils.layout_postprocessor import LayoutPostprocessor from docling.utils.profiling import TimeRecorder _log = logging.getLogger(__name__) @@ -47,111 +47,241 @@ class LayoutModel(BasePageModel): TABLE_LABEL = DocItemLabel.TABLE FIGURE_LABEL = DocItemLabel.PICTURE FORMULA_LABEL = DocItemLabel.FORMULA - CONTAINER_LABELS = [DocItemLabel.FORM, DocItemLabel.KEY_VALUE_REGION] def __init__(self, artifacts_path: Path, accelerator_options: AcceleratorOptions): device = decide_device(accelerator_options.device) self.layout_predictor = LayoutPredictor( - artifacts_path, device, accelerator_options.num_threads + artifact_path=artifacts_path, + device=device, + num_threads=accelerator_options.num_threads, + base_threshold=0.6, + blacklist_classes={"Form", "Key-Value Region"}, ) - def draw_clusters_and_cells_side_by_side( - self, conv_res, page, clusters, mode_prefix: str, show: bool = False - ): - """ - Draws a page image side by side with clusters filtered into two categories: - - Left: Clusters excluding FORM, KEY_VALUE_REGION, and PICTURE. - - Right: Clusters including FORM, KEY_VALUE_REGION, and PICTURE. - """ - label_to_color = { - DocItemLabel.TEXT: (255, 255, 153), # Light Yellow - DocItemLabel.CAPTION: (255, 204, 153), # Light Orange - DocItemLabel.LIST_ITEM: (153, 153, 255), # Light Purple - DocItemLabel.FORMULA: (192, 192, 192), # Gray - DocItemLabel.TABLE: (255, 204, 204), # Light Pink - DocItemLabel.PICTURE: (255, 204, 164), # Light Beige - DocItemLabel.SECTION_HEADER: (255, 153, 153), # Light Red - DocItemLabel.PAGE_HEADER: (204, 255, 204), # Light Green - DocItemLabel.PAGE_FOOTER: ( - 204, - 255, - 204, - ), # Light Green (same as Page-Header) - DocItemLabel.TITLE: (255, 153, 153), # Light Red (same as Section-Header) - DocItemLabel.FOOTNOTE: (200, 200, 255), # Light Blue - DocItemLabel.DOCUMENT_INDEX: (220, 220, 220), # Light Gray - DocItemLabel.CODE: (255, 223, 186), # Peach - DocItemLabel.CHECKBOX_SELECTED: (255, 182, 193), # Pale Green - DocItemLabel.CHECKBOX_UNSELECTED: (255, 182, 193), # Light Pink - DocItemLabel.FORM: (200, 255, 255), # Light Cyan - DocItemLabel.KEY_VALUE_REGION: (183, 65, 14), # Rusty orange + def postprocess(self, clusters_in: List[Cluster], cells: List[Cell], page_height): + MIN_INTERSECTION = 0.2 + CLASS_THRESHOLDS = { + DocItemLabel.CAPTION: 0.35, + DocItemLabel.FOOTNOTE: 0.35, + DocItemLabel.FORMULA: 0.35, + DocItemLabel.LIST_ITEM: 0.35, + DocItemLabel.PAGE_FOOTER: 0.35, + DocItemLabel.PAGE_HEADER: 0.35, + DocItemLabel.PICTURE: 0.2, # low threshold adjust to capture chemical structures for examples. + DocItemLabel.SECTION_HEADER: 0.45, + DocItemLabel.TABLE: 0.35, + DocItemLabel.TEXT: 0.45, + DocItemLabel.TITLE: 0.45, + DocItemLabel.DOCUMENT_INDEX: 0.45, + DocItemLabel.CODE: 0.45, + DocItemLabel.CHECKBOX_SELECTED: 0.45, + DocItemLabel.CHECKBOX_UNSELECTED: 0.45, + DocItemLabel.FORM: 0.45, + DocItemLabel.KEY_VALUE_REGION: 0.45, } - # Filter clusters for left and right images - exclude_labels = { - DocItemLabel.FORM, - DocItemLabel.KEY_VALUE_REGION, - DocItemLabel.PICTURE, + CLASS_REMAPPINGS = { + DocItemLabel.DOCUMENT_INDEX: DocItemLabel.TABLE, + DocItemLabel.TITLE: DocItemLabel.SECTION_HEADER, } - left_clusters = [c for c in clusters if c.label not in exclude_labels] - right_clusters = [c for c in clusters if c.label in exclude_labels] - # Create a deep copy of the original image for both sides - left_image = copy.deepcopy(page.image) - right_image = copy.deepcopy(page.image) + _log.debug("================= Start postprocess function ====================") + start_time = time.time() + # Apply Confidence Threshold to cluster predictions + # confidence = self.conf_threshold + clusters_mod = [] - # Function to draw clusters on an image - def draw_clusters(image, clusters): - draw = ImageDraw.Draw(image, "RGBA") - for c_tl in clusters: - all_clusters = [c_tl, *c_tl.children] - for c in all_clusters: - cell_color = (0, 0, 0, 40) # Transparent black for cells - for tc in c.cells: - cx0, cy0, cx1, cy1 = tc.bbox.as_tuple() - draw.rectangle( - [(cx0, cy0), (cx1, cy1)], - outline=None, - fill=cell_color, - ) + for cluster in clusters_in: + confidence = CLASS_THRESHOLDS[cluster.label] + if cluster.confidence >= confidence: + # annotation["created_by"] = "high_conf_pred" - x0, y0, x1, y1 = c.bbox.as_tuple() - cluster_fill_color = ( - *list(label_to_color.get(c.label)), # type: ignore - 70, - ) - cluster_outline_color = ( - *list(label_to_color.get(c.label)), # type: ignore - 255, - ) - draw.rectangle( - [(x0, y0), (x1, y1)], - outline=cluster_outline_color, - fill=cluster_fill_color, - ) + # Remap class labels where needed. + if cluster.label in CLASS_REMAPPINGS.keys(): + cluster.label = CLASS_REMAPPINGS[cluster.label] + clusters_mod.append(cluster) - # Draw clusters on both images - draw_clusters(left_image, left_clusters) - draw_clusters(right_image, right_clusters) + # map to dictionary clusters and cells, with bottom left origin + clusters_orig = [ + { + "id": c.id, + "bbox": list( + c.bbox.to_bottom_left_origin(page_height).as_tuple() + ), # TODO + "confidence": c.confidence, + "cell_ids": [], + "type": c.label, + } + for c in clusters_in + ] - # Combine the images side by side - combined_width = left_image.width * 2 - combined_height = left_image.height - combined_image = Image.new("RGB", (combined_width, combined_height)) - combined_image.paste(left_image, (0, 0)) - combined_image.paste(right_image, (left_image.width, 0)) + clusters_out = [ + { + "id": c.id, + "bbox": list( + c.bbox.to_bottom_left_origin(page_height).as_tuple() + ), # TODO + "confidence": c.confidence, + "created_by": "high_conf_pred", + "cell_ids": [], + "type": c.label, + } + for c in clusters_mod + ] - if show: - combined_image.show() - else: - out_path: Path = ( - Path(settings.debug.debug_output_path) - / f"debug_{conv_res.input.file.stem}" + del clusters_mod + + raw_cells = [ + { + "id": c.id, + "bbox": list( + c.bbox.to_bottom_left_origin(page_height).as_tuple() + ), # TODO + "text": c.text, + } + for c in cells + ] + cell_count = len(raw_cells) + + _log.debug("---- 0. Treat cluster overlaps ------") + clusters_out = lu.remove_cluster_duplicates_by_conf(clusters_out, 0.8) + + _log.debug( + "---- 1. Initially assign cells to clusters based on minimum intersection ------" + ) + ## Check for cells included in or touched by clusters: + clusters_out = lu.assigning_cell_ids_to_clusters( + clusters_out, raw_cells, MIN_INTERSECTION + ) + + _log.debug("---- 2. Assign Orphans with Low Confidence Detections") + # Creates a map of cell_id->cluster_id + ( + clusters_around_cells, + orphan_cell_indices, + ambiguous_cell_indices, + ) = lu.cell_id_state_map(clusters_out, cell_count) + + # Assign orphan cells with lower confidence predictions + clusters_out, orphan_cell_indices = lu.assign_orphans_with_low_conf_pred( + clusters_out, clusters_orig, raw_cells, orphan_cell_indices + ) + + # Refresh the cell_ids assignment, after creating new clusters using low conf predictions + clusters_out = lu.assigning_cell_ids_to_clusters( + clusters_out, raw_cells, MIN_INTERSECTION + ) + + _log.debug("---- 3. Settle Ambigous Cells") + # Creates an update map after assignment of cell_id->cluster_id + ( + clusters_around_cells, + orphan_cell_indices, + ambiguous_cell_indices, + ) = lu.cell_id_state_map(clusters_out, cell_count) + + # Settle pdf cells that belong to multiple clusters + clusters_out, ambiguous_cell_indices = lu.remove_ambigous_pdf_cell_by_conf( + clusters_out, raw_cells, ambiguous_cell_indices + ) + + _log.debug("---- 4. Set Orphans as Text") + ( + clusters_around_cells, + orphan_cell_indices, + ambiguous_cell_indices, + ) = lu.cell_id_state_map(clusters_out, cell_count) + + clusters_out, orphan_cell_indices = lu.set_orphan_as_text( + clusters_out, clusters_orig, raw_cells, orphan_cell_indices + ) + + _log.debug("---- 5. Merge Cells & and adapt the bounding boxes") + # Merge cells orphan cells + clusters_out = lu.merge_cells(clusters_out) + + # Clean up clusters that remain from merged and unreasonable clusters + clusters_out = lu.clean_up_clusters( + clusters_out, + raw_cells, + merge_cells=True, + img_table=True, + one_cell_table=True, + ) + + new_clusters = lu.adapt_bboxes(raw_cells, clusters_out, orphan_cell_indices) + clusters_out = new_clusters + + ## We first rebuild where every cell is now: + ## Now we write into a prediction cells list, not into the raw cells list. + ## As we don't need previous labels, we best overwrite any old list, because that might + ## have been sorted differently. + ( + clusters_around_cells, + orphan_cell_indices, + ambiguous_cell_indices, + ) = lu.cell_id_state_map(clusters_out, cell_count) + + target_cells = [] + for ix, cell in enumerate(raw_cells): + new_cell = { + "id": ix, + "rawcell_id": ix, + "label": "None", + "bbox": cell["bbox"], + "text": cell["text"], + } + for cluster_index in clusters_around_cells[ + ix + ]: # By previous analysis, this is always 1 cluster. + new_cell["label"] = clusters_out[cluster_index]["type"] + target_cells.append(new_cell) + # _log.debug("New label of cell " + str(ix) + " is " + str(new_cell["label"])) + cells_out = target_cells + + ## ------------------------------- + ## Sort clusters into reasonable reading order, and sort the cells inside each cluster + _log.debug("---- 5. Sort clusters in reading order ------") + sorted_clusters = lu.produce_reading_order( + clusters_out, "raw_cell_ids", "raw_cell_ids", True + ) + clusters_out = sorted_clusters + + # end_time = timer() + _log.debug("---- End of postprocessing function ------") + end_time = time.time() - start_time + _log.debug(f"Finished post processing in seconds={end_time:.3f}") + + cells_out_new = [ + Cell( + id=c["id"], # type: ignore + bbox=BoundingBox.from_tuple( + coord=c["bbox"], origin=CoordOrigin.BOTTOMLEFT # type: ignore + ).to_top_left_origin(page_height), + text=c["text"], # type: ignore ) - out_path.mkdir(parents=True, exist_ok=True) + for c in cells_out + ] - out_file = out_path / f"{mode_prefix}_layout_page_{page.page_no:05}.png" - combined_image.save(str(out_file), format="png") + del cells_out + + clusters_out_new = [] + for c in clusters_out: + cluster_cells = [ + ccell for ccell in cells_out_new if ccell.id in c["cell_ids"] # type: ignore + ] + c_new = Cluster( + id=c["id"], # type: ignore + bbox=BoundingBox.from_tuple( + coord=c["bbox"], origin=CoordOrigin.BOTTOMLEFT # type: ignore + ).to_top_left_origin(page_height), + confidence=c["confidence"], # type: ignore + label=DocItemLabel(c["type"]), + cells=cluster_cells, + ) + clusters_out_new.append(c_new) + + return clusters_out_new, cells_out_new def __call__( self, conv_res: ConversionResult, page_batch: Iterable[Page] @@ -184,78 +314,43 @@ class LayoutModel(BasePageModel): ) clusters.append(cluster) - # DEBUG code: - def draw_clusters_and_cells( - clusters, mode_prefix: str, show: bool = False - ): - label_to_color = { - DocItemLabel.TEXT: (255, 255, 153), # Light Yellow - DocItemLabel.CAPTION: (255, 204, 153), # Light Orange - DocItemLabel.LIST_ITEM: (153, 153, 255), # Light Purple - DocItemLabel.FORMULA: (192, 192, 192), # Gray - DocItemLabel.TABLE: (255, 204, 204), # Light Pink - DocItemLabel.PICTURE: (255, 255, 204), # Light Beige - DocItemLabel.SECTION_HEADER: (255, 153, 153), # Light Red - DocItemLabel.PAGE_HEADER: (204, 255, 204), # Light Green - DocItemLabel.PAGE_FOOTER: ( - 204, - 255, - 204, - ), # Light Green (same as Page-Header) - DocItemLabel.TITLE: ( - 255, - 153, - 153, - ), # Light Red (same as Section-Header) - DocItemLabel.FOOTNOTE: (200, 200, 255), # Light Blue - DocItemLabel.DOCUMENT_INDEX: (220, 220, 220), # Light Gray - DocItemLabel.CODE: (255, 223, 186), # Peach - DocItemLabel.CHECKBOX_SELECTED: ( - 255, - 182, - 193, - ), # Pale Green - DocItemLabel.CHECKBOX_UNSELECTED: ( - 255, - 182, - 193, - ), # Light Pink - DocItemLabel.FORM: (200, 255, 255), # Light Cyan - DocItemLabel.KEY_VALUE_REGION: ( - 183, - 65, - 14, - ), # Rusty orange - } + # Map cells to clusters + # TODO: Remove, postprocess should take care of it anyway. + for cell in page.cells: + for cluster in clusters: + if not cell.bbox.area() > 0: + overlap_frac = 0.0 + else: + overlap_frac = ( + cell.bbox.intersection_area_with(cluster.bbox) + / cell.bbox.area() + ) + if overlap_frac > 0.5: + cluster.cells.append(cell) + + # Pre-sort clusters + # clusters = self.sort_clusters_by_cell_order(clusters) + + # DEBUG code: + def draw_clusters_and_cells(show: bool = False): image = copy.deepcopy(page.image) if image is not None: - draw = ImageDraw.Draw(image, "RGBA") + draw = ImageDraw.Draw(image) for c in clusters: - cell_color = (0, 0, 0, 40) - for tc in c.cells: # [:1]: - cx0, cy0, cx1, cy1 = tc.bbox.as_tuple() - draw.rectangle( - [(cx0, cy0), (cx1, cy1)], - outline=None, - fill=cell_color, - ) - x0, y0, x1, y1 = c.bbox.as_tuple() - cluster_fill_color = ( - *list(label_to_color.get(c.label)), # type: ignore - 70, - ) - cluster_outline_color = ( - *list(label_to_color.get(c.label)), # type: ignore - 255, - ) - draw.rectangle( - [(x0, y0), (x1, y1)], - outline=cluster_outline_color, - fill=cluster_fill_color, - ) + draw.rectangle([(x0, y0), (x1, y1)], outline="green") + cell_color = ( + random.randint(30, 140), + random.randint(30, 140), + random.randint(30, 140), + ) + for tc in c.cells: # [:1]: + x0, y0, x1, y1 = tc.bbox.as_tuple() + draw.rectangle( + [(x0, y0), (x1, y1)], outline=cell_color + ) if show: image.show() else: @@ -266,30 +361,19 @@ class LayoutModel(BasePageModel): out_path.mkdir(parents=True, exist_ok=True) out_file = ( - out_path - / f"{mode_prefix}_layout_page_{page.page_no:05}.png" + out_path / f"layout_page_{page.page_no:05}.png" ) image.save(str(out_file), format="png") - if settings.debug.visualize_raw_layout: - self.draw_clusters_and_cells_side_by_side( - conv_res, page, clusters, mode_prefix="raw" - ) + # draw_clusters_and_cells() - # Apply postprocessing - processed_clusters, processed_cells = LayoutPostprocessor( - page.cells, clusters - ).postprocess() - # processed_clusters, processed_cells = clusters, page.cells - - page.cells = processed_cells - page.predictions.layout = LayoutPrediction( - clusters=processed_clusters + clusters, page.cells = self.postprocess( + clusters, page.cells, page.size.height ) + page.predictions.layout = LayoutPrediction(clusters=clusters) + if settings.debug.visualize_layout: - self.draw_clusters_and_cells_side_by_side( - conv_res, page, processed_clusters, mode_prefix="postprocessed" - ) + draw_clusters_and_cells() yield page diff --git a/docling/models/page_assemble_model.py b/docling/models/page_assemble_model.py index 4c27400f..9b064ead 100644 --- a/docling/models/page_assemble_model.py +++ b/docling/models/page_assemble_model.py @@ -6,7 +6,6 @@ from pydantic import BaseModel from docling.datamodel.base_models import ( AssembledUnit, - ContainerElement, FigureElement, Page, PageElement, @@ -160,15 +159,6 @@ class PageAssembleModel(BasePageModel): ) elements.append(equation) body.append(equation) - elif cluster.label in LayoutModel.CONTAINER_LABELS: - container_el = ContainerElement( - label=cluster.label, - id=cluster.id, - page_no=page.page_no, - cluster=cluster, - ) - elements.append(container_el) - body.append(container_el) page.assembled = AssembledUnit( elements=elements, headers=headers, body=body diff --git a/docling/pipeline/standard_pdf_pipeline.py b/docling/pipeline/standard_pdf_pipeline.py index 68a4dcf2..2f8c1421 100644 --- a/docling/pipeline/standard_pdf_pipeline.py +++ b/docling/pipeline/standard_pdf_pipeline.py @@ -106,7 +106,7 @@ class StandardPdfPipeline(PaginatedPipeline): repo_id="ds4sd/docling-models", force_download=force, local_dir=local_dir, - revision="refs/pr/2", + revision="v2.1.0", ) return Path(download_path) diff --git a/docling/utils/layout_postprocessor.py b/docling/utils/layout_postprocessor.py deleted file mode 100644 index ec6b9e77..00000000 --- a/docling/utils/layout_postprocessor.py +++ /dev/null @@ -1,499 +0,0 @@ -import bisect -import logging -import sys -from collections import defaultdict -from typing import Dict, List, Set, Tuple - -from docling_core.types.doc import DocItemLabel -from rtree import index - -from docling.datamodel.base_models import BoundingBox, Cell, Cluster - -_log = logging.getLogger(__name__) - - -class UnionFind: - """Efficient Union-Find data structure for grouping elements.""" - - def __init__(self, elements): - self.parent = {elem: elem for elem in elements} - self.rank = {elem: 0 for elem in elements} - - def find(self, x): - if self.parent[x] != x: - self.parent[x] = self.find(self.parent[x]) # Path compression - return self.parent[x] - - def union(self, x, y): - root_x, root_y = self.find(x), self.find(y) - if root_x == root_y: - return - - if self.rank[root_x] > self.rank[root_y]: - self.parent[root_y] = root_x - elif self.rank[root_x] < self.rank[root_y]: - self.parent[root_x] = root_y - else: - self.parent[root_y] = root_x - self.rank[root_x] += 1 - - def get_groups(self) -> Dict[int, List[int]]: - """Returns groups as {root: [elements]}.""" - groups = defaultdict(list) - for elem in self.parent: - groups[self.find(elem)].append(elem) - return groups - - -class SpatialClusterIndex: - """Efficient spatial indexing for clusters using R-tree and interval trees.""" - - def __init__(self, clusters: List[Cluster]): - p = index.Property() - p.dimension = 2 - self.spatial_index = index.Index(properties=p) - self.x_intervals = IntervalTree() - self.y_intervals = IntervalTree() - self.clusters_by_id: Dict[int, Cluster] = {} - - for cluster in clusters: - self.add_cluster(cluster) - - def add_cluster(self, cluster: Cluster): - bbox = cluster.bbox - self.spatial_index.insert(cluster.id, bbox.as_tuple()) - self.x_intervals.insert(bbox.l, bbox.r, cluster.id) - self.y_intervals.insert(bbox.t, bbox.b, cluster.id) - self.clusters_by_id[cluster.id] = cluster - - def remove_cluster(self, cluster: Cluster): - self.spatial_index.delete(cluster.id, cluster.bbox.as_tuple()) - del self.clusters_by_id[cluster.id] - - def find_candidates(self, bbox: BoundingBox) -> Set[int]: - """Find potential overlapping cluster IDs using all indexes.""" - spatial = set(self.spatial_index.intersection(bbox.as_tuple())) - x_candidates = self.x_intervals.find_containing( - bbox.l - ) | self.x_intervals.find_containing(bbox.r) - y_candidates = self.y_intervals.find_containing( - bbox.t - ) | self.y_intervals.find_containing(bbox.b) - return spatial | x_candidates | y_candidates - - def check_overlap( - self, - bbox1: BoundingBox, - bbox2: BoundingBox, - overlap_threshold: float, - containment_threshold: float, - ) -> bool: - """Check if two bboxes overlap sufficiently.""" - area1, area2 = bbox1.area(), bbox2.area() - if area1 <= 0 or area2 <= 0: - return False - - overlap_area = bbox1.intersection_area_with(bbox2) - if overlap_area <= 0: - return False - - iou = overlap_area / (area1 + area2 - overlap_area) - containment1 = overlap_area / area1 - containment2 = overlap_area / area2 - - return ( - iou > overlap_threshold - or containment1 > containment_threshold - or containment2 > containment_threshold - ) - - -class IntervalTree: - """Memory-efficient interval tree for 1D overlap queries.""" - - def __init__(self): - self.intervals: List[Tuple[float, float, int]] = ( - [] - ) # (min, max, id) sorted by min - - def insert(self, min_val: float, max_val: float, id: int): - bisect.insort(self.intervals, (min_val, max_val, id), key=lambda x: x[0]) - - def find_containing(self, point: float) -> Set[int]: - """Find all intervals containing the point.""" - pos = bisect.bisect_left(self.intervals, (point, float("-inf"), -1)) - result = set() - - # Check intervals starting before point - for min_val, max_val, id in reversed(self.intervals[:pos]): - if min_val <= point <= max_val: - result.add(id) - else: - break - - # Check intervals starting at/after point - for min_val, max_val, id in self.intervals[pos:]: - if point <= max_val: - if min_val <= point: - result.add(id) - else: - break - - return result - - -class LayoutPostprocessor: - """Postprocesses layout predictions by cleaning up clusters and mapping cells.""" - - # Cluster type-specific parameters for overlap resolution - OVERLAP_PARAMS = { - "regular": {"area_threshold": 1.3, "conf_threshold": 0.05}, - "picture": {"area_threshold": 2.0, "conf_threshold": 0.3}, - "wrapper": {"area_threshold": 2.0, "conf_threshold": 0.2}, - } - - WRAPPER_TYPES = {DocItemLabel.FORM, DocItemLabel.KEY_VALUE_REGION} - SPECIAL_TYPES = WRAPPER_TYPES | {DocItemLabel.PICTURE} - - CONFIDENCE_THRESHOLDS = { - DocItemLabel.CAPTION: 0.35, - DocItemLabel.FOOTNOTE: 0.35, - DocItemLabel.FORMULA: 0.35, - DocItemLabel.LIST_ITEM: 0.35, - DocItemLabel.PAGE_FOOTER: 0.35, - DocItemLabel.PAGE_HEADER: 0.35, - DocItemLabel.PICTURE: 0.1, - DocItemLabel.SECTION_HEADER: 0.45, - DocItemLabel.TABLE: 0.35, - DocItemLabel.TEXT: 0.45, - DocItemLabel.TITLE: 0.45, - DocItemLabel.CODE: 0.45, - DocItemLabel.CHECKBOX_SELECTED: 0.45, - DocItemLabel.CHECKBOX_UNSELECTED: 0.45, - DocItemLabel.FORM: 0.45, - DocItemLabel.KEY_VALUE_REGION: 0.45, - DocItemLabel.DOCUMENT_INDEX: 0.45, - } - - LABEL_REMAPPING = { - DocItemLabel.DOCUMENT_INDEX: DocItemLabel.TABLE, - DocItemLabel.TITLE: DocItemLabel.SECTION_HEADER, - } - - def __init__(self, cells: List[Cell], clusters: List[Cluster]): - """Initialize processor with cells and clusters.""" - """Initialize processor with cells and spatial indices.""" - self.cells = cells - self.regular_clusters = [ - c for c in clusters if c.label not in self.SPECIAL_TYPES - ] - self.special_clusters = [c for c in clusters if c.label in self.SPECIAL_TYPES] - - # Build spatial indices once - self.regular_index = SpatialClusterIndex(self.regular_clusters) - self.picture_index = SpatialClusterIndex( - [c for c in self.special_clusters if c.label == DocItemLabel.PICTURE] - ) - self.wrapper_index = SpatialClusterIndex( - [c for c in self.special_clusters if c.label in self.WRAPPER_TYPES] - ) - - def postprocess(self) -> Tuple[List[Cluster], List[Cell]]: - """Main processing pipeline.""" - self.regular_clusters = self._process_regular_clusters() - self.special_clusters = self._process_special_clusters() - - # Remove regular clusters that are included in wrappers - contained_ids = { - child.id - for wrapper in self.special_clusters - if wrapper.label in self.SPECIAL_TYPES - for child in wrapper.children - } - self.regular_clusters = [ - c for c in self.regular_clusters if c.id not in contained_ids - ] - - # Combine and sort final clusters - final_clusters = self._sort_clusters( - self.regular_clusters + self.special_clusters - ) - return final_clusters, self.cells - - def _process_regular_clusters(self) -> List[Cluster]: - """Process regular clusters with iterative refinement.""" - clusters = [ - c - for c in self.regular_clusters - if c.confidence >= self.CONFIDENCE_THRESHOLDS[c.label] - ] - - # Apply label remapping - for cluster in clusters: - if cluster.label in self.LABEL_REMAPPING: - cluster.label = self.LABEL_REMAPPING[cluster.label] - - # Initial cell assignment - clusters = self._assign_cells_to_clusters(clusters) - - # Remove clusters with no cells - clusters = [cluster for cluster in clusters if cluster.cells] - - # Handle orphaned cells - unassigned = self._find_unassigned_cells(clusters) - if unassigned: - next_id = max((c.id for c in clusters), default=0) + 1 - orphan_clusters = [ - Cluster( - id=next_id + i, - label=DocItemLabel.TEXT, - bbox=cell.bbox, - confidence=0.0, - cells=[cell], - ) - for i, cell in enumerate(unassigned) - ] - clusters.extend(orphan_clusters) - - # Iterative refinement - prev_count = len(clusters) + 1 - for _ in range(3): # Maximum 3 iterations - if prev_count == len(clusters): - break - prev_count = len(clusters) - clusters = self._adjust_cluster_bboxes(clusters) - clusters = self._remove_overlapping_clusters(clusters, "regular") - - return clusters - - def _process_special_clusters(self) -> List[Cluster]: - special_clusters = [ - c - for c in self.special_clusters - if c.confidence >= self.CONFIDENCE_THRESHOLDS[c.label] - ] - - for special in special_clusters: - contained = [] - for cluster in self.regular_clusters: - overlap = cluster.bbox.intersection_area_with(special.bbox) - if overlap > 0: - containment = overlap / cluster.bbox.area() - if containment > 0.8: - contained.append(cluster) - - if contained: - # Sort contained clusters by minimum cell ID - contained.sort( - key=lambda cluster: ( - min(cell.id for cell in cluster.cells) - if cluster.cells - else sys.maxsize - ) - ) - special.children = contained - - # Adjust bbox only for wrapper types - if special.label in self.WRAPPER_TYPES: - special.bbox = BoundingBox( - l=min(c.bbox.l for c in contained), - t=min(c.bbox.t for c in contained), - r=max(c.bbox.r for c in contained), - b=max(c.bbox.b for c in contained), - ) - - picture_clusters = [ - c for c in special_clusters if c.label == DocItemLabel.PICTURE - ] - picture_clusters = self._remove_overlapping_clusters( - picture_clusters, "picture" - ) - - wrapper_clusters = [ - c for c in special_clusters if c.label in self.WRAPPER_TYPES - ] - wrapper_clusters = self._remove_overlapping_clusters( - wrapper_clusters, "wrapper" - ) - - return picture_clusters + wrapper_clusters - - def _remove_overlapping_clusters( - self, - clusters: List[Cluster], - cluster_type: str, - overlap_threshold: float = 0.8, - containment_threshold: float = 0.8, - ) -> List[Cluster]: - if not clusters: - return [] - - spatial_index = ( - self.regular_index - if cluster_type == "regular" - else self.picture_index if cluster_type == "picture" else self.wrapper_index - ) - - # Map of currently valid clusters - valid_clusters = {c.id: c for c in clusters} - uf = UnionFind(valid_clusters.keys()) - params = self.OVERLAP_PARAMS[cluster_type] - - for cluster in clusters: - candidates = spatial_index.find_candidates(cluster.bbox) - candidates &= valid_clusters.keys() # Only keep existing candidates - candidates.discard(cluster.id) - - for other_id in candidates: - if spatial_index.check_overlap( - cluster.bbox, - valid_clusters[other_id].bbox, - overlap_threshold, - containment_threshold, - ): - uf.union(cluster.id, other_id) - - result = [] - for group in uf.get_groups().values(): - if len(group) == 1: - result.append(valid_clusters[group[0]]) - continue - - group_clusters = [valid_clusters[cid] for cid in group] - current_best = None - - for candidate in group_clusters: - should_select = True - for other in group_clusters: - if other == candidate: - continue - - area_ratio = candidate.bbox.area() / other.bbox.area() - conf_diff = other.confidence - candidate.confidence - - if ( - area_ratio <= params["area_threshold"] - and conf_diff > params["conf_threshold"] - ): - should_select = False - break - - if should_select: - if current_best is None or ( - candidate.bbox.area() > current_best.bbox.area() - and current_best.confidence - candidate.confidence - <= params["conf_threshold"] - ): - current_best = candidate - - best = current_best if current_best else group_clusters[0] - for cluster in group_clusters: - if cluster != best: - best.cells.extend(cluster.cells) - result.append(best) - - return result - - def _select_best_cluster( - self, - clusters: List[Cluster], - area_threshold: float, - conf_threshold: float, - ) -> Cluster: - """Iteratively select best cluster based on area and confidence thresholds.""" - current_best = None - for candidate in clusters: - should_select = True - for other in clusters: - if other == candidate: - continue - - area_ratio = candidate.bbox.area() / other.bbox.area() - conf_diff = other.confidence - candidate.confidence - - if area_ratio <= area_threshold and conf_diff > conf_threshold: - should_select = False - break - - if should_select: - if current_best is None or ( - candidate.bbox.area() > current_best.bbox.area() - and current_best.confidence - candidate.confidence <= conf_threshold - ): - current_best = candidate - - return current_best if current_best else clusters[0] - - def _assign_cells_to_clusters( - self, clusters: List[Cluster], min_overlap: float = 0.2 - ) -> List[Cluster]: - """Assign cells to best overlapping cluster.""" - for cluster in clusters: - cluster.cells = [] - - for cell in self.cells: - if not cell.text.strip(): - continue - - best_overlap = min_overlap - best_cluster = None - - for cluster in clusters: - if cell.bbox.area() <= 0: - continue - - overlap = cell.bbox.intersection_area_with(cluster.bbox) - overlap_ratio = overlap / cell.bbox.area() - - if overlap_ratio > best_overlap: - best_overlap = overlap_ratio - best_cluster = cluster - - if best_cluster is not None: - best_cluster.cells.append(cell) - - return clusters - - def _find_unassigned_cells(self, clusters: List[Cluster]) -> List[Cell]: - """Find cells not assigned to any cluster.""" - assigned = {cell.id for cluster in clusters for cell in cluster.cells} - return [ - cell for cell in self.cells if cell.id not in assigned and cell.text.strip() - ] - - def _adjust_cluster_bboxes(self, clusters: List[Cluster]) -> List[Cluster]: - """Adjust cluster bounding boxes to contain their cells.""" - for cluster in clusters: - if not cluster.cells: - continue - - cells_bbox = BoundingBox( - l=min(cell.bbox.l for cell in cluster.cells), - t=min(cell.bbox.t for cell in cluster.cells), - r=max(cell.bbox.r for cell in cluster.cells), - b=max(cell.bbox.b for cell in cluster.cells), - ) - - if cluster.label == DocItemLabel.TABLE: - # For tables, take union of current bbox and cells bbox - cluster.bbox = BoundingBox( - l=min(cluster.bbox.l, cells_bbox.l), - t=min(cluster.bbox.t, cells_bbox.t), - r=max(cluster.bbox.r, cells_bbox.r), - b=max(cluster.bbox.b, cells_bbox.b), - ) - else: - cluster.bbox = cells_bbox - - return clusters - - def _sort_clusters(self, clusters: List[Cluster]) -> List[Cluster]: - """Sort clusters in reading order (top-to-bottom, left-to-right).""" - - def reading_order_key(cluster: Cluster) -> Tuple[float, float]: - if cluster.cells and cluster.label != DocItemLabel.PICTURE: - first_cell = min(cluster.cells, key=lambda c: (c.bbox.t, c.bbox.l)) - return (first_cell.bbox.t, first_cell.bbox.l) - return (cluster.bbox.t, cluster.bbox.l) - - return sorted(clusters, key=reading_order_key) diff --git a/docling/utils/layout_utils.py b/docling/utils/layout_utils.py new file mode 100644 index 00000000..ceb18047 --- /dev/null +++ b/docling/utils/layout_utils.py @@ -0,0 +1,812 @@ +import copy +import logging + +import networkx as nx +from docling_core.types.doc import DocItemLabel + +logger = logging.getLogger("layout_utils") + + +## ------------------------------- +## Geometric helper functions +## The coordinates grow left to right, and bottom to top. +## The bounding box list elements 0 to 3 are x_left, y_bottom, x_right, y_top. + + +def area(bbox): + return (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) + + +def contains(bbox_i, bbox_j): + ## Returns True if bbox_i contains bbox_j, else False + return ( + bbox_i[0] <= bbox_j[0] + and bbox_i[1] <= bbox_j[1] + and bbox_i[2] >= bbox_j[2] + and bbox_i[3] >= bbox_j[3] + ) + + +def is_intersecting(bbox_i, bbox_j): + return not ( + bbox_i[2] < bbox_j[0] + or bbox_i[0] > bbox_j[2] + or bbox_i[3] < bbox_j[1] + or bbox_i[1] > bbox_j[3] + ) + + +def bb_iou(boxA, boxB): + # determine the (x, y)-coordinates of the intersection rectangle + xA = max(boxA[0], boxB[0]) + yA = max(boxA[1], boxB[1]) + xB = min(boxA[2], boxB[2]) + yB = min(boxA[3], boxB[3]) + # compute the area of intersection rectangle + interArea = max(0, xB - xA + 1) * max(0, yB - yA + 1) + # compute the area of both the prediction and ground-truth + # rectangles + boxAArea = (boxA[2] - boxA[0] + 1) * (boxA[3] - boxA[1] + 1) + boxBArea = (boxB[2] - boxB[0] + 1) * (boxB[3] - boxB[1] + 1) + # compute the intersection over union by taking the intersection + # area and dividing it by the sum of prediction + ground-truth + # areas - the interesection area + iou = interArea / float(boxAArea + boxBArea - interArea) + # return the intersection over union value + return iou + + +def compute_intersection(bbox_i, bbox_j): + ## Returns the size of the intersection area of the two boxes + if not is_intersecting(bbox_i, bbox_j): + return 0 + ## Determine the (x, y)-coordinates of the intersection rectangle: + xA = max(bbox_i[0], bbox_j[0]) + yA = max(bbox_i[1], bbox_j[1]) + xB = min(bbox_i[2], bbox_j[2]) + yB = min(bbox_i[3], bbox_j[3]) + ## Compute the area of intersection rectangle: + interArea = (xB - xA) * (yB - yA) + if interArea < 0: + logger.debug("Warning: Negative intersection detected!") + return 0 + return interArea + + +def surrounding(bbox_i, bbox_j): + ## Computes minimal box that contains both input boxes + sbox = [] + sbox.append(min(bbox_i[0], bbox_j[0])) + sbox.append(min(bbox_i[1], bbox_j[1])) + sbox.append(max(bbox_i[2], bbox_j[2])) + sbox.append(max(bbox_i[3], bbox_j[3])) + return sbox + + +def surrounding_list(bbox_list): + ## Computes minimal box that contains all boxes in the input list + ## The list should be non-empty, but just in case it's not: + if len(bbox_list) == 0: + sbox = [0, 0, 0, 0] + else: + sbox = [] + sbox.append(min([bbox[0] for bbox in bbox_list])) + sbox.append(min([bbox[1] for bbox in bbox_list])) + sbox.append(max([bbox[2] for bbox in bbox_list])) + sbox.append(max([bbox[3] for bbox in bbox_list])) + return sbox + + +def vertical_overlap(bboxA, bboxB): + ## bbox[1] is the lower bound, bbox[3] the upper bound (larger number) + if bboxB[3] < bboxA[1]: ## B below A + return False + elif bboxA[3] < bboxB[1]: ## A below B + return False + else: + return True + + +def vertical_overlap_fraction(bboxA, bboxB): + ## Returns the vertical overlap as fraction of the lower bbox height. + ## bbox[1] is the lower bound, bbox[3] the upper bound (larger number) + ## Height 0 is permitted in the input. + heightA = bboxA[3] - bboxA[1] + heightB = bboxB[3] - bboxB[1] + min_height = min(heightA, heightB) + if bboxA[3] >= bboxB[3]: ## A starts higher or equal + if ( + bboxA[1] <= bboxB[1] + ): ## B is completely in A; this can include height of B = 0: + fraction = 1 + else: + overlap = max(bboxB[3] - bboxA[1], 0) + fraction = overlap / max(min_height, 0.001) + else: + if ( + bboxB[1] <= bboxA[1] + ): ## A is completely in B; this can include height of A = 0: + fraction = 1 + else: + overlap = max(bboxA[3] - bboxB[1], 0) + fraction = overlap / max(min_height, 0.001) + return fraction + + +## ------------------------------- +## Cluster-and-cell relations + + +def compute_enclosed_cells( + cluster_bbox, raw_cells, min_cell_intersection_with_cluster=0.2 +): + cells_in_cluster = [] + cells_in_cluster_int = [] + for ix, cell in enumerate(raw_cells): + cell_bbox = cell["bbox"] + intersection = compute_intersection(cell_bbox, cluster_bbox) + frac_area = area(cell_bbox) * min_cell_intersection_with_cluster + + if ( + intersection > frac_area and frac_area > 0 + ): # intersect > certain fraction of cell + cells_in_cluster.append(ix) + cells_in_cluster_int.append(intersection) + elif contains( + cluster_bbox, + [cell_bbox[0] + 3, cell_bbox[1] + 3, cell_bbox[2] - 3, cell_bbox[3] - 3], + ): + cells_in_cluster.append(ix) + return cells_in_cluster, cells_in_cluster_int + + +def find_clusters_around_cells(cell_count, clusters): + ## Per raw cell, find to which clusters it belongs. + ## Return list of these indices in the raw-cell order. + clusters_around_cells = [[] for _ in range(cell_count)] + for cl_ix, cluster in enumerate(clusters): + for ix in cluster["cell_ids"]: + clusters_around_cells[ix].append(cl_ix) + return clusters_around_cells + + +def find_cell_index(raw_ix, cell_array): + ## "raw_ix" is a rawcell_id. + ## "cell_array" has the structure of an (annotation) cells array. + ## Returns index of cell in cell_array that has this rawcell_id. + for ix, cell in enumerate(cell_array): + if cell["rawcell_id"] == raw_ix: + return ix + + +def find_cell_indices(cluster, cell_array): + ## "cluster" must have the structure as in a clusters array in a prediction, + ## "cell_array" that of a cells array. + ## Returns list of indices of cells in cell_array that have the rawcell_ids as in the cluster, + ## in the order of the rawcell_ids. + result = [] + for raw_ix in sorted(cluster["cell_ids"]): + ## Find the cell with this rawcell_id (if any) + for ix, cell in enumerate(cell_array): + if cell["rawcell_id"] == raw_ix: + result.append(ix) + return result + + +def find_first_cell_index(cluster, cell_array): + ## "cluster" must be a dict with key "cell_ids"; it can also be a line. + ## "cell_array" has the structure of a cells array in an annotation. + ## Returns index of cell in cell_array that has the lowest rawcell_id from the cluster. + result = [] ## We keep it a list as it can be empty (picture without text cells) + if len(cluster["cell_ids"]) == 0: + return result + raw_ix = min(cluster["cell_ids"]) + ## Find the cell with this rawcell_id (if any) + for ix, cell in enumerate(cell_array): + if cell["rawcell_id"] == raw_ix: + result.append(ix) + break ## One is enough; should be only one anyway. + if result == []: + logger.debug( + " Warning: Raw cell " + str(raw_ix) + " not found in annotation cells" + ) + return result + + +## ------------------------------- +## Cluster labels and text + + +def relabel_cluster(cluster, cl_ix, new_label, target_pred): + ## "cluster" must have the structure as in a clusters array in a prediction, + ## "cl_ix" is its index in target_pred, + ## "new_label" is the intended new label, + ## "target_pred" is the entire current target prediction. + ## Sets label on the cluster itself, and on the cells in the target_pred. + ## Returns new_label so that also the cl_label variable in the main code is easily set. + target_pred["clusters"][cl_ix]["type"] = new_label + cluster_target_cells = find_cell_indices(cluster, target_pred["cells"]) + for ix in cluster_target_cells: + target_pred["cells"][ix]["label"] = new_label + return new_label + + +def find_cluster_text(cluster, raw_cells): + ## "cluster" must be a dict with "cell_ids"; it can also be a line. + ## "raw_cells" must have the format of item["raw"]["cells"] + ## Returns the text of the cluster, with blanks between the cell contents + ## (which seem to be words or phrases without starting or trailing blanks). + ## Note that in formulas, this may give a lot more blanks than originally + cluster_text = "" + for raw_ix in sorted(cluster["cell_ids"]): + cluster_text = cluster_text + raw_cells[raw_ix]["text"] + " " + return cluster_text.rstrip() + + +def find_cluster_text_without_blanks(cluster, raw_cells): + ## "cluster" must be a dict with "cell_ids"; it can also be a line. + ## "raw_cells" must have the format of item["raw"]["cells"] + ## Returns the text of the cluster, without blanks between the cell contents + ## Interesting in formula analysis. + cluster_text = "" + for raw_ix in sorted(cluster["cell_ids"]): + cluster_text = cluster_text + raw_cells[raw_ix]["text"] + return cluster_text.rstrip() + + +## ------------------------------- +## Clusters and lines +## (Most line-oriented functions are only needed in TextAnalysisGivenClusters, +## but this one also in FormulaAnalysis) + + +def build_cluster_from_lines(lines, label, id): + ## Lines must be a non-empty list of dicts (lines) with elements "cell_ids" and "bbox" + ## (There is no condition that they are really geometrically lines) + ## A cluster in standard format is returned with given label and id + local_lines = copy.deepcopy( + lines + ) ## without this, it changes "lines" also outside this function + first_line = local_lines.pop(0) + cluster = { + "id": id, + "type": label, + "cell_ids": first_line["cell_ids"], + "bbox": first_line["bbox"], + "confidence": 0, + "created_by": "merged_cells", + } + confidence = 0 + counter = 0 + for line in local_lines: + new_cell_ids = cluster["cell_ids"] + line["cell_ids"] + cluster["cell_ids"] = new_cell_ids + cluster["bbox"] = surrounding(cluster["bbox"], line["bbox"]) + counter += 1 + confidence += line["confidence"] + confidence = confidence / counter + cluster["confidence"] = confidence + return cluster + + +## ------------------------------- +## Reading order + + +def produce_reading_order(clusters, cluster_sort_type, cell_sort_type, sort_ids): + ## In: + ## Clusters: list as in predictions. + ## cluster_sort_type: string, currently only "raw_cells". + ## cell_sort_type: string, currently only "raw_cells". + ## sort_ids: Boolean, whether the cluster ids should be adapted to their new position + ## Out: Another clusters list, sorted according to the type. + + logger.debug("---- Start cluster sorting ------") + + if cell_sort_type == "raw_cell_ids": + for cl in clusters: + sorted_cell_ids = sorted(cl["cell_ids"]) + cl["cell_ids"] = sorted_cell_ids + else: + logger.debug( + "Unknown cell_sort_type `" + + cell_sort_type + + "`, no cell sorting will happen." + ) + + if cluster_sort_type == "raw_cell_ids": + clusters_with_cells = [cl for cl in clusters if cl["cell_ids"] != []] + clusters_without_cells = [cl for cl in clusters if cl["cell_ids"] == []] + logger.debug( + "Clusters with cells: " + str([cl["id"] for cl in clusters_with_cells]) + ) + logger.debug( + " Their first cell ids: " + + str([cl["cell_ids"][0] for cl in clusters_with_cells]) + ) + logger.debug( + "Clusters without cells: " + + str([cl["id"] for cl in clusters_without_cells]) + ) + clusters_with_cells_sorted = sorted( + clusters_with_cells, key=lambda cluster: cluster["cell_ids"][0] + ) + logger.debug( + " First cell ids after sorting: " + + str([cl["cell_ids"][0] for cl in clusters_with_cells_sorted]) + ) + sorted_clusters = clusters_with_cells_sorted + clusters_without_cells + else: + logger.debug( + "Unknown cluster_sort_type: `" + + cluster_sort_type + + "`, no cluster sorting will happen." + ) + + if sort_ids: + for i, cl in enumerate(sorted_clusters): + cl["id"] = i + return sorted_clusters + + +## ------------------------------- +## Line Splitting + + +def sort_cells_horizontal(line_cell_ids, raw_cells): + ## "line_cells" should be a non-empty list of (raw) cell_ids + ## "raw_cells" has the structure of item["raw"]["cells"]. + ## Sorts the cells in the line by x0 (left start). + new_line_cell_ids = sorted( + line_cell_ids, key=lambda cell_id: raw_cells[cell_id]["bbox"][0] + ) + return new_line_cell_ids + + +def adapt_bboxes(raw_cells, clusters, orphan_cell_indices): + new_clusters = [] + for ix, cluster in enumerate(clusters): + new_cluster = copy.deepcopy(cluster) + logger.debug( + "Treating cluster " + str(ix) + ", type " + str(new_cluster["type"]) + ) + logger.debug(" with cells: " + str(new_cluster["cell_ids"])) + if len(cluster["cell_ids"]) == 0 and cluster["type"] != DocItemLabel.PICTURE: + logger.debug(" Empty non-picture, removed") + continue ## Skip this former cluster, now without cells. + new_bbox = adapt_bbox(raw_cells, new_cluster, orphan_cell_indices) + new_cluster["bbox"] = new_bbox + new_clusters.append(new_cluster) + return new_clusters + + +def adapt_bbox(raw_cells, cluster, orphan_cell_indices): + if not (cluster["type"] in [DocItemLabel.TABLE, DocItemLabel.PICTURE]): + ## A text-like cluster. The bbox only needs to be around the text cells: + logger.debug(" Initial bbox: " + str(cluster["bbox"])) + new_bbox = surrounding_list( + [raw_cells[cid]["bbox"] for cid in cluster["cell_ids"]] + ) + logger.debug(" New bounding box:" + str(new_bbox)) + if cluster["type"] == DocItemLabel.PICTURE: + ## We only make the bbox completely comprise included text cells: + logger.debug(" Picture") + if len(cluster["cell_ids"]) != 0: + min_bbox = surrounding_list( + [raw_cells[cid]["bbox"] for cid in cluster["cell_ids"]] + ) + logger.debug(" Minimum bbox: " + str(min_bbox)) + logger.debug(" Initial bbox: " + str(cluster["bbox"])) + new_bbox = surrounding(min_bbox, cluster["bbox"]) + logger.debug(" New bbox (initial and text cells): " + str(new_bbox)) + else: + logger.debug(" without text cells, no change.") + new_bbox = cluster["bbox"] + else: ## A table + ## At least we have to keep the included text cells, and we make the bbox completely comprise them + min_bbox = surrounding_list( + [raw_cells[cid]["bbox"] for cid in cluster["cell_ids"]] + ) + logger.debug(" Minimum bbox: " + str(min_bbox)) + logger.debug(" Initial bbox: " + str(cluster["bbox"])) + new_bbox = surrounding(min_bbox, cluster["bbox"]) + logger.debug(" Possibly increased bbox: " + str(new_bbox)) + + ## Now we look which non-belonging cells are covered. + ## (To decrease dependencies, we don't make use of which cells we actually removed.) + ## We don't worry about orphan cells, those could still be added to the table. + enclosed_cells = compute_enclosed_cells( + new_bbox, raw_cells, min_cell_intersection_with_cluster=0.3 + )[0] + additional_cells = set(enclosed_cells) - set(cluster["cell_ids"]) + logger.debug( + " Additional cells enclosed by Table bbox: " + str(additional_cells) + ) + spurious_cells = additional_cells - set(orphan_cell_indices) + logger.debug( + " Spurious cells enclosed by Table bbox (additional minus orphans): " + + str(spurious_cells) + ) + if len(spurious_cells) == 0: + return new_bbox + + ## Else we want to keep as much as possible, e.g., grid lines, but not the spurious cells if we can. + ## We initialize possible cuts with the current bbox. + left_cut = new_bbox[0] + right_cut = new_bbox[2] + upper_cut = new_bbox[3] + lower_cut = new_bbox[1] + + for cell_ix in spurious_cells: + cell = raw_cells[cell_ix] + # logger.debug(" Spurious cell bbox: " + str(cell["bbox"])) + is_left = cell["bbox"][2] < min_bbox[0] + is_right = cell["bbox"][0] > min_bbox[2] + is_above = cell["bbox"][1] > min_bbox[3] + is_below = cell["bbox"][3] < min_bbox[1] + # logger.debug(" Left, right, above, below? " + str([is_left, is_right, is_above, is_below])) + + if is_left: + if cell["bbox"][2] > left_cut: + ## We move the left cut to exclude this cell: + left_cut = cell["bbox"][2] + if is_right: + if cell["bbox"][0] < right_cut: + ## We move the right cut to exclude this cell: + right_cut = cell["bbox"][0] + if is_above: + if cell["bbox"][1] < upper_cut: + ## We move the upper cut to exclude this cell: + upper_cut = cell["bbox"][1] + if is_below: + if cell["bbox"][3] > lower_cut: + ## We move the left cut to exclude this cell: + lower_cut = cell["bbox"][3] + # logger.debug(" Current bbox: " + str([left_cut, lower_cut, right_cut, upper_cut])) + + new_bbox = [left_cut, lower_cut, right_cut, upper_cut] + + logger.debug(" Final bbox: " + str(new_bbox)) + return new_bbox + + +def remove_cluster_duplicates_by_conf(cluster_predictions, threshold=0.5): + DuplicateDeletedClusterIDs = [] + for cluster_1 in cluster_predictions: + for cluster_2 in cluster_predictions: + if cluster_1["id"] != cluster_2["id"]: + if_conf = False + if cluster_1["confidence"] > cluster_2["confidence"]: + if_conf = True + if if_conf == True: + if bb_iou(cluster_1["bbox"], cluster_2["bbox"]) > threshold: + DuplicateDeletedClusterIDs.append(cluster_2["id"]) + elif contains( + cluster_1["bbox"], + [ + cluster_2["bbox"][0] + 3, + cluster_2["bbox"][1] + 3, + cluster_2["bbox"][2] - 3, + cluster_2["bbox"][3] - 3, + ], + ): + DuplicateDeletedClusterIDs.append(cluster_2["id"]) + + DuplicateDeletedClusterIDs = list(set(DuplicateDeletedClusterIDs)) + + for cl_id in DuplicateDeletedClusterIDs: + for cluster in cluster_predictions: + if cl_id == cluster["id"]: + cluster_predictions.remove(cluster) + return cluster_predictions + + +# Assign orphan cells by a low confidence prediction that is below the assigned confidence +def assign_orphans_with_low_conf_pred( + cluster_predictions, cluster_predictions_low, raw_cells, orphan_cell_indices +): + for orph_id in orphan_cell_indices: + cluster_chosen = {} + iou_thresh = 0.05 + confidence = 0.05 + + # Loop over all predictions, and find the one with the highest IOU, and confidence + for cluster in cluster_predictions_low: + calc_iou = bb_iou(cluster["bbox"], raw_cells[orph_id]["bbox"]) + cluster_area = (cluster["bbox"][3] - cluster["bbox"][1]) * ( + cluster["bbox"][2] - cluster["bbox"][0] + ) + cell_area = ( + raw_cells[orph_id]["bbox"][3] - raw_cells[orph_id]["bbox"][1] + ) * (raw_cells[orph_id]["bbox"][2] - raw_cells[orph_id]["bbox"][0]) + + if ( + (iou_thresh < calc_iou) + and (cluster["confidence"] > confidence) + and (cell_area * 3 > cluster_area) + ): + cluster_chosen = cluster + iou_thresh = calc_iou + confidence = cluster["confidence"] + # If a candidate is found, assign to it the PDF cell ids, and tag that it was created by this function for tracking + if iou_thresh != 0.05 and confidence != 0.05: + cluster_chosen["cell_ids"].append(orph_id) + cluster_chosen["created_by"] = "orph_low_conf" + cluster_predictions.append(cluster_chosen) + orphan_cell_indices.remove(orph_id) + return cluster_predictions, orphan_cell_indices + + +def remove_ambigous_pdf_cell_by_conf(cluster_predictions, raw_cells, amb_cell_idxs): + for amb_cell_id in amb_cell_idxs: + highest_conf = 0 + highest_bbox_iou = 0 + cluster_chosen = None + problamatic_clusters = [] + + # Find clusters in question + for cluster in cluster_predictions: + + if amb_cell_id in cluster["cell_ids"]: + problamatic_clusters.append(amb_cell_id) + + # If the cell_id is in a cluster of high conf, and highest iou score, and smaller in area + bbox_iou_val = bb_iou(cluster["bbox"], raw_cells[amb_cell_id]["bbox"]) + + if ( + cluster["confidence"] > highest_conf + and bbox_iou_val > highest_bbox_iou + ): + cluster_chosen = cluster + highest_conf = cluster["confidence"] + highest_bbox_iou = bbox_iou_val + if cluster["id"] in problamatic_clusters: + problamatic_clusters.remove(cluster["id"]) + + # now remove the assigning of cell id from lower confidence, and threshold + for cluster in cluster_predictions: + for prob_amb_id in problamatic_clusters: + if prob_amb_id in cluster["cell_ids"]: + cluster["cell_ids"].remove(prob_amb_id) + amb_cell_idxs.remove(amb_cell_id) + + return cluster_predictions, amb_cell_idxs + + +def ranges(nums): + # Find if consecutive numbers exist within pdf cells + # Used to remove line numbers for review manuscripts + nums = sorted(set(nums)) + gaps = [[s, e] for s, e in zip(nums, nums[1:]) if s + 1 < e] + edges = iter(nums[:1] + sum(gaps, []) + nums[-1:]) + return list(zip(edges, edges)) + + +def set_orphan_as_text( + cluster_predictions, cluster_predictions_low, raw_cells, orphan_cell_indices +): + max_id = -1 + figures = [] + for cluster in cluster_predictions: + if cluster["type"] == DocItemLabel.PICTURE: + figures.append(cluster) + + if cluster["id"] > max_id: + max_id = cluster["id"] + max_id += 1 + + lines_detector = False + content_of_orphans = [] + for orph_id in orphan_cell_indices: + orph_cell = raw_cells[orph_id] + content_of_orphans.append(raw_cells[orph_id]["text"]) + + fil_content_of_orphans = [] + for cell_content in content_of_orphans: + if cell_content.isnumeric(): + try: + num = int(cell_content) + fil_content_of_orphans.append(num) + except ValueError: # ignore the cell + pass + + # line_orphans = [] + # Check if there are more than 2 pdf orphan cells, if there are more than 2, + # then check between the orphan cells if they are numeric + # and if they are a consecutive series of numbers (using ranges function) to decide + + if len(fil_content_of_orphans) > 2: + out_ranges = ranges(fil_content_of_orphans) + if len(out_ranges) > 1: + cnt_range = 0 + for ranges_ in out_ranges: + if ranges_[0] != ranges_[1]: + # If there are more than 75 (half the total line number of a review manuscript page) + # decide that there are line numbers on page to be ignored. + if len(list(range(ranges_[0], ranges_[1]))) > 75: + lines_detector = True + # line_orphans = line_orphans + list(range(ranges_[0], ranges_[1])) + + for orph_id in orphan_cell_indices: + orph_cell = raw_cells[orph_id] + if bool(orph_cell["text"] and not orph_cell["text"].isspace()): + fig_flag = False + # Do not assign orphan cells if they are inside a figure + for fig in figures: + if contains(fig["bbox"], orph_cell["bbox"]): + fig_flag = True + + # if fig_flag == False and raw_cells[orph_id]["text"] not in line_orphans: + if fig_flag == False and lines_detector == False: + # get class from low confidence detections if not set as text: + class_type = DocItemLabel.TEXT + + for cluster in cluster_predictions_low: + intersection = compute_intersection( + orph_cell["bbox"], cluster["bbox"] + ) + class_type = DocItemLabel.TEXT + if ( + cluster["confidence"] > 0.1 + and bb_iou(cluster["bbox"], orph_cell["bbox"]) > 0.4 + ): + class_type = cluster["type"] + elif contains( + cluster["bbox"], + [ + orph_cell["bbox"][0] + 3, + orph_cell["bbox"][1] + 3, + orph_cell["bbox"][2] - 3, + orph_cell["bbox"][3] - 3, + ], + ): + class_type = cluster["type"] + elif intersection > area(orph_cell["bbox"]) * 0.2: + class_type = cluster["type"] + + new_cluster = { + "id": max_id, + "bbox": orph_cell["bbox"], + "type": class_type, + "cell_ids": [orph_id], + "confidence": -1, + "created_by": "orphan_default", + } + max_id += 1 + cluster_predictions.append(new_cluster) + return cluster_predictions, orphan_cell_indices + + +def merge_cells(cluster_predictions): + # Using graph component creates clusters if orphan cells are touching or too close. + G = nx.Graph() + for cluster in cluster_predictions: + if cluster["created_by"] == "orphan_default": + G.add_node(cluster["id"]) + + for cluster_1 in cluster_predictions: + for cluster_2 in cluster_predictions: + if ( + cluster_1["id"] != cluster_2["id"] + and cluster_2["created_by"] == "orphan_default" + and cluster_1["created_by"] == "orphan_default" + ): + cl1 = copy.deepcopy(cluster_1["bbox"]) + cl2 = copy.deepcopy(cluster_2["bbox"]) + cl1[0] = cl1[0] - 2 + cl1[1] = cl1[1] - 2 + cl1[2] = cl1[2] + 2 + cl1[3] = cl1[3] + 2 + cl2[0] = cl2[0] - 2 + cl2[1] = cl2[1] - 2 + cl2[2] = cl2[2] + 2 + cl2[3] = cl2[3] + 2 + if is_intersecting(cl1, cl2): + G.add_edge(cluster_1["id"], cluster_2["id"]) + + component = sorted(map(sorted, nx.k_edge_components(G, k=1))) + max_id = -1 + for cluster_1 in cluster_predictions: + if cluster_1["id"] > max_id: + max_id = cluster_1["id"] + + for nodes in component: + if len(nodes) > 1: + max_id += 1 + lines = [] + for node in nodes: + for cluster in cluster_predictions: + if cluster["id"] == node: + lines.append(cluster) + cluster_predictions.remove(cluster) + new_merged_cluster = build_cluster_from_lines( + lines, DocItemLabel.TEXT, max_id + ) + cluster_predictions.append(new_merged_cluster) + return cluster_predictions + + +def clean_up_clusters( + cluster_predictions, + raw_cells, + merge_cells=False, + img_table=False, + one_cell_table=False, +): + DuplicateDeletedClusterIDs = [] + + for cluster_1 in cluster_predictions: + for cluster_2 in cluster_predictions: + if cluster_1["id"] != cluster_2["id"]: + # remove any artifcats created by merging clusters + if merge_cells == True: + if contains( + cluster_1["bbox"], + [ + cluster_2["bbox"][0] + 3, + cluster_2["bbox"][1] + 3, + cluster_2["bbox"][2] - 3, + cluster_2["bbox"][3] - 3, + ], + ): + cluster_1["cell_ids"] = ( + cluster_1["cell_ids"] + cluster_2["cell_ids"] + ) + DuplicateDeletedClusterIDs.append(cluster_2["id"]) + # remove clusters that might appear inside tables, or images (such as pdf cells in graphs) + elif img_table == True: + if ( + cluster_1["type"] == DocItemLabel.TEXT + and cluster_2["type"] == DocItemLabel.PICTURE + or cluster_2["type"] == DocItemLabel.TABLE + ): + if bb_iou(cluster_1["bbox"], cluster_2["bbox"]) > 0.5: + DuplicateDeletedClusterIDs.append(cluster_1["id"]) + elif contains( + [ + cluster_2["bbox"][0] - 3, + cluster_2["bbox"][1] - 3, + cluster_2["bbox"][2] + 3, + cluster_2["bbox"][3] + 3, + ], + cluster_1["bbox"], + ): + DuplicateDeletedClusterIDs.append(cluster_1["id"]) + # remove tables that have one pdf cell + if one_cell_table == True: + if ( + cluster_1["type"] == DocItemLabel.TABLE + and len(cluster_1["cell_ids"]) < 2 + ): + DuplicateDeletedClusterIDs.append(cluster_1["id"]) + + DuplicateDeletedClusterIDs = list(set(DuplicateDeletedClusterIDs)) + + for cl_id in DuplicateDeletedClusterIDs: + for cluster in cluster_predictions: + if cl_id == cluster["id"]: + cluster_predictions.remove(cluster) + return cluster_predictions + + +def assigning_cell_ids_to_clusters(clusters, raw_cells, threshold): + for cluster in clusters: + cells_in_cluster, _ = compute_enclosed_cells( + cluster["bbox"], raw_cells, min_cell_intersection_with_cluster=threshold + ) + cluster["cell_ids"] = cells_in_cluster + ## These cell_ids are ids of the raw cells. + ## They are often, but not always, the same as the "id" or the index of the "cells" list in a prediction. + return clusters + + +# Creates a map of cell_id->cluster_id +def cell_id_state_map(clusters, cell_count): + clusters_around_cells = find_clusters_around_cells(cell_count, clusters) + orphan_cell_indices = [ + ix for ix in range(cell_count) if len(clusters_around_cells[ix]) == 0 + ] # which cells are assigned no cluster? + ambiguous_cell_indices = [ + ix for ix in range(cell_count) if len(clusters_around_cells[ix]) > 1 + ] # which cells are assigned > 1 clusters? + return clusters_around_cells, orphan_cell_indices, ambiguous_cell_indices diff --git a/poetry.lock b/poetry.lock index 3b2ad8e7..4d7c9a88 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.5 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand. [[package]] name = "aiohappyeyeballs" @@ -914,13 +914,13 @@ chunking = ["semchunk (>=2.2.0,<3.0.0)", "transformers (>=4.34.0,<5.0.0)"] [[package]] name = "docling-ibm-models" -version = "3.0.0" +version = "3.1.0" description = "This package contains the AI models used by the Docling PDF conversion package" optional = false python-versions = "<4.0,>=3.9" files = [ - {file = "docling_ibm_models-3.0.0-py3-none-any.whl", hash = "sha256:61d1bc3fc36fbec687533f543e2f899117bc19e5b31ab03520af4b84e1f7327c"}, - {file = "docling_ibm_models-3.0.0.tar.gz", hash = "sha256:2a4c064c6a58cfce039e9574c52cb3cab7decd103e20e9c5ccb7834e7fa04d4f"}, + {file = "docling_ibm_models-3.1.0-py3-none-any.whl", hash = "sha256:a381a45dff16fdb2246b99c15a2e3d6ba880c573d48a1d6477d3ffb36bab807f"}, + {file = "docling_ibm_models-3.1.0.tar.gz", hash = "sha256:65d734ffa490edc4e2301d296b6e893afa536c63b7daae7bbda781bd15b3431e"}, ] [package.dependencies] @@ -7608,4 +7608,4 @@ tesserocr = ["tesserocr"] [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "4b3ccc0f6fa8a57da342674fa938be59e453a6289d230791c1a5d970ea4441de" +content-hash = "5271637a86ae221be362a288546c9fee3e3e25e5b323c997464c032c284716bd" diff --git a/pyproject.toml b/pyproject.toml index 6c087ce3..aa3fbaf6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,7 @@ packages = [{include = "docling"}] python = "^3.9" docling-core = { version = "^2.9.0", extras = ["chunking"] } pydantic = "^2.0.0" -docling-ibm-models = "^3.0.0" +docling-ibm-models = "^3.1.0" deepsearch-glm = "^1.0.0" docling-parse = "^3.0.0" filetype = "^1.2.0"