Fixes for layout processing and tableformer workaround

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>
This commit is contained in:
Christoph Auer 2024-12-10 15:50:18 +01:00
parent 959f91180f
commit 4447d22c2f
4 changed files with 104 additions and 101 deletions

View File

@ -71,6 +71,10 @@ class TableStructureModel(BasePageModel):
x0, y0, x1, y1 = table_element.cluster.bbox.as_tuple()
draw.rectangle([(x0, y0), (x1, y1)], outline="red")
for cell in table_element.cluster.cells:
x0, y0, x1, y1 = cell.bbox.as_tuple()
draw.rectangle([(x0, y0), (x1, y1)], outline="green")
for tc in table_element.table_cells:
if tc.bbox is not None:
x0, y0, x1, y1 = tc.bbox.as_tuple()
@ -84,7 +88,6 @@ class TableStructureModel(BasePageModel):
text=f"{tc.start_row_offset_idx}, {tc.start_col_offset_idx}",
fill="black",
)
if show:
image.show()
else:
@ -136,41 +139,33 @@ class TableStructureModel(BasePageModel):
yield page
continue
tokens = []
for c in page.cells:
for cluster, _ in in_tables:
if c.bbox.area() > 0:
if (
c.bbox.intersection_area_with(cluster.bbox)
/ c.bbox.area()
> 0.2
):
# Only allow non empty stings (spaces) into the cells of a table
if len(c.text.strip()) > 0:
new_cell = copy.deepcopy(c)
new_cell.bbox = new_cell.bbox.scaled(
scale=self.scale
)
tokens.append(new_cell.model_dump())
page_input = {
"tokens": tokens,
"width": page.size.width * self.scale,
"height": page.size.height * self.scale,
"image": numpy.asarray(page.get_image(scale=self.scale)),
}
page_input["image"] = numpy.asarray(
page.get_image(scale=self.scale)
)
table_clusters, table_bboxes = zip(*in_tables)
if len(table_bboxes):
tf_output = self.tf_predictor.multi_table_predict(
page_input, table_bboxes, do_matching=self.do_cell_matching
)
for table_cluster, tbl_box in in_tables:
for table_cluster, table_out in zip(table_clusters, tf_output):
tokens = []
for c in table_cluster.cells:
# Only allow non empty stings (spaces) into the cells of a table
if len(c.text.strip()) > 0:
new_cell = copy.deepcopy(c)
new_cell.bbox = new_cell.bbox.scaled(
scale=self.scale
)
tokens.append(new_cell.model_dump())
page_input["tokens"] = tokens
tf_output = self.tf_predictor.multi_table_predict(
page_input, [tbl_box], do_matching=self.do_cell_matching
)
table_out = tf_output[0]
table_cells = []
for element in table_out["tf_responses"]:

View File

