From 57d51ede04f99c467d303e04537105b47bca02ec Mon Sep 17 00:00:00 2001 From: Christoph Auer Date: Wed, 11 Dec 2024 17:08:35 +0100 Subject: [PATCH] Many layout processing improvements, add document index type Signed-off-by: Christoph Auer Signed-off-by: Christoph Auer --- docling/datamodel/document.py | 2 +- docling/models/layout_model.py | 56 ++++++++++++++++++++----- docling/models/page_assemble_model.py | 2 +- docling/models/table_structure_model.py | 5 ++- docling/utils/glm_utils.py | 4 +- docling/utils/layout_postprocessor.py | 44 +++++++++++++++---- 6 files changed, 90 insertions(+), 23 deletions(-) diff --git a/docling/datamodel/document.py b/docling/datamodel/document.py index bae4ab2a..140eade9 100644 --- a/docling/datamodel/document.py +++ b/docling/datamodel/document.py @@ -63,7 +63,7 @@ _log = logging.getLogger(__name__) layout_label_to_ds_type = { DocItemLabel.TITLE: "title", - DocItemLabel.DOCUMENT_INDEX: "table-of-contents", + DocItemLabel.DOCUMENT_INDEX: "table", DocItemLabel.SECTION_HEADER: "subtitle-level-1", DocItemLabel.CHECKBOX_SELECTED: "checkbox-selected", DocItemLabel.CHECKBOX_UNSELECTED: "checkbox-unselected", diff --git a/docling/models/layout_model.py b/docling/models/layout_model.py index 3d7f2269..2caa3866 100644 --- a/docling/models/layout_model.py +++ b/docling/models/layout_model.py @@ -7,7 +7,7 @@ from typing import Iterable, List from docling_core.types.doc import CoordOrigin, DocItemLabel from docling_ibm_models.layoutmodel.layout_predictor import LayoutPredictor -from PIL import Image, ImageDraw +from PIL import Image, ImageDraw, ImageFont from docling.datamodel.base_models import ( BoundingBox, @@ -44,7 +44,7 @@ class LayoutModel(BasePageModel): ] PAGE_HEADER_LABELS = [DocItemLabel.PAGE_HEADER, DocItemLabel.PAGE_FOOTER] - TABLE_LABEL = DocItemLabel.TABLE + TABLE_LABELS = [DocItemLabel.TABLE, DocItemLabel.DOCUMENT_INDEX] FIGURE_LABEL = DocItemLabel.PICTURE FORMULA_LABEL = DocItemLabel.FORMULA CONTAINER_LABELS = [DocItemLabel.FORM, DocItemLabel.KEY_VALUE_REGION] @@ -62,6 +62,7 @@ class LayoutModel(BasePageModel): Draws a page image side by side with clusters filtered into two categories: - Left: Clusters excluding FORM, KEY_VALUE_REGION, and PICTURE. - Right: Clusters including FORM, KEY_VALUE_REGION, and PICTURE. + Includes label names and confidence scores for each cluster. """ label_to_color = { DocItemLabel.TEXT: (255, 255, 153), # Light Yellow @@ -103,9 +104,18 @@ class LayoutModel(BasePageModel): # Function to draw clusters on an image def draw_clusters(image, clusters): draw = ImageDraw.Draw(image, "RGBA") + + # Create a smaller font for the labels + try: + font = ImageFont.truetype("arial.ttf", 12) + except OSError: + # Fallback to default font if arial is not available + font = ImageFont.load_default() + for c_tl in clusters: all_clusters = [c_tl, *c_tl.children] for c in all_clusters: + # Draw cells first (underneath) cell_color = (0, 0, 0, 40) # Transparent black for cells for tc in c.cells: cx0, cy0, cx1, cy1 = tc.bbox.as_tuple() @@ -115,21 +125,44 @@ class LayoutModel(BasePageModel): fill=cell_color, ) + # Draw cluster rectangle x0, y0, x1, y1 = c.bbox.as_tuple() - cluster_fill_color = ( - *list(label_to_color.get(c.label)), # type: ignore - 70, - ) - cluster_outline_color = ( - *list(label_to_color.get(c.label)), # type: ignore - 255, - ) + cluster_fill_color = (*list(label_to_color.get(c.label)), 70) + cluster_outline_color = (*list(label_to_color.get(c.label)), 255) draw.rectangle( [(x0, y0), (x1, y1)], outline=cluster_outline_color, fill=cluster_fill_color, ) + # Add label name and confidence + label_text = f"{c.label.name} ({c.confidence:.2f})" + + # Create semi-transparent background for text + text_bbox = draw.textbbox((x0, y0), label_text, font=font) + text_bg_padding = 2 + draw.rectangle( + [ + ( + text_bbox[0] - text_bg_padding, + text_bbox[1] - text_bg_padding, + ), + ( + text_bbox[2] + text_bg_padding, + text_bbox[3] + text_bg_padding, + ), + ], + fill=(255, 255, 255, 180), # Semi-transparent white + ) + + # Draw text + draw.text( + (x0, y0), + label_text, + fill=(0, 0, 0, 255), # Solid black + font=font, + ) + # Draw clusters on both images draw_clusters(left_image, left_clusters) draw_clusters(right_image, right_clusters) @@ -277,8 +310,9 @@ class LayoutModel(BasePageModel): ) # Apply postprocessing + processed_clusters, processed_cells = LayoutPostprocessor( - page.cells, clusters + page.cells, clusters, page.size ).postprocess() # processed_clusters, processed_cells = clusters, page.cells diff --git a/docling/models/page_assemble_model.py b/docling/models/page_assemble_model.py index 4c27400f..3e202e20 100644 --- a/docling/models/page_assemble_model.py +++ b/docling/models/page_assemble_model.py @@ -95,7 +95,7 @@ class PageAssembleModel(BasePageModel): headers.append(text_el) else: body.append(text_el) - elif cluster.label == LayoutModel.TABLE_LABEL: + elif cluster.label in LayoutModel.TABLE_LABELS: tbl = None if page.predictions.tablestructure: tbl = page.predictions.tablestructure.table_map.get( diff --git a/docling/models/table_structure_model.py b/docling/models/table_structure_model.py index 851fa039..94f347c1 100644 --- a/docling/models/table_structure_model.py +++ b/docling/models/table_structure_model.py @@ -133,7 +133,8 @@ class TableStructureModel(BasePageModel): ], ) for cluster in page.predictions.layout.clusters - if cluster.label == DocItemLabel.TABLE + if cluster.label + in [DocItemLabel.TABLE, DocItemLabel.DOCUMENT_INDEX] ] if not len(in_tables): yield page @@ -198,7 +199,7 @@ class TableStructureModel(BasePageModel): id=table_cluster.id, page_no=page.page_no, cluster=table_cluster, - label=DocItemLabel.TABLE, + label=table_cluster.label, ) page.predictions.tablestructure.table_map[ diff --git a/docling/utils/glm_utils.py b/docling/utils/glm_utils.py index 13681017..3289017b 100644 --- a/docling/utils/glm_utils.py +++ b/docling/utils/glm_utils.py @@ -169,6 +169,8 @@ def to_docling_document(doc_glm, update_name_label=False) -> DoclingDocument: current_list = None text = "" caption_refs = [] + item_label = DocItemLabel(pelem["name"]) + for caption in obj["captions"]: text += caption["text"] @@ -254,7 +256,7 @@ def to_docling_document(doc_glm, update_name_label=False) -> DoclingDocument: ), ) - tbl = doc.add_table(data=tbl_data, prov=prov) + tbl = doc.add_table(data=tbl_data, prov=prov, label=item_label) tbl.captions.extend(caption_refs) elif ptype in ["form", "key_value_region"]: diff --git a/docling/utils/layout_postprocessor.py b/docling/utils/layout_postprocessor.py index 9adc371b..298c838c 100644 --- a/docling/utils/layout_postprocessor.py +++ b/docling/utils/layout_postprocessor.py @@ -4,7 +4,7 @@ import sys from collections import defaultdict from typing import Dict, List, Set, Tuple -from docling_core.types.doc import DocItemLabel +from docling_core.types.doc import DocItemLabel, Size from rtree import index from docling.datamodel.base_models import BoundingBox, Cell, Cluster @@ -152,7 +152,12 @@ class LayoutPostprocessor: "wrapper": {"area_threshold": 2.0, "conf_threshold": 0.2}, } - WRAPPER_TYPES = {DocItemLabel.FORM, DocItemLabel.KEY_VALUE_REGION} + WRAPPER_TYPES = { + DocItemLabel.FORM, + DocItemLabel.KEY_VALUE_REGION, + DocItemLabel.TABLE, + DocItemLabel.DOCUMENT_INDEX, + } SPECIAL_TYPES = WRAPPER_TYPES | {DocItemLabel.PICTURE} CONFIDENCE_THRESHOLDS = { @@ -164,7 +169,7 @@ class LayoutPostprocessor: DocItemLabel.PAGE_HEADER: 0.5, DocItemLabel.PICTURE: 0.5, DocItemLabel.SECTION_HEADER: 0.45, - DocItemLabel.TABLE: 0.35, + DocItemLabel.TABLE: 0.5, DocItemLabel.TEXT: 0.55, # 0.45, DocItemLabel.TITLE: 0.45, DocItemLabel.CODE: 0.45, @@ -176,14 +181,15 @@ class LayoutPostprocessor: } LABEL_REMAPPING = { - DocItemLabel.DOCUMENT_INDEX: DocItemLabel.TABLE, + # DocItemLabel.DOCUMENT_INDEX: DocItemLabel.TABLE, DocItemLabel.TITLE: DocItemLabel.SECTION_HEADER, } - def __init__(self, cells: List[Cell], clusters: List[Cluster]): + def __init__(self, cells: List[Cell], clusters: List[Cluster], page_size: Size): """Initialize processor with cells and clusters.""" """Initialize processor with cells and spatial indices.""" self.cells = cells + self.page_size = page_size self.regular_clusters = [ c for c in clusters if c.label not in self.SPECIAL_TYPES ] @@ -281,6 +287,19 @@ class LayoutPostprocessor: special_clusters = self._handle_cross_type_overlaps(special_clusters) + # Calculate page area from known page size + page_area = self.page_size.width * self.page_size.height + if page_area > 0: + # Filter out full-page pictures + special_clusters = [ + cluster + for cluster in special_clusters + if not ( + cluster.label == DocItemLabel.PICTURE + and cluster.bbox.area() / page_area > 0.90 + ) + ] + for special in special_clusters: contained = [] for cluster in self.regular_clusters: @@ -313,6 +332,13 @@ class LayoutPostprocessor: b=max(c.bbox.b for c in contained), ) + # Collect all cells from children + all_cells = [] + for child in contained: + all_cells.extend(child.cells) + special.cells = self._deduplicate_cells(all_cells) + special.cells = self._sort_cells(special.cells) + picture_clusters = [ c for c in special_clusters if c.label == DocItemLabel.PICTURE ] @@ -338,7 +364,7 @@ class LayoutPostprocessor: wrappers_to_remove = set() for wrapper in special_clusters: - if wrapper.label != DocItemLabel.KEY_VALUE_REGION: + if wrapper.label not in self.WRAPPER_TYPES: continue # only treat KEY_VALUE_REGION for now. for regular in self.regular_clusters: @@ -348,8 +374,12 @@ class LayoutPostprocessor: wrapper_area = wrapper.bbox.area() overlap_ratio = overlap / wrapper_area + conf_diff = wrapper.confidence - regular.confidence + # If wrapper is mostly overlapping with a TABLE, remove the wrapper - if overlap_ratio > 0.8: # 80% overlap threshold + if ( + overlap_ratio > 0.9 and conf_diff < 0.1 + ): # self.OVERLAP_PARAMS["wrapper"]["conf_threshold"]): # 80% overlap threshold wrappers_to_remove.add(wrapper.id) break