mirror of
https://github.com/DS4SD/docling.git
synced 2025-08-01 15:02:21 +00:00
Fixes for layout processing and tableformer workaround
Signed-off-by: Christoph Auer <cau@zurich.ibm.com>
This commit is contained in:
parent
959f91180f
commit
4447d22c2f
@ -71,6 +71,10 @@ class TableStructureModel(BasePageModel):
|
||||
x0, y0, x1, y1 = table_element.cluster.bbox.as_tuple()
|
||||
draw.rectangle([(x0, y0), (x1, y1)], outline="red")
|
||||
|
||||
for cell in table_element.cluster.cells:
|
||||
x0, y0, x1, y1 = cell.bbox.as_tuple()
|
||||
draw.rectangle([(x0, y0), (x1, y1)], outline="green")
|
||||
|
||||
for tc in table_element.table_cells:
|
||||
if tc.bbox is not None:
|
||||
x0, y0, x1, y1 = tc.bbox.as_tuple()
|
||||
@ -84,7 +88,6 @@ class TableStructureModel(BasePageModel):
|
||||
text=f"{tc.start_row_offset_idx}, {tc.start_col_offset_idx}",
|
||||
fill="black",
|
||||
)
|
||||
|
||||
if show:
|
||||
image.show()
|
||||
else:
|
||||
@ -136,41 +139,33 @@ class TableStructureModel(BasePageModel):
|
||||
yield page
|
||||
continue
|
||||
|
||||
tokens = []
|
||||
for c in page.cells:
|
||||
for cluster, _ in in_tables:
|
||||
if c.bbox.area() > 0:
|
||||
if (
|
||||
c.bbox.intersection_area_with(cluster.bbox)
|
||||
/ c.bbox.area()
|
||||
> 0.2
|
||||
):
|
||||
# Only allow non empty stings (spaces) into the cells of a table
|
||||
if len(c.text.strip()) > 0:
|
||||
new_cell = copy.deepcopy(c)
|
||||
new_cell.bbox = new_cell.bbox.scaled(
|
||||
scale=self.scale
|
||||
)
|
||||
|
||||
tokens.append(new_cell.model_dump())
|
||||
|
||||
page_input = {
|
||||
"tokens": tokens,
|
||||
"width": page.size.width * self.scale,
|
||||
"height": page.size.height * self.scale,
|
||||
"image": numpy.asarray(page.get_image(scale=self.scale)),
|
||||
}
|
||||
page_input["image"] = numpy.asarray(
|
||||
page.get_image(scale=self.scale)
|
||||
)
|
||||
|
||||
table_clusters, table_bboxes = zip(*in_tables)
|
||||
|
||||
if len(table_bboxes):
|
||||
tf_output = self.tf_predictor.multi_table_predict(
|
||||
page_input, table_bboxes, do_matching=self.do_cell_matching
|
||||
)
|
||||
for table_cluster, tbl_box in in_tables:
|
||||
|
||||
for table_cluster, table_out in zip(table_clusters, tf_output):
|
||||
tokens = []
|
||||
for c in table_cluster.cells:
|
||||
# Only allow non empty stings (spaces) into the cells of a table
|
||||
if len(c.text.strip()) > 0:
|
||||
new_cell = copy.deepcopy(c)
|
||||
new_cell.bbox = new_cell.bbox.scaled(
|
||||
scale=self.scale
|
||||
)
|
||||
|
||||
tokens.append(new_cell.model_dump())
|
||||
page_input["tokens"] = tokens
|
||||
|
||||
tf_output = self.tf_predictor.multi_table_predict(
|
||||
page_input, [tbl_box], do_matching=self.do_cell_matching
|
||||
)
|
||||
table_out = tf_output[0]
|
||||
table_cells = []
|
||||
for element in table_out["tf_responses"]:
|
||||
|
||||
|
@ -162,7 +162,7 @@ class LayoutPostprocessor:
|
||||
DocItemLabel.LIST_ITEM: 0.5,
|
||||
DocItemLabel.PAGE_FOOTER: 0.5,
|
||||
DocItemLabel.PAGE_HEADER: 0.5,
|
||||
DocItemLabel.PICTURE: 0.1,
|
||||
DocItemLabel.PICTURE: 0.5,
|
||||
DocItemLabel.SECTION_HEADER: 0.45,
|
||||
DocItemLabel.TABLE: 0.35,
|
||||
DocItemLabel.TEXT: 0.55, # 0.45,
|
||||
@ -279,6 +279,8 @@ class LayoutPostprocessor:
|
||||
if c.confidence >= self.CONFIDENCE_THRESHOLDS[c.label]
|
||||
]
|
||||
|
||||
special_clusters = self._handle_cross_type_overlaps(special_clusters)
|
||||
|
||||
for special in special_clusters:
|
||||
contained = []
|
||||
for cluster in self.regular_clusters:
|
||||
@ -289,14 +291,17 @@ class LayoutPostprocessor:
|
||||
contained.append(cluster)
|
||||
|
||||
if contained:
|
||||
# Sort contained clusters by minimum cell ID
|
||||
contained.sort(
|
||||
key=lambda cluster: (
|
||||
min(cell.id for cell in cluster.cells)
|
||||
if cluster.cells
|
||||
else sys.maxsize
|
||||
)
|
||||
)
|
||||
# # Sort contained clusters by minimum cell ID:
|
||||
# contained.sort(
|
||||
# key=lambda cluster: (
|
||||
# min(cell.id for cell in cluster.cells)
|
||||
# if cluster.cells
|
||||
# else sys.maxsize
|
||||
# )
|
||||
# )
|
||||
|
||||
# Sort contained clusters left-to-right, top-to-bottom
|
||||
contained = self._sort_clusters(contained)
|
||||
special.children = contained
|
||||
|
||||
# Adjust bbox only for wrapper types
|
||||
@ -324,6 +329,39 @@ class LayoutPostprocessor:
|
||||
|
||||
return picture_clusters + wrapper_clusters
|
||||
|
||||
def _handle_cross_type_overlaps(self, special_clusters) -> List[Cluster]:
|
||||
"""Handle overlaps between regular and wrapper clusters before child assignment.
|
||||
|
||||
In particular, KEY_VALUE_REGION proposals that are almost identical to a TABLE
|
||||
should be removed.
|
||||
"""
|
||||
wrappers_to_remove = set()
|
||||
|
||||
for wrapper in special_clusters:
|
||||
if wrapper.label != DocItemLabel.KEY_VALUE_REGION:
|
||||
continue # only treat KEY_VALUE_REGION for now.
|
||||
|
||||
for regular in self.regular_clusters:
|
||||
if regular.label == DocItemLabel.TABLE:
|
||||
# Calculate overlap
|
||||
overlap = regular.bbox.intersection_area_with(wrapper.bbox)
|
||||
wrapper_area = wrapper.bbox.area()
|
||||
overlap_ratio = overlap / wrapper_area
|
||||
|
||||
# If wrapper is mostly overlapping with a TABLE, remove the wrapper
|
||||
if overlap_ratio > 0.8: # 80% overlap threshold
|
||||
wrappers_to_remove.add(wrapper.id)
|
||||
break
|
||||
|
||||
# Filter out the identified wrappers
|
||||
special_clusters = [
|
||||
cluster
|
||||
for cluster in special_clusters
|
||||
if cluster.id not in wrappers_to_remove
|
||||
]
|
||||
|
||||
return special_clusters
|
||||
|
||||
def _should_prefer_cluster(
|
||||
self, candidate: Cluster, other: Cluster, params: dict
|
||||
) -> bool:
|
||||
@ -443,6 +481,7 @@ class LayoutPostprocessor:
|
||||
if cluster != best:
|
||||
best.cells.extend(cluster.cells)
|
||||
|
||||
best.cells = self._deduplicate_cells(best.cells)
|
||||
best.cells = self._sort_cells(best.cells)
|
||||
result.append(best)
|
||||
|
||||
@ -478,6 +517,16 @@ class LayoutPostprocessor:
|
||||
|
||||
return current_best if current_best else clusters[0]
|
||||
|
||||
def _deduplicate_cells(self, cells: List[Cell]) -> List[Cell]:
|
||||
"""Ensure each cell appears only once, maintaining order of first appearance."""
|
||||
seen_ids = set()
|
||||
unique_cells = []
|
||||
for cell in cells:
|
||||
if cell.id not in seen_ids:
|
||||
seen_ids.add(cell.id)
|
||||
unique_cells.append(cell)
|
||||
return unique_cells
|
||||
|
||||
def _assign_cells_to_clusters(
|
||||
self, clusters: List[Cluster], min_overlap: float = 0.2
|
||||
) -> List[Cluster]:
|
||||
@ -506,6 +555,10 @@ class LayoutPostprocessor:
|
||||
if best_cluster is not None:
|
||||
best_cluster.cells.append(cell)
|
||||
|
||||
# Deduplicate cells in each cluster after assignment
|
||||
for cluster in clusters:
|
||||
cluster.cells = self._deduplicate_cells(cluster.cells)
|
||||
|
||||
return clusters
|
||||
|
||||
def _find_unassigned_cells(self, clusters: List[Cluster]) -> List[Cell]:
|
||||
@ -547,11 +600,4 @@ class LayoutPostprocessor:
|
||||
|
||||
def _sort_clusters(self, clusters: List[Cluster]) -> List[Cluster]:
|
||||
"""Sort clusters in reading order (top-to-bottom, left-to-right)."""
|
||||
|
||||
def reading_order_key(cluster: Cluster) -> Tuple[float, float]:
|
||||
if cluster.cells and cluster.label != DocItemLabel.PICTURE:
|
||||
first_cell = min(cluster.cells, key=lambda c: (c.bbox.t, c.bbox.l))
|
||||
return (first_cell.bbox.t, first_cell.bbox.l)
|
||||
return (cluster.bbox.t, cluster.bbox.l)
|
||||
|
||||
return sorted(clusters, key=reading_order_key)
|
||||
return sorted(clusters, key=lambda cluster: (cluster.bbox.t, cluster.bbox.l))
|
||||
|
73
poetry.lock
generated
73
poetry.lock
generated
@ -922,27 +922,29 @@ name = "docling-core"
|
||||
version = "2.9.0"
|
||||
description = "A python library to define and validate data types in Docling."
|
||||
optional = false
|
||||
python-versions = "<4.0,>=3.9"
|
||||
files = [
|
||||
{file = "docling_core-2.9.0-py3-none-any.whl", hash = "sha256:b44b077db5d2ac8a900f30a15abe329c165b1f2eb7f1c90d1275c423c1c3d668"},
|
||||
{file = "docling_core-2.9.0.tar.gz", hash = "sha256:1bf12fe67ee4852330e9bac33fe62b45598ff885481e03a88fa8e1bf48252424"},
|
||||
]
|
||||
python-versions = "^3.9"
|
||||
files = []
|
||||
develop = false
|
||||
|
||||
[package.dependencies]
|
||||
jsonref = ">=1.1.0,<2.0.0"
|
||||
jsonschema = ">=4.16.0,<5.0.0"
|
||||
pandas = ">=2.1.4,<3.0.0"
|
||||
pillow = ">=10.3.0,<11.0.0"
|
||||
pydantic = ">=2.6.0,<2.10.0 || >2.10.0,<2.10.1 || >2.10.1,<2.10.2 || >2.10.2,<3.0.0"
|
||||
jsonref = "^1.1.0"
|
||||
jsonschema = "^4.16.0"
|
||||
pandas = "^2.1.4"
|
||||
pillow = "^10.3.0"
|
||||
pydantic = ">=2.6.0,<3.0.0,!=2.10.0,!=2.10.1,!=2.10.2"
|
||||
pyyaml = ">=5.1,<7.0.0"
|
||||
semchunk = {version = ">=2.2.0,<3.0.0", optional = true, markers = "extra == \"chunking\""}
|
||||
tabulate = ">=0.9.0,<0.10.0"
|
||||
transformers = {version = ">=4.34.0,<5.0.0", optional = true, markers = "extra == \"chunking\""}
|
||||
typing-extensions = ">=4.12.2,<5.0.0"
|
||||
tabulate = "^0.9.0"
|
||||
typing-extensions = "^4.12.2"
|
||||
|
||||
[package.extras]
|
||||
chunking = ["semchunk (>=2.2.0,<3.0.0)", "transformers (>=4.34.0,<5.0.0)"]
|
||||
|
||||
[package.source]
|
||||
type = "git"
|
||||
url = "ssh://git@github.com/DS4SD/docling-core.git"
|
||||
reference = "cau/include-picture-contents"
|
||||
resolved_reference = "012f8ac38a2ba7e77110b3f7ad57af2a984232e5"
|
||||
|
||||
[[package]]
|
||||
name = "docling-ibm-models"
|
||||
version = "2.0.7"
|
||||
@ -2855,32 +2857,6 @@ files = [
|
||||
{file = "more_itertools-10.5.0-py3-none-any.whl", hash = "sha256:037b0d3203ce90cca8ab1defbbdac29d5f993fc20131f3664dc8d6acfa872aef"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mpire"
|
||||
version = "2.10.2"
|
||||
description = "A Python package for easy multiprocessing, but faster than multiprocessing"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "mpire-2.10.2-py3-none-any.whl", hash = "sha256:d627707f7a8d02aa4c7f7d59de399dec5290945ddf7fbd36cbb1d6ebb37a51fb"},
|
||||
{file = "mpire-2.10.2.tar.gz", hash = "sha256:f66a321e93fadff34585a4bfa05e95bd946cf714b442f51c529038eb45773d97"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
multiprocess = [
|
||||
{version = "*", optional = true, markers = "python_version < \"3.11\" and extra == \"dill\""},
|
||||
{version = ">=0.70.15", optional = true, markers = "python_version >= \"3.11\" and extra == \"dill\""},
|
||||
]
|
||||
pygments = ">=2.0"
|
||||
pywin32 = {version = ">=301", markers = "platform_system == \"Windows\""}
|
||||
tqdm = ">=4.27"
|
||||
|
||||
[package.extras]
|
||||
dashboard = ["flask"]
|
||||
dill = ["multiprocess", "multiprocess (>=0.70.15)"]
|
||||
docs = ["docutils (==0.17.1)", "sphinx (==3.2.1)", "sphinx-autodoc-typehints (==1.11.0)", "sphinx-rtd-theme (==0.5.0)", "sphinx-versions (==1.0.1)", "sphinxcontrib-images (==0.9.2)"]
|
||||
testing = ["ipywidgets", "multiprocess", "multiprocess (>=0.70.15)", "numpy", "pywin32 (>=301)", "rich"]
|
||||
|
||||
[[package]]
|
||||
name = "mpmath"
|
||||
version = "1.3.0"
|
||||
@ -6170,21 +6146,6 @@ files = [
|
||||
cryptography = ">=2.0"
|
||||
jeepney = ">=0.6"
|
||||
|
||||
[[package]]
|
||||
name = "semchunk"
|
||||
version = "2.2.0"
|
||||
description = "A fast and lightweight Python library for splitting text into semantically meaningful chunks."
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
files = [
|
||||
{file = "semchunk-2.2.0-py3-none-any.whl", hash = "sha256:7db19ca90ddb48f99265e789e07a7bb111ae25185f9cc3d44b94e1e61b9067fc"},
|
||||
{file = "semchunk-2.2.0.tar.gz", hash = "sha256:4de761ce614036fa3bea61adbe47e3ade7c96ac9b062f223b3ac353dbfd26743"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
mpire = {version = "*", extras = ["dill"]}
|
||||
tqdm = "*"
|
||||
|
||||
[[package]]
|
||||
name = "semver"
|
||||
version = "2.13.0"
|
||||
@ -7723,4 +7684,4 @@ tesserocr = ["tesserocr"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.9"
|
||||
content-hash = "6917af8d76aa1f85a159f0ab9546478b4bef194ae726c79196bac087c7368fef"
|
||||
content-hash = "c991515ef231d9eeead33cc876e8cb93fe31e949a5ab92918a4b77257d2700a3"
|
||||
|
@ -28,7 +28,8 @@ python = "^3.9"
|
||||
docling-ibm-models = { git = "ssh://git@github.com/DS4SD/docling-ibm-models.git", branch = "nli/performance" }
|
||||
deepsearch-glm = "^1.0.0"
|
||||
docling-parse = "^3.0.0"
|
||||
docling-core = { version = "^2.9.0", extras = ["chunking"] }
|
||||
#docling-core = { version = "^2.9.0", extras = ["chunking"] }
|
||||
docling-core = { git = "ssh://git@github.com/DS4SD/docling-core.git", branch = "cau/include-picture-contents" }
|
||||
pydantic = "^2.0.0"
|
||||
filetype = "^1.2.0"
|
||||
pypdfium2 = "^4.30.0"
|
||||
|
Loading…
Reference in New Issue
Block a user