@ -162,7 +162,7 @@ class LayoutPostprocessor:
DocItemLabel.LIST_ITEM: 0.5,
DocItemLabel.PAGE_FOOTER: 0.5,
DocItemLabel.PAGE_HEADER: 0.5,
DocItemLabel.PICTURE: 0.1,
DocItemLabel.PICTURE: 0.5,
DocItemLabel.SECTION_HEADER: 0.45,
DocItemLabel.TABLE: 0.35,
DocItemLabel.TEXT: 0.55, # 0.45,
@ -279,6 +279,8 @@ class LayoutPostprocessor:
if c.confidence >= self.CONFIDENCE_THRESHOLDS[c.label]
]
special_clusters = self._handle_cross_type_overlaps(special_clusters)
for special in special_clusters:
contained = []
for cluster in self.regular_clusters:
@ -289,14 +291,17 @@ class LayoutPostprocessor:
contained.append(cluster)
if contained:
# Sort contained clusters by minimum cell ID
contained.sort(
key=lambda cluster: (
min(cell.id for cell in cluster.cells)
if cluster.cells
else sys.maxsize
)
)
# # Sort contained clusters by minimum cell ID:
# contained.sort(
# key=lambda cluster: (
# min(cell.id for cell in cluster.cells)
# if cluster.cells
# else sys.maxsize
# )
# )
# Sort contained clusters left-to-right, top-to-bottom
contained = self._sort_clusters(contained)
special.children = contained
# Adjust bbox only for wrapper types
@ -324,6 +329,39 @@ class LayoutPostprocessor:
return picture_clusters + wrapper_clusters
def _handle_cross_type_overlaps(self, special_clusters) -> List[Cluster]:
"""Handle overlaps between regular and wrapper clusters before child assignment.
In particular, KEY_VALUE_REGION proposals that are almost identical to a TABLE
should be removed.
"""
wrappers_to_remove = set()
for wrapper in special_clusters:
if wrapper.label != DocItemLabel.KEY_VALUE_REGION:
continue # only treat KEY_VALUE_REGION for now.
for regular in self.regular_clusters:
if regular.label == DocItemLabel.TABLE:
# Calculate overlap
overlap = regular.bbox.intersection_area_with(wrapper.bbox)
wrapper_area = wrapper.bbox.area()
overlap_ratio = overlap / wrapper_area
# If wrapper is mostly overlapping with a TABLE, remove the wrapper
if overlap_ratio > 0.8: # 80% overlap threshold
wrappers_to_remove.add(wrapper.id)
break
# Filter out the identified wrappers
special_clusters = [
cluster
for cluster in special_clusters
if cluster.id not in wrappers_to_remove
]
return special_clusters
def _should_prefer_cluster(
self, candidate: Cluster, other: Cluster, params: dict
) -> bool:
@ -443,6 +481,7 @@ class LayoutPostprocessor:
if cluster != best:
best.cells.extend(cluster.cells)
best.cells = self._deduplicate_cells(best.cells)
best.cells = self._sort_cells(best.cells)
result.append(best)
@ -478,6 +517,16 @@ class LayoutPostprocessor:
return current_best if current_best else clusters[0]
def _deduplicate_cells(self, cells: List[Cell]) -> List[Cell]:
"""Ensure each cell appears only once, maintaining order of first appearance."""
seen_ids = set()
unique_cells = []
for cell in cells:
if cell.id not in seen_ids:
seen_ids.add(cell.id)
unique_cells.append(cell)
return unique_cells
def _assign_cells_to_clusters(
self, clusters: List[Cluster], min_overlap: float = 0.2
) -> List[Cluster]:
@ -506,6 +555,10 @@ class LayoutPostprocessor:
if best_cluster is not None:
best_cluster.cells.append(cell)
# Deduplicate cells in each cluster after assignment
for cluster in clusters:
cluster.cells = self._deduplicate_cells(cluster.cells)
return clusters
def _find_unassigned_cells(self, clusters: List[Cluster]) -> List[Cell]:
@ -547,11 +600,4 @@ class LayoutPostprocessor:
def _sort_clusters(self, clusters: List[Cluster]) -> List[Cluster]:
"""Sort clusters in reading order (top-to-bottom, left-to-right)."""
def reading_order_key(cluster: Cluster) -> Tuple[float, float]:
if cluster.cells and cluster.label != DocItemLabel.PICTURE:
first_cell = min(cluster.cells, key=lambda c: (c.bbox.t, c.bbox.l))
return (first_cell.bbox.t, first_cell.bbox.l)
return (cluster.bbox.t, cluster.bbox.l)
return sorted(clusters, key=reading_order_key)
return sorted(clusters, key=lambda cluster: (cluster.bbox.t, cluster.bbox.l))

73
poetry.lock generated
View File

