From 4447d22c2f6657d42f1e3775cbc0146c3df8f392 Mon Sep 17 00:00:00 2001 From: Christoph Auer Date: Tue, 10 Dec 2024 15:50:18 +0100 Subject: [PATCH] Fixes for layout processing and tableformer workaround Signed-off-by: Christoph Auer --- docling/models/table_structure_model.py | 49 +++++++-------- docling/utils/layout_postprocessor.py | 80 +++++++++++++++++++------ poetry.lock | 73 ++++++---------------- pyproject.toml | 3 +- 4 files changed, 104 insertions(+), 101 deletions(-) diff --git a/docling/models/table_structure_model.py b/docling/models/table_structure_model.py index e5ab1fe2..851fa039 100644 --- a/docling/models/table_structure_model.py +++ b/docling/models/table_structure_model.py @@ -71,6 +71,10 @@ class TableStructureModel(BasePageModel): x0, y0, x1, y1 = table_element.cluster.bbox.as_tuple() draw.rectangle([(x0, y0), (x1, y1)], outline="red") + for cell in table_element.cluster.cells: + x0, y0, x1, y1 = cell.bbox.as_tuple() + draw.rectangle([(x0, y0), (x1, y1)], outline="green") + for tc in table_element.table_cells: if tc.bbox is not None: x0, y0, x1, y1 = tc.bbox.as_tuple() @@ -84,7 +88,6 @@ class TableStructureModel(BasePageModel): text=f"{tc.start_row_offset_idx}, {tc.start_col_offset_idx}", fill="black", ) - if show: image.show() else: @@ -136,41 +139,33 @@ class TableStructureModel(BasePageModel): yield page continue - tokens = [] - for c in page.cells: - for cluster, _ in in_tables: - if c.bbox.area() > 0: - if ( - c.bbox.intersection_area_with(cluster.bbox) - / c.bbox.area() - > 0.2 - ): - # Only allow non empty stings (spaces) into the cells of a table - if len(c.text.strip()) > 0: - new_cell = copy.deepcopy(c) - new_cell.bbox = new_cell.bbox.scaled( - scale=self.scale - ) - - tokens.append(new_cell.model_dump()) - page_input = { - "tokens": tokens, "width": page.size.width * self.scale, "height": page.size.height * self.scale, + "image": numpy.asarray(page.get_image(scale=self.scale)), } - page_input["image"] = numpy.asarray( - page.get_image(scale=self.scale) - ) table_clusters, table_bboxes = zip(*in_tables) if len(table_bboxes): - tf_output = self.tf_predictor.multi_table_predict( - page_input, table_bboxes, do_matching=self.do_cell_matching - ) + for table_cluster, tbl_box in in_tables: - for table_cluster, table_out in zip(table_clusters, tf_output): + tokens = [] + for c in table_cluster.cells: + # Only allow non empty stings (spaces) into the cells of a table + if len(c.text.strip()) > 0: + new_cell = copy.deepcopy(c) + new_cell.bbox = new_cell.bbox.scaled( + scale=self.scale + ) + + tokens.append(new_cell.model_dump()) + page_input["tokens"] = tokens + + tf_output = self.tf_predictor.multi_table_predict( + page_input, [tbl_box], do_matching=self.do_cell_matching + ) + table_out = tf_output[0] table_cells = [] for element in table_out["tf_responses"]: diff --git a/docling/utils/layout_postprocessor.py b/docling/utils/layout_postprocessor.py index 652a8911..9adc371b 100644 --- a/docling/utils/layout_postprocessor.py +++ b/docling/utils/layout_postprocessor.py @@ -162,7 +162,7 @@ class LayoutPostprocessor: DocItemLabel.LIST_ITEM: 0.5, DocItemLabel.PAGE_FOOTER: 0.5, DocItemLabel.PAGE_HEADER: 0.5, - DocItemLabel.PICTURE: 0.1, + DocItemLabel.PICTURE: 0.5, DocItemLabel.SECTION_HEADER: 0.45, DocItemLabel.TABLE: 0.35, DocItemLabel.TEXT: 0.55, # 0.45, @@ -279,6 +279,8 @@ class LayoutPostprocessor: if c.confidence >= self.CONFIDENCE_THRESHOLDS[c.label] ] + special_clusters = self._handle_cross_type_overlaps(special_clusters) + for special in special_clusters: contained = [] for cluster in self.regular_clusters: @@ -289,14 +291,17 @@ class LayoutPostprocessor: contained.append(cluster) if contained: - # Sort contained clusters by minimum cell ID - contained.sort( - key=lambda cluster: ( - min(cell.id for cell in cluster.cells) - if cluster.cells - else sys.maxsize - ) - ) + # # Sort contained clusters by minimum cell ID: + # contained.sort( + # key=lambda cluster: ( + # min(cell.id for cell in cluster.cells) + # if cluster.cells + # else sys.maxsize + # ) + # ) + + # Sort contained clusters left-to-right, top-to-bottom + contained = self._sort_clusters(contained) special.children = contained # Adjust bbox only for wrapper types @@ -324,6 +329,39 @@ class LayoutPostprocessor: return picture_clusters + wrapper_clusters + def _handle_cross_type_overlaps(self, special_clusters) -> List[Cluster]: + """Handle overlaps between regular and wrapper clusters before child assignment. + + In particular, KEY_VALUE_REGION proposals that are almost identical to a TABLE + should be removed. + """ + wrappers_to_remove = set() + + for wrapper in special_clusters: + if wrapper.label != DocItemLabel.KEY_VALUE_REGION: + continue # only treat KEY_VALUE_REGION for now. + + for regular in self.regular_clusters: + if regular.label == DocItemLabel.TABLE: + # Calculate overlap + overlap = regular.bbox.intersection_area_with(wrapper.bbox) + wrapper_area = wrapper.bbox.area() + overlap_ratio = overlap / wrapper_area + + # If wrapper is mostly overlapping with a TABLE, remove the wrapper + if overlap_ratio > 0.8: # 80% overlap threshold + wrappers_to_remove.add(wrapper.id) + break + + # Filter out the identified wrappers + special_clusters = [ + cluster + for cluster in special_clusters + if cluster.id not in wrappers_to_remove + ] + + return special_clusters + def _should_prefer_cluster( self, candidate: Cluster, other: Cluster, params: dict ) -> bool: @@ -443,6 +481,7 @@ class LayoutPostprocessor: if cluster != best: best.cells.extend(cluster.cells) + best.cells = self._deduplicate_cells(best.cells) best.cells = self._sort_cells(best.cells) result.append(best) @@ -478,6 +517,16 @@ class LayoutPostprocessor: return current_best if current_best else clusters[0] + def _deduplicate_cells(self, cells: List[Cell]) -> List[Cell]: + """Ensure each cell appears only once, maintaining order of first appearance.""" + seen_ids = set() + unique_cells = [] + for cell in cells: + if cell.id not in seen_ids: + seen_ids.add(cell.id) + unique_cells.append(cell) + return unique_cells + def _assign_cells_to_clusters( self, clusters: List[Cluster], min_overlap: float = 0.2 ) -> List[Cluster]: @@ -506,6 +555,10 @@ class LayoutPostprocessor: if best_cluster is not None: best_cluster.cells.append(cell) + # Deduplicate cells in each cluster after assignment + for cluster in clusters: + cluster.cells = self._deduplicate_cells(cluster.cells) + return clusters def _find_unassigned_cells(self, clusters: List[Cluster]) -> List[Cell]: @@ -547,11 +600,4 @@ class LayoutPostprocessor: def _sort_clusters(self, clusters: List[Cluster]) -> List[Cluster]: """Sort clusters in reading order (top-to-bottom, left-to-right).""" - - def reading_order_key(cluster: Cluster) -> Tuple[float, float]: - if cluster.cells and cluster.label != DocItemLabel.PICTURE: - first_cell = min(cluster.cells, key=lambda c: (c.bbox.t, c.bbox.l)) - return (first_cell.bbox.t, first_cell.bbox.l) - return (cluster.bbox.t, cluster.bbox.l) - - return sorted(clusters, key=reading_order_key) + return sorted(clusters, key=lambda cluster: (cluster.bbox.t, cluster.bbox.l)) diff --git a/poetry.lock b/poetry.lock index 5a84d3b4..ec79701b 100644 --- a/poetry.lock +++ b/poetry.lock @@ -922,27 +922,29 @@ name = "docling-core" version = "2.9.0" description = "A python library to define and validate data types in Docling." optional = false -python-versions = "<4.0,>=3.9" -files = [ - {file = "docling_core-2.9.0-py3-none-any.whl", hash = "sha256:b44b077db5d2ac8a900f30a15abe329c165b1f2eb7f1c90d1275c423c1c3d668"}, - {file = "docling_core-2.9.0.tar.gz", hash = "sha256:1bf12fe67ee4852330e9bac33fe62b45598ff885481e03a88fa8e1bf48252424"}, -] +python-versions = "^3.9" +files = [] +develop = false [package.dependencies] -jsonref = ">=1.1.0,<2.0.0" -jsonschema = ">=4.16.0,<5.0.0" -pandas = ">=2.1.4,<3.0.0" -pillow = ">=10.3.0,<11.0.0" -pydantic = ">=2.6.0,<2.10.0 || >2.10.0,<2.10.1 || >2.10.1,<2.10.2 || >2.10.2,<3.0.0" +jsonref = "^1.1.0" +jsonschema = "^4.16.0" +pandas = "^2.1.4" +pillow = "^10.3.0" +pydantic = ">=2.6.0,<3.0.0,!=2.10.0,!=2.10.1,!=2.10.2" pyyaml = ">=5.1,<7.0.0" -semchunk = {version = ">=2.2.0,<3.0.0", optional = true, markers = "extra == \"chunking\""} -tabulate = ">=0.9.0,<0.10.0" -transformers = {version = ">=4.34.0,<5.0.0", optional = true, markers = "extra == \"chunking\""} -typing-extensions = ">=4.12.2,<5.0.0" +tabulate = "^0.9.0" +typing-extensions = "^4.12.2" [package.extras] chunking = ["semchunk (>=2.2.0,<3.0.0)", "transformers (>=4.34.0,<5.0.0)"] +[package.source] +type = "git" +url = "ssh://git@github.com/DS4SD/docling-core.git" +reference = "cau/include-picture-contents" +resolved_reference = "012f8ac38a2ba7e77110b3f7ad57af2a984232e5" + [[package]] name = "docling-ibm-models" version = "2.0.7" @@ -2855,32 +2857,6 @@ files = [ {file = "more_itertools-10.5.0-py3-none-any.whl", hash = "sha256:037b0d3203ce90cca8ab1defbbdac29d5f993fc20131f3664dc8d6acfa872aef"}, ] -[[package]] -name = "mpire" -version = "2.10.2" -description = "A Python package for easy multiprocessing, but faster than multiprocessing" -optional = false -python-versions = "*" -files = [ - {file = "mpire-2.10.2-py3-none-any.whl", hash = "sha256:d627707f7a8d02aa4c7f7d59de399dec5290945ddf7fbd36cbb1d6ebb37a51fb"}, - {file = "mpire-2.10.2.tar.gz", hash = "sha256:f66a321e93fadff34585a4bfa05e95bd946cf714b442f51c529038eb45773d97"}, -] - -[package.dependencies] -multiprocess = [ - {version = "*", optional = true, markers = "python_version < \"3.11\" and extra == \"dill\""}, - {version = ">=0.70.15", optional = true, markers = "python_version >= \"3.11\" and extra == \"dill\""}, -] -pygments = ">=2.0" -pywin32 = {version = ">=301", markers = "platform_system == \"Windows\""} -tqdm = ">=4.27" - -[package.extras] -dashboard = ["flask"] -dill = ["multiprocess", "multiprocess (>=0.70.15)"] -docs = ["docutils (==0.17.1)", "sphinx (==3.2.1)", "sphinx-autodoc-typehints (==1.11.0)", "sphinx-rtd-theme (==0.5.0)", "sphinx-versions (==1.0.1)", "sphinxcontrib-images (==0.9.2)"] -testing = ["ipywidgets", "multiprocess", "multiprocess (>=0.70.15)", "numpy", "pywin32 (>=301)", "rich"] - [[package]] name = "mpmath" version = "1.3.0" @@ -6170,21 +6146,6 @@ files = [ cryptography = ">=2.0" jeepney = ">=0.6" -[[package]] -name = "semchunk" -version = "2.2.0" -description = "A fast and lightweight Python library for splitting text into semantically meaningful chunks." -optional = false -python-versions = ">=3.9" -files = [ - {file = "semchunk-2.2.0-py3-none-any.whl", hash = "sha256:7db19ca90ddb48f99265e789e07a7bb111ae25185f9cc3d44b94e1e61b9067fc"}, - {file = "semchunk-2.2.0.tar.gz", hash = "sha256:4de761ce614036fa3bea61adbe47e3ade7c96ac9b062f223b3ac353dbfd26743"}, -] - -[package.dependencies] -mpire = {version = "*", extras = ["dill"]} -tqdm = "*" - [[package]] name = "semver" version = "2.13.0" @@ -7723,4 +7684,4 @@ tesserocr = ["tesserocr"] [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "6917af8d76aa1f85a159f0ab9546478b4bef194ae726c79196bac087c7368fef" +content-hash = "c991515ef231d9eeead33cc876e8cb93fe31e949a5ab92918a4b77257d2700a3" diff --git a/pyproject.toml b/pyproject.toml index 122675ac..fb9fac32 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,8 @@ python = "^3.9" docling-ibm-models = { git = "ssh://git@github.com/DS4SD/docling-ibm-models.git", branch = "nli/performance" } deepsearch-glm = "^1.0.0" docling-parse = "^3.0.0" -docling-core = { version = "^2.9.0", extras = ["chunking"] } +#docling-core = { version = "^2.9.0", extras = ["chunking"] } +docling-core = { git = "ssh://git@github.com/DS4SD/docling-core.git", branch = "cau/include-picture-contents" } pydantic = "^2.0.0" filetype = "^1.2.0" pypdfium2 = "^4.30.0"