@ -922,27 +922,29 @@ name = "docling-core"
version = "2.9.0"
description = "A python library to define and validate data types in Docling."
optional = false
python-versions = "<4.0,>=3.9"
files = [
{file = "docling_core-2.9.0-py3-none-any.whl", hash = "sha256:b44b077db5d2ac8a900f30a15abe329c165b1f2eb7f1c90d1275c423c1c3d668"},
{file = "docling_core-2.9.0.tar.gz", hash = "sha256:1bf12fe67ee4852330e9bac33fe62b45598ff885481e03a88fa8e1bf48252424"},
]
python-versions = "^3.9"
files = []
develop = false
[package.dependencies]
jsonref = ">=1.1.0,<2.0.0"
jsonschema = ">=4.16.0,<5.0.0"
pandas = ">=2.1.4,<3.0.0"
pillow = ">=10.3.0,<11.0.0"
pydantic = ">=2.6.0,<2.10.0 || >2.10.0,<2.10.1 || >2.10.1,<2.10.2 || >2.10.2,<3.0.0"
jsonref = "^1.1.0"
jsonschema = "^4.16.0"
pandas = "^2.1.4"
pillow = "^10.3.0"
pydantic = ">=2.6.0,<3.0.0,!=2.10.0,!=2.10.1,!=2.10.2"
pyyaml = ">=5.1,<7.0.0"
semchunk = {version = ">=2.2.0,<3.0.0", optional = true, markers = "extra == \"chunking\""}
tabulate = ">=0.9.0,<0.10.0"
transformers = {version = ">=4.34.0,<5.0.0", optional = true, markers = "extra == \"chunking\""}
typing-extensions = ">=4.12.2,<5.0.0"
tabulate = "^0.9.0"
typing-extensions = "^4.12.2"
[package.extras]
chunking = ["semchunk (>=2.2.0,<3.0.0)", "transformers (>=4.34.0,<5.0.0)"]
[package.source]
type = "git"
url = "ssh://git@github.com/DS4SD/docling-core.git"
reference = "cau/include-picture-contents"
resolved_reference = "012f8ac38a2ba7e77110b3f7ad57af2a984232e5"
[[package]]
name = "docling-ibm-models"
version = "2.0.7"
@ -2855,32 +2857,6 @@ files = [
{file = "more_itertools-10.5.0-py3-none-any.whl", hash = "sha256:037b0d3203ce90cca8ab1defbbdac29d5f993fc20131f3664dc8d6acfa872aef"},
]
[[package]]
name = "mpire"
version = "2.10.2"
description = "A Python package for easy multiprocessing, but faster than multiprocessing"
optional = false
python-versions = "*"
files = [
{file = "mpire-2.10.2-py3-none-any.whl", hash = "sha256:d627707f7a8d02aa4c7f7d59de399dec5290945ddf7fbd36cbb1d6ebb37a51fb"},
{file = "mpire-2.10.2.tar.gz", hash = "sha256:f66a321e93fadff34585a4bfa05e95bd946cf714b442f51c529038eb45773d97"},
]
[package.dependencies]
multiprocess = [
{version = "*", optional = true, markers = "python_version < \"3.11\" and extra == \"dill\""},
{version = ">=0.70.15", optional = true, markers = "python_version >= \"3.11\" and extra == \"dill\""},
]
pygments = ">=2.0"
pywin32 = {version = ">=301", markers = "platform_system == \"Windows\""}
tqdm = ">=4.27"
[package.extras]
dashboard = ["flask"]
dill = ["multiprocess", "multiprocess (>=0.70.15)"]
docs = ["docutils (==0.17.1)", "sphinx (==3.2.1)", "sphinx-autodoc-typehints (==1.11.0)", "sphinx-rtd-theme (==0.5.0)", "sphinx-versions (==1.0.1)", "sphinxcontrib-images (==0.9.2)"]
testing = ["ipywidgets", "multiprocess", "multiprocess (>=0.70.15)", "numpy", "pywin32 (>=301)", "rich"]
[[package]]
name = "mpmath"
version = "1.3.0"
@ -6170,21 +6146,6 @@ files = [
cryptography = ">=2.0"
jeepney = ">=0.6"
[[package]]
name = "semchunk"
version = "2.2.0"
description = "A fast and lightweight Python library for splitting text into semantically meaningful chunks."
optional = false
python-versions = ">=3.9"
files = [
{file = "semchunk-2.2.0-py3-none-any.whl", hash = "sha256:7db19ca90ddb48f99265e789e07a7bb111ae25185f9cc3d44b94e1e61b9067fc"},
{file = "semchunk-2.2.0.tar.gz", hash = "sha256:4de761ce614036fa3bea61adbe47e3ade7c96ac9b062f223b3ac353dbfd26743"},
]
[package.dependencies]
mpire = {version = "*", extras = ["dill"]}
tqdm = "*"
[[package]]
name = "semver"
version = "2.13.0"
@ -7723,4 +7684,4 @@ tesserocr = ["tesserocr"]
[metadata]
lock-version = "2.0"
python-versions = "^3.9"
content-hash = "6917af8d76aa1f85a159f0ab9546478b4bef194ae726c79196bac087c7368fef"
content-hash = "c991515ef231d9eeead33cc876e8cb93fe31e949a5ab92918a4b77257d2700a3"

View File

@ -28,7 +28,8 @@ python = "^3.9"
docling-ibm-models = { git = "ssh://git@github.com/DS4SD/docling-ibm-models.git", branch = "nli/performance" }
deepsearch-glm = "^1.0.0"
docling-parse = "^3.0.0"
docling-core = { version = "^2.9.0", extras = ["chunking"] }
#docling-core = { version = "^2.9.0", extras = ["chunking"] }
docling-core = { git = "ssh://git@github.com/DS4SD/docling-core.git", branch = "cau/include-picture-contents" }
pydantic = "^2.0.0"
filetype = "^1.2.0"
pypdfium2 = "^4.30.0"