mirror of
https://github.com/DS4SD/docling.git
synced 2025-07-27 12:34:22 +00:00
Upgraded Layout Postprocessing, sending old code back to ERZ
Signed-off-by: Christoph Auer <cau@zurich.ibm.com> Signed-off-by: Christoph Auer <cau@zurich.ibm.com>
This commit is contained in:
parent
20a2cd0f53
commit
e0cf80a919
@ -117,6 +117,7 @@ class Cluster(BaseModel):
|
|||||||
bbox: BoundingBox
|
bbox: BoundingBox
|
||||||
confidence: float = 1.0
|
confidence: float = 1.0
|
||||||
cells: List[Cell] = []
|
cells: List[Cell] = []
|
||||||
|
children: List["Cluster"] = [] # Add child cluster support
|
||||||
|
|
||||||
|
|
||||||
class BasePageElement(BaseModel):
|
class BasePageElement(BaseModel):
|
||||||
|
@ -31,6 +31,7 @@ class DebugSettings(BaseModel):
|
|||||||
visualize_cells: bool = False
|
visualize_cells: bool = False
|
||||||
visualize_ocr: bool = False
|
visualize_ocr: bool = False
|
||||||
visualize_layout: bool = False
|
visualize_layout: bool = False
|
||||||
|
visualize_raw_layout: bool = False
|
||||||
visualize_tables: bool = False
|
visualize_tables: bool = False
|
||||||
|
|
||||||
profile_pipeline_timings: bool = False
|
profile_pipeline_timings: bool = False
|
||||||
|
@ -7,7 +7,7 @@ from typing import Iterable, List
|
|||||||
|
|
||||||
from docling_core.types.doc import CoordOrigin, DocItemLabel
|
from docling_core.types.doc import CoordOrigin, DocItemLabel
|
||||||
from docling_ibm_models.layoutmodel.layout_predictor import LayoutPredictor
|
from docling_ibm_models.layoutmodel.layout_predictor import LayoutPredictor
|
||||||
from PIL import ImageDraw
|
from PIL import Image, ImageDraw
|
||||||
|
|
||||||
from docling.datamodel.base_models import (
|
from docling.datamodel.base_models import (
|
||||||
BoundingBox,
|
BoundingBox,
|
||||||
@ -19,7 +19,7 @@ from docling.datamodel.base_models import (
|
|||||||
from docling.datamodel.document import ConversionResult
|
from docling.datamodel.document import ConversionResult
|
||||||
from docling.datamodel.settings import settings
|
from docling.datamodel.settings import settings
|
||||||
from docling.models.base_model import BasePageModel
|
from docling.models.base_model import BasePageModel
|
||||||
from docling.utils import layout_utils as lu
|
from docling.utils.layout_postprocessor import LayoutPostprocessor
|
||||||
from docling.utils.profiling import TimeRecorder
|
from docling.utils.profiling import TimeRecorder
|
||||||
|
|
||||||
_log = logging.getLogger(__name__)
|
_log = logging.getLogger(__name__)
|
||||||
@ -49,230 +49,101 @@ class LayoutModel(BasePageModel):
|
|||||||
def __init__(self, artifacts_path: Path):
|
def __init__(self, artifacts_path: Path):
|
||||||
self.layout_predictor = LayoutPredictor(artifacts_path) # TODO temporary
|
self.layout_predictor = LayoutPredictor(artifacts_path) # TODO temporary
|
||||||
|
|
||||||
def postprocess(self, clusters_in: List[Cluster], cells: List[Cell], page_height):
|
def draw_clusters_and_cells_side_by_side(
|
||||||
MIN_INTERSECTION = 0.2
|
self, conv_res, page, clusters, mode_prefix: str, show: bool = False
|
||||||
CLASS_THRESHOLDS = {
|
):
|
||||||
DocItemLabel.CAPTION: 0.35,
|
"""
|
||||||
DocItemLabel.FOOTNOTE: 0.35,
|
Draws a page image side by side with clusters filtered into two categories:
|
||||||
DocItemLabel.FORMULA: 0.35,
|
- Left: Clusters excluding FORM, KEY_VALUE_REGION, and PICTURE.
|
||||||
DocItemLabel.LIST_ITEM: 0.35,
|
- Right: Clusters including FORM, KEY_VALUE_REGION, and PICTURE.
|
||||||
DocItemLabel.PAGE_FOOTER: 0.35,
|
"""
|
||||||
DocItemLabel.PAGE_HEADER: 0.35,
|
label_to_color = {
|
||||||
DocItemLabel.PICTURE: 0.2, # low threshold adjust to capture chemical structures for examples.
|
DocItemLabel.TEXT: (255, 255, 153), # Light Yellow
|
||||||
DocItemLabel.SECTION_HEADER: 0.45,
|
DocItemLabel.CAPTION: (255, 204, 153), # Light Orange
|
||||||
DocItemLabel.TABLE: 0.35,
|
DocItemLabel.LIST_ITEM: (153, 153, 255), # Light Purple
|
||||||
DocItemLabel.TEXT: 0.45,
|
DocItemLabel.FORMULA: (192, 192, 192), # Gray
|
||||||
DocItemLabel.TITLE: 0.45,
|
DocItemLabel.TABLE: (255, 204, 204), # Light Pink
|
||||||
DocItemLabel.DOCUMENT_INDEX: 0.45,
|
DocItemLabel.PICTURE: (255, 204, 164), # Light Beige
|
||||||
DocItemLabel.CODE: 0.45,
|
DocItemLabel.SECTION_HEADER: (255, 153, 153), # Light Red
|
||||||
DocItemLabel.CHECKBOX_SELECTED: 0.45,
|
DocItemLabel.PAGE_HEADER: (204, 255, 204), # Light Green
|
||||||
DocItemLabel.CHECKBOX_UNSELECTED: 0.45,
|
DocItemLabel.PAGE_FOOTER: (
|
||||||
DocItemLabel.FORM: 0.45,
|
204,
|
||||||
DocItemLabel.KEY_VALUE_REGION: 0.45,
|
255,
|
||||||
|
204,
|
||||||
|
), # Light Green (same as Page-Header)
|
||||||
|
DocItemLabel.TITLE: (255, 153, 153), # Light Red (same as Section-Header)
|
||||||
|
DocItemLabel.FOOTNOTE: (200, 200, 255), # Light Blue
|
||||||
|
DocItemLabel.DOCUMENT_INDEX: (220, 220, 220), # Light Gray
|
||||||
|
DocItemLabel.CODE: (255, 223, 186), # Peach
|
||||||
|
DocItemLabel.CHECKBOX_SELECTED: (255, 182, 193), # Pale Green
|
||||||
|
DocItemLabel.CHECKBOX_UNSELECTED: (255, 182, 193), # Light Pink
|
||||||
|
DocItemLabel.FORM: (200, 255, 255), # Light Cyan
|
||||||
|
DocItemLabel.KEY_VALUE_REGION: (183, 65, 14), # Rusty orange
|
||||||
}
|
}
|
||||||
|
|
||||||
CLASS_REMAPPINGS = {
|
# Filter clusters for left and right images
|
||||||
DocItemLabel.DOCUMENT_INDEX: DocItemLabel.TABLE,
|
exclude_labels = {
|
||||||
DocItemLabel.TITLE: DocItemLabel.SECTION_HEADER,
|
DocItemLabel.FORM,
|
||||||
|
DocItemLabel.KEY_VALUE_REGION,
|
||||||
|
DocItemLabel.PICTURE,
|
||||||
}
|
}
|
||||||
|
left_clusters = [c for c in clusters if c.label not in exclude_labels]
|
||||||
|
right_clusters = [c for c in clusters if c.label in exclude_labels]
|
||||||
|
|
||||||
_log.debug("================= Start postprocess function ====================")
|
# Create a deep copy of the original image for both sides
|
||||||
start_time = time.time()
|
left_image = copy.deepcopy(page.image)
|
||||||
# Apply Confidence Threshold to cluster predictions
|
right_image = copy.deepcopy(page.image)
|
||||||
# confidence = self.conf_threshold
|
|
||||||
clusters_mod = []
|
|
||||||
|
|
||||||
for cluster in clusters_in:
|
# Function to draw clusters on an image
|
||||||
confidence = CLASS_THRESHOLDS[cluster.label]
|
def draw_clusters(image, clusters):
|
||||||
if cluster.confidence >= confidence:
|
draw = ImageDraw.Draw(image, "RGBA")
|
||||||
# annotation["created_by"] = "high_conf_pred"
|
for c in clusters:
|
||||||
|
cell_color = (0, 0, 0, 40) # Transparent black for cells
|
||||||
|
for tc in c.cells:
|
||||||
|
cx0, cy0, cx1, cy1 = tc.bbox.as_tuple()
|
||||||
|
draw.rectangle(
|
||||||
|
[(cx0, cy0), (cx1, cy1)],
|
||||||
|
outline=None,
|
||||||
|
fill=cell_color,
|
||||||
|
)
|
||||||
|
|
||||||
# Remap class labels where needed.
|
x0, y0, x1, y1 = c.bbox.as_tuple()
|
||||||
if cluster.label in CLASS_REMAPPINGS.keys():
|
cluster_fill_color = (
|
||||||
cluster.label = CLASS_REMAPPINGS[cluster.label]
|
*list(label_to_color.get(c.label)), # type: ignore
|
||||||
clusters_mod.append(cluster)
|
70,
|
||||||
|
)
|
||||||
|
cluster_outline_color = (
|
||||||
|
*list(label_to_color.get(c.label)), # type: ignore
|
||||||
|
255,
|
||||||
|
)
|
||||||
|
draw.rectangle(
|
||||||
|
[(x0, y0), (x1, y1)],
|
||||||
|
outline=cluster_outline_color,
|
||||||
|
fill=cluster_fill_color,
|
||||||
|
)
|
||||||
|
|
||||||
# map to dictionary clusters and cells, with bottom left origin
|
# Draw clusters on both images
|
||||||
clusters_orig = [
|
draw_clusters(left_image, left_clusters)
|
||||||
{
|
draw_clusters(right_image, right_clusters)
|
||||||
"id": c.id,
|
|
||||||
"bbox": list(
|
|
||||||
c.bbox.to_bottom_left_origin(page_height).as_tuple()
|
|
||||||
), # TODO
|
|
||||||
"confidence": c.confidence,
|
|
||||||
"cell_ids": [],
|
|
||||||
"type": c.label,
|
|
||||||
}
|
|
||||||
for c in clusters_in
|
|
||||||
]
|
|
||||||
|
|
||||||
clusters_out = [
|
# Combine the images side by side
|
||||||
{
|
combined_width = left_image.width * 2
|
||||||
"id": c.id,
|
combined_height = left_image.height
|
||||||
"bbox": list(
|
combined_image = Image.new("RGB", (combined_width, combined_height))
|
||||||
c.bbox.to_bottom_left_origin(page_height).as_tuple()
|
combined_image.paste(left_image, (0, 0))
|
||||||
), # TODO
|
combined_image.paste(right_image, (left_image.width, 0))
|
||||||
"confidence": c.confidence,
|
|
||||||
"created_by": "high_conf_pred",
|
|
||||||
"cell_ids": [],
|
|
||||||
"type": c.label,
|
|
||||||
}
|
|
||||||
for c in clusters_mod
|
|
||||||
]
|
|
||||||
|
|
||||||
del clusters_mod
|
if show:
|
||||||
|
combined_image.show()
|
||||||
raw_cells = [
|
else:
|
||||||
{
|
out_path: Path = (
|
||||||
"id": c.id,
|
Path(settings.debug.debug_output_path)
|
||||||
"bbox": list(
|
/ f"debug_{conv_res.input.file.stem}"
|
||||||
c.bbox.to_bottom_left_origin(page_height).as_tuple()
|
|
||||||
), # TODO
|
|
||||||
"text": c.text,
|
|
||||||
}
|
|
||||||
for c in cells
|
|
||||||
]
|
|
||||||
cell_count = len(raw_cells)
|
|
||||||
|
|
||||||
_log.debug("---- 0. Treat cluster overlaps ------")
|
|
||||||
clusters_out = lu.remove_cluster_duplicates_by_conf(clusters_out, 0.8)
|
|
||||||
|
|
||||||
_log.debug(
|
|
||||||
"---- 1. Initially assign cells to clusters based on minimum intersection ------"
|
|
||||||
)
|
|
||||||
## Check for cells included in or touched by clusters:
|
|
||||||
clusters_out = lu.assigning_cell_ids_to_clusters(
|
|
||||||
clusters_out, raw_cells, MIN_INTERSECTION
|
|
||||||
)
|
|
||||||
|
|
||||||
_log.debug("---- 2. Assign Orphans with Low Confidence Detections")
|
|
||||||
# Creates a map of cell_id->cluster_id
|
|
||||||
(
|
|
||||||
clusters_around_cells,
|
|
||||||
orphan_cell_indices,
|
|
||||||
ambiguous_cell_indices,
|
|
||||||
) = lu.cell_id_state_map(clusters_out, cell_count)
|
|
||||||
|
|
||||||
# Assign orphan cells with lower confidence predictions
|
|
||||||
clusters_out, orphan_cell_indices = lu.assign_orphans_with_low_conf_pred(
|
|
||||||
clusters_out, clusters_orig, raw_cells, orphan_cell_indices
|
|
||||||
)
|
|
||||||
|
|
||||||
# Refresh the cell_ids assignment, after creating new clusters using low conf predictions
|
|
||||||
clusters_out = lu.assigning_cell_ids_to_clusters(
|
|
||||||
clusters_out, raw_cells, MIN_INTERSECTION
|
|
||||||
)
|
|
||||||
|
|
||||||
_log.debug("---- 3. Settle Ambigous Cells")
|
|
||||||
# Creates an update map after assignment of cell_id->cluster_id
|
|
||||||
(
|
|
||||||
clusters_around_cells,
|
|
||||||
orphan_cell_indices,
|
|
||||||
ambiguous_cell_indices,
|
|
||||||
) = lu.cell_id_state_map(clusters_out, cell_count)
|
|
||||||
|
|
||||||
# Settle pdf cells that belong to multiple clusters
|
|
||||||
clusters_out, ambiguous_cell_indices = lu.remove_ambigous_pdf_cell_by_conf(
|
|
||||||
clusters_out, raw_cells, ambiguous_cell_indices
|
|
||||||
)
|
|
||||||
|
|
||||||
_log.debug("---- 4. Set Orphans as Text")
|
|
||||||
(
|
|
||||||
clusters_around_cells,
|
|
||||||
orphan_cell_indices,
|
|
||||||
ambiguous_cell_indices,
|
|
||||||
) = lu.cell_id_state_map(clusters_out, cell_count)
|
|
||||||
|
|
||||||
clusters_out, orphan_cell_indices = lu.set_orphan_as_text(
|
|
||||||
clusters_out, clusters_orig, raw_cells, orphan_cell_indices
|
|
||||||
)
|
|
||||||
|
|
||||||
_log.debug("---- 5. Merge Cells & and adapt the bounding boxes")
|
|
||||||
# Merge cells orphan cells
|
|
||||||
clusters_out = lu.merge_cells(clusters_out)
|
|
||||||
|
|
||||||
# Clean up clusters that remain from merged and unreasonable clusters
|
|
||||||
clusters_out = lu.clean_up_clusters(
|
|
||||||
clusters_out,
|
|
||||||
raw_cells,
|
|
||||||
merge_cells=True,
|
|
||||||
img_table=True,
|
|
||||||
one_cell_table=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
new_clusters = lu.adapt_bboxes(raw_cells, clusters_out, orphan_cell_indices)
|
|
||||||
clusters_out = new_clusters
|
|
||||||
|
|
||||||
## We first rebuild where every cell is now:
|
|
||||||
## Now we write into a prediction cells list, not into the raw cells list.
|
|
||||||
## As we don't need previous labels, we best overwrite any old list, because that might
|
|
||||||
## have been sorted differently.
|
|
||||||
(
|
|
||||||
clusters_around_cells,
|
|
||||||
orphan_cell_indices,
|
|
||||||
ambiguous_cell_indices,
|
|
||||||
) = lu.cell_id_state_map(clusters_out, cell_count)
|
|
||||||
|
|
||||||
target_cells = []
|
|
||||||
for ix, cell in enumerate(raw_cells):
|
|
||||||
new_cell = {
|
|
||||||
"id": ix,
|
|
||||||
"rawcell_id": ix,
|
|
||||||
"label": "None",
|
|
||||||
"bbox": cell["bbox"],
|
|
||||||
"text": cell["text"],
|
|
||||||
}
|
|
||||||
for cluster_index in clusters_around_cells[
|
|
||||||
ix
|
|
||||||
]: # By previous analysis, this is always 1 cluster.
|
|
||||||
new_cell["label"] = clusters_out[cluster_index]["type"]
|
|
||||||
target_cells.append(new_cell)
|
|
||||||
# _log.debug("New label of cell " + str(ix) + " is " + str(new_cell["label"]))
|
|
||||||
cells_out = target_cells
|
|
||||||
|
|
||||||
## -------------------------------
|
|
||||||
## Sort clusters into reasonable reading order, and sort the cells inside each cluster
|
|
||||||
_log.debug("---- 5. Sort clusters in reading order ------")
|
|
||||||
sorted_clusters = lu.produce_reading_order(
|
|
||||||
clusters_out, "raw_cell_ids", "raw_cell_ids", True
|
|
||||||
)
|
|
||||||
clusters_out = sorted_clusters
|
|
||||||
|
|
||||||
# end_time = timer()
|
|
||||||
_log.debug("---- End of postprocessing function ------")
|
|
||||||
end_time = time.time() - start_time
|
|
||||||
_log.debug(f"Finished post processing in seconds={end_time:.3f}")
|
|
||||||
|
|
||||||
cells_out_new = [
|
|
||||||
Cell(
|
|
||||||
id=c["id"], # type: ignore
|
|
||||||
bbox=BoundingBox.from_tuple(
|
|
||||||
coord=c["bbox"], origin=CoordOrigin.BOTTOMLEFT # type: ignore
|
|
||||||
).to_top_left_origin(page_height),
|
|
||||||
text=c["text"], # type: ignore
|
|
||||||
)
|
)
|
||||||
for c in cells_out
|
out_path.mkdir(parents=True, exist_ok=True)
|
||||||
]
|
|
||||||
|
|
||||||
del cells_out
|
out_file = out_path / f"{mode_prefix}_layout_page_{page.page_no:05}.png"
|
||||||
|
combined_image.save(str(out_file), format="png")
|
||||||
clusters_out_new = []
|
|
||||||
for c in clusters_out:
|
|
||||||
cluster_cells = [
|
|
||||||
ccell for ccell in cells_out_new if ccell.id in c["cell_ids"] # type: ignore
|
|
||||||
]
|
|
||||||
c_new = Cluster(
|
|
||||||
id=c["id"], # type: ignore
|
|
||||||
bbox=BoundingBox.from_tuple(
|
|
||||||
coord=c["bbox"], origin=CoordOrigin.BOTTOMLEFT # type: ignore
|
|
||||||
).to_top_left_origin(page_height),
|
|
||||||
confidence=c["confidence"], # type: ignore
|
|
||||||
label=DocItemLabel(c["type"]),
|
|
||||||
cells=cluster_cells,
|
|
||||||
)
|
|
||||||
clusters_out_new.append(c_new)
|
|
||||||
|
|
||||||
return clusters_out_new, cells_out_new
|
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self, conv_res: ConversionResult, page_batch: Iterable[Page]
|
self, conv_res: ConversionResult, page_batch: Iterable[Page]
|
||||||
@ -304,44 +175,97 @@ class LayoutModel(BasePageModel):
|
|||||||
cells=[],
|
cells=[],
|
||||||
)
|
)
|
||||||
clusters.append(cluster)
|
clusters.append(cluster)
|
||||||
|
#
|
||||||
# Map cells to clusters
|
# # Map cells to clusters
|
||||||
# TODO: Remove, postprocess should take care of it anyway.
|
# # TODO: Remove, postprocess should take care of it anyway.
|
||||||
for cell in page.cells:
|
# for cell in page.cells:
|
||||||
for cluster in clusters:
|
# for cluster in clusters:
|
||||||
if not cell.bbox.area() > 0:
|
# if not cell.bbox.area() > 0:
|
||||||
overlap_frac = 0.0
|
# overlap_frac = 0.0
|
||||||
else:
|
# else:
|
||||||
overlap_frac = (
|
# overlap_frac = (
|
||||||
cell.bbox.intersection_area_with(cluster.bbox)
|
# cell.bbox.intersection_area_with(cluster.bbox)
|
||||||
/ cell.bbox.area()
|
# / cell.bbox.area()
|
||||||
)
|
# )
|
||||||
|
#
|
||||||
if overlap_frac > 0.5:
|
# if overlap_frac > 0.5:
|
||||||
cluster.cells.append(cell)
|
# cluster.cells.append(cell)
|
||||||
|
|
||||||
# Pre-sort clusters
|
# Pre-sort clusters
|
||||||
# clusters = self.sort_clusters_by_cell_order(clusters)
|
# clusters = self.sort_clusters_by_cell_order(clusters)
|
||||||
|
|
||||||
# DEBUG code:
|
# DEBUG code:
|
||||||
def draw_clusters_and_cells(show: bool = False):
|
def draw_clusters_and_cells(
|
||||||
|
clusters, mode_prefix: str, show: bool = False
|
||||||
|
):
|
||||||
|
label_to_color = {
|
||||||
|
DocItemLabel.TEXT: (255, 255, 153), # Light Yellow
|
||||||
|
DocItemLabel.CAPTION: (255, 204, 153), # Light Orange
|
||||||
|
DocItemLabel.LIST_ITEM: (153, 153, 255), # Light Purple
|
||||||
|
DocItemLabel.FORMULA: (192, 192, 192), # Gray
|
||||||
|
DocItemLabel.TABLE: (255, 204, 204), # Light Pink
|
||||||
|
DocItemLabel.PICTURE: (255, 255, 204), # Light Beige
|
||||||
|
DocItemLabel.SECTION_HEADER: (255, 153, 153), # Light Red
|
||||||
|
DocItemLabel.PAGE_HEADER: (204, 255, 204), # Light Green
|
||||||
|
DocItemLabel.PAGE_FOOTER: (
|
||||||
|
204,
|
||||||
|
255,
|
||||||
|
204,
|
||||||
|
), # Light Green (same as Page-Header)
|
||||||
|
DocItemLabel.TITLE: (
|
||||||
|
255,
|
||||||
|
153,
|
||||||
|
153,
|
||||||
|
), # Light Red (same as Section-Header)
|
||||||
|
DocItemLabel.FOOTNOTE: (200, 200, 255), # Light Blue
|
||||||
|
DocItemLabel.DOCUMENT_INDEX: (220, 220, 220), # Light Gray
|
||||||
|
DocItemLabel.CODE: (255, 223, 186), # Peach
|
||||||
|
DocItemLabel.CHECKBOX_SELECTED: (
|
||||||
|
255,
|
||||||
|
182,
|
||||||
|
193,
|
||||||
|
), # Pale Green
|
||||||
|
DocItemLabel.CHECKBOX_UNSELECTED: (
|
||||||
|
255,
|
||||||
|
182,
|
||||||
|
193,
|
||||||
|
), # Light Pink
|
||||||
|
DocItemLabel.FORM: (200, 255, 255), # Light Cyan
|
||||||
|
DocItemLabel.KEY_VALUE_REGION: (
|
||||||
|
183,
|
||||||
|
65,
|
||||||
|
14,
|
||||||
|
), # Rusty orange
|
||||||
|
}
|
||||||
|
|
||||||
image = copy.deepcopy(page.image)
|
image = copy.deepcopy(page.image)
|
||||||
if image is not None:
|
if image is not None:
|
||||||
draw = ImageDraw.Draw(image)
|
draw = ImageDraw.Draw(image, "RGBA")
|
||||||
for c in clusters:
|
for c in clusters:
|
||||||
x0, y0, x1, y1 = c.bbox.as_tuple()
|
cell_color = (0, 0, 0, 40)
|
||||||
draw.rectangle([(x0, y0), (x1, y1)], outline="green")
|
|
||||||
|
|
||||||
cell_color = (
|
|
||||||
random.randint(30, 140),
|
|
||||||
random.randint(30, 140),
|
|
||||||
random.randint(30, 140),
|
|
||||||
)
|
|
||||||
for tc in c.cells: # [:1]:
|
for tc in c.cells: # [:1]:
|
||||||
x0, y0, x1, y1 = tc.bbox.as_tuple()
|
cx0, cy0, cx1, cy1 = tc.bbox.as_tuple()
|
||||||
draw.rectangle(
|
draw.rectangle(
|
||||||
[(x0, y0), (x1, y1)], outline=cell_color
|
[(cx0, cy0), (cx1, cy1)],
|
||||||
|
outline=None,
|
||||||
|
fill=cell_color,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
x0, y0, x1, y1 = c.bbox.as_tuple()
|
||||||
|
cluster_fill_color = (
|
||||||
|
*list(label_to_color.get(c.label)), # type: ignore
|
||||||
|
70,
|
||||||
|
)
|
||||||
|
cluster_outline_color = (
|
||||||
|
*list(label_to_color.get(c.label)), # type: ignore
|
||||||
|
255,
|
||||||
|
)
|
||||||
|
draw.rectangle(
|
||||||
|
[(x0, y0), (x1, y1)],
|
||||||
|
outline=cluster_outline_color,
|
||||||
|
fill=cluster_fill_color,
|
||||||
|
)
|
||||||
|
|
||||||
if show:
|
if show:
|
||||||
image.show()
|
image.show()
|
||||||
else:
|
else:
|
||||||
@ -352,19 +276,42 @@ class LayoutModel(BasePageModel):
|
|||||||
out_path.mkdir(parents=True, exist_ok=True)
|
out_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
out_file = (
|
out_file = (
|
||||||
out_path / f"layout_page_{page.page_no:05}.png"
|
out_path
|
||||||
|
/ f"{mode_prefix}_layout_page_{page.page_no:05}.png"
|
||||||
)
|
)
|
||||||
image.save(str(out_file), format="png")
|
image.save(str(out_file), format="png")
|
||||||
|
|
||||||
# draw_clusters_and_cells()
|
if settings.debug.visualize_raw_layout:
|
||||||
|
self.draw_clusters_and_cells_side_by_side(
|
||||||
|
conv_res, page, clusters, mode_prefix="raw"
|
||||||
|
)
|
||||||
|
|
||||||
clusters, page.cells = self.postprocess(
|
# Apply postprocessing
|
||||||
clusters, page.cells, page.size.height
|
processed_clusters, processed_cells = LayoutPostprocessor(
|
||||||
|
page.cells, clusters
|
||||||
|
).postprocess()
|
||||||
|
# processed_clusters, processed_cells = clusters, page.cells
|
||||||
|
|
||||||
|
page.cells = processed_cells
|
||||||
|
page.predictions.layout = LayoutPrediction(
|
||||||
|
clusters=processed_clusters
|
||||||
)
|
)
|
||||||
|
|
||||||
page.predictions.layout = LayoutPrediction(clusters=clusters)
|
# def check_for_overlaps(clusters):
|
||||||
|
# for i, cluster in enumerate(clusters):
|
||||||
|
# for j, other_cluster in enumerate(clusters):
|
||||||
|
# if i >= j or cluster.label in [DocItemLabel.FORM, DocItemLabel.KEY_VALUE_REGION, DocItemLabel.PICTURE]:
|
||||||
|
# continue
|
||||||
|
#
|
||||||
|
# overlap_area = cluster.bbox.intersection_area_with(other_cluster.bbox)
|
||||||
|
# if overlap_area > 0:
|
||||||
|
# print(f"Overlap detected between cluster {i} and cluster {j}")
|
||||||
|
# print(f"Cluster {i} bbox: {cluster.bbox}, Cluster {j} bbox: {other_cluster.bbox}")
|
||||||
|
# check_for_overlaps(processed_clusters)
|
||||||
|
|
||||||
if settings.debug.visualize_layout:
|
if settings.debug.visualize_layout:
|
||||||
draw_clusters_and_cells()
|
self.draw_clusters_and_cells_side_by_side(
|
||||||
|
conv_res, page, processed_clusters, mode_prefix="postprocessed"
|
||||||
|
)
|
||||||
|
|
||||||
yield page
|
yield page
|
||||||
|
444
docling/utils/layout_postprocessor.py
Normal file
444
docling/utils/layout_postprocessor.py
Normal file
@ -0,0 +1,444 @@
|
|||||||
|
import logging
|
||||||
|
from bisect import bisect_left
|
||||||
|
from typing import Dict, List, Set, Tuple
|
||||||
|
|
||||||
|
from docling_core.types.doc import DocItemLabel
|
||||||
|
from rtree import index
|
||||||
|
|
||||||
|
from docling.datamodel.base_models import BoundingBox, Cell, Cluster
|
||||||
|
|
||||||
|
_log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class UnionFind:
|
||||||
|
"""Union-Find (Disjoint-Set) data structure."""
|
||||||
|
|
||||||
|
def __init__(self, elements):
|
||||||
|
self.parent = {elem: elem for elem in elements}
|
||||||
|
self.rank = {elem: 0 for elem in elements}
|
||||||
|
|
||||||
|
def find(self, x):
|
||||||
|
if self.parent[x] != x:
|
||||||
|
self.parent[x] = self.find(self.parent[x]) # Path compression
|
||||||
|
return self.parent[x]
|
||||||
|
|
||||||
|
def union(self, x, y):
|
||||||
|
root_x = self.find(x)
|
||||||
|
root_y = self.find(y)
|
||||||
|
if root_x != root_y:
|
||||||
|
if self.rank[root_x] > self.rank[root_y]:
|
||||||
|
self.parent[root_y] = root_x
|
||||||
|
elif self.rank[root_x] < self.rank[root_y]:
|
||||||
|
self.parent[root_x] = root_y
|
||||||
|
else:
|
||||||
|
self.parent[root_y] = root_x
|
||||||
|
self.rank[root_x] += 1
|
||||||
|
|
||||||
|
def groups(self):
|
||||||
|
"""Return groups as {root: [elements]}."""
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
components = defaultdict(list)
|
||||||
|
for elem in self.parent:
|
||||||
|
root = self.find(elem)
|
||||||
|
components[root].append(elem)
|
||||||
|
return components
|
||||||
|
|
||||||
|
|
||||||
|
class SpatialClusterIndex:
|
||||||
|
"""Helper class to manage spatial indexes for clusters and find overlaps efficiently."""
|
||||||
|
|
||||||
|
def __init__(self, clusters: List[Cluster]):
|
||||||
|
# Create spatial index
|
||||||
|
p = index.Property()
|
||||||
|
p.dimension = 2
|
||||||
|
self.spatial_index = index.Index(properties=p)
|
||||||
|
|
||||||
|
# Initialize interval trees
|
||||||
|
self.x_intervals = IntervalTree()
|
||||||
|
self.y_intervals = IntervalTree()
|
||||||
|
|
||||||
|
# Map to store clusters by ID
|
||||||
|
self.clusters_by_id = {} # type: ignore
|
||||||
|
|
||||||
|
# Populate indexes and maps
|
||||||
|
for cluster in clusters:
|
||||||
|
self.add_cluster(cluster)
|
||||||
|
|
||||||
|
def add_cluster(self, cluster: Cluster):
|
||||||
|
"""Add a cluster to all indexes."""
|
||||||
|
self.spatial_index.insert(cluster.id, cluster.bbox.as_tuple())
|
||||||
|
self.x_intervals.insert(cluster.bbox.l, cluster.bbox.r, cluster.id)
|
||||||
|
self.y_intervals.insert(cluster.bbox.t, cluster.bbox.b, cluster.id)
|
||||||
|
self.clusters_by_id[cluster.id] = cluster
|
||||||
|
|
||||||
|
def remove_cluster(self, cluster: Cluster):
|
||||||
|
"""Remove a cluster from all indexes."""
|
||||||
|
self.spatial_index.delete(cluster.id, cluster.bbox.as_tuple())
|
||||||
|
# Note: IntervalTree doesn't support deletion, but we handle this
|
||||||
|
# by checking clusters_by_id membership
|
||||||
|
del self.clusters_by_id[cluster.id]
|
||||||
|
|
||||||
|
def find_candidates(self, bbox: BoundingBox) -> Set[int]:
|
||||||
|
"""Find all potential overlapping cluster IDs using all indexes."""
|
||||||
|
bbox_tuple = bbox.as_tuple()
|
||||||
|
spatial_candidates = set(self.spatial_index.intersection(bbox_tuple))
|
||||||
|
x_candidates = self.x_intervals.find_containing(
|
||||||
|
bbox.l
|
||||||
|
) | self.x_intervals.find_containing(bbox.r)
|
||||||
|
y_candidates = self.y_intervals.find_containing(
|
||||||
|
bbox.t
|
||||||
|
) | self.y_intervals.find_containing(bbox.b)
|
||||||
|
return spatial_candidates | x_candidates | y_candidates
|
||||||
|
|
||||||
|
def check_overlap(
|
||||||
|
self,
|
||||||
|
bbox1: BoundingBox,
|
||||||
|
bbox2: BoundingBox,
|
||||||
|
overlap_threshold: float,
|
||||||
|
containment_threshold: float,
|
||||||
|
) -> bool:
|
||||||
|
"""Check if two bboxes overlap sufficiently."""
|
||||||
|
area1 = bbox1.area()
|
||||||
|
area2 = bbox2.area()
|
||||||
|
if area1 <= 0 or area2 <= 0:
|
||||||
|
return False
|
||||||
|
|
||||||
|
overlap_area = bbox1.intersection_area_with(bbox2)
|
||||||
|
if overlap_area <= 0:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check both IoU and containment
|
||||||
|
iou = overlap_area / (area1 + area2 - overlap_area)
|
||||||
|
containment_ratio1 = overlap_area / area1
|
||||||
|
containment_ratio2 = overlap_area / area2
|
||||||
|
|
||||||
|
return (
|
||||||
|
iou > overlap_threshold
|
||||||
|
or containment_ratio1 > containment_threshold
|
||||||
|
or containment_ratio2 > containment_threshold
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class IntervalTree:
|
||||||
|
def __init__(self):
|
||||||
|
self.intervals = [] # List of (min, max, box_id) sorted by min
|
||||||
|
|
||||||
|
def insert(self, min_val: float, max_val: float, box_id: int):
|
||||||
|
self.intervals.append((min_val, max_val, box_id))
|
||||||
|
self.intervals.sort(key=lambda x: x[0])
|
||||||
|
|
||||||
|
def find_containing(self, point: float) -> Set[int]:
|
||||||
|
pos = bisect_left(self.intervals, (point, float("-inf"), -1))
|
||||||
|
result = set()
|
||||||
|
|
||||||
|
i = pos - 1
|
||||||
|
while i >= 0:
|
||||||
|
min_val, max_val, box_id = self.intervals[i]
|
||||||
|
if min_val <= point <= max_val:
|
||||||
|
result.add(box_id)
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
i -= 1
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class LayoutPostprocessor:
|
||||||
|
"""Postprocesses layout predictions by cleaning up clusters and mapping cells."""
|
||||||
|
|
||||||
|
# Cluster type-specific parameters for overlap resolution
|
||||||
|
OVERLAP_PARAMS = {
|
||||||
|
"regular": {"area_threshold": 1.3, "conf_threshold": 0.15},
|
||||||
|
"picture": {"area_threshold": 2.0, "conf_threshold": 0.3},
|
||||||
|
"wrapper": {"area_threshold": 2.0, "conf_threshold": 0.2},
|
||||||
|
}
|
||||||
|
|
||||||
|
WRAPPER_TYPES = {DocItemLabel.FORM, DocItemLabel.KEY_VALUE_REGION}
|
||||||
|
SPECIAL_TYPES = WRAPPER_TYPES | {DocItemLabel.PICTURE}
|
||||||
|
|
||||||
|
CONFIDENCE_THRESHOLDS = {
|
||||||
|
DocItemLabel.CAPTION: 0.35,
|
||||||
|
DocItemLabel.FOOTNOTE: 0.35,
|
||||||
|
DocItemLabel.FORMULA: 0.35,
|
||||||
|
DocItemLabel.LIST_ITEM: 0.35,
|
||||||
|
DocItemLabel.PAGE_FOOTER: 0.35,
|
||||||
|
DocItemLabel.PAGE_HEADER: 0.35,
|
||||||
|
DocItemLabel.PICTURE: 0.1,
|
||||||
|
DocItemLabel.SECTION_HEADER: 0.45,
|
||||||
|
DocItemLabel.TABLE: 0.35,
|
||||||
|
DocItemLabel.TEXT: 0.45,
|
||||||
|
DocItemLabel.TITLE: 0.45,
|
||||||
|
DocItemLabel.CODE: 0.45,
|
||||||
|
DocItemLabel.CHECKBOX_SELECTED: 0.45,
|
||||||
|
DocItemLabel.CHECKBOX_UNSELECTED: 0.45,
|
||||||
|
DocItemLabel.FORM: 0.45,
|
||||||
|
DocItemLabel.KEY_VALUE_REGION: 0.45,
|
||||||
|
DocItemLabel.DOCUMENT_INDEX: 0.45,
|
||||||
|
}
|
||||||
|
|
||||||
|
LABEL_REMAPPING = {
|
||||||
|
DocItemLabel.DOCUMENT_INDEX: DocItemLabel.TABLE,
|
||||||
|
DocItemLabel.TITLE: DocItemLabel.SECTION_HEADER,
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, cells: List[Cell], clusters: List[Cluster]):
|
||||||
|
"""Initialize processor with cells and clusters."""
|
||||||
|
self.cells = cells
|
||||||
|
self.regular_clusters = [
|
||||||
|
c for c in clusters if c.label not in self.SPECIAL_TYPES
|
||||||
|
]
|
||||||
|
self.special_clusters = [c for c in clusters if c.label in self.SPECIAL_TYPES]
|
||||||
|
|
||||||
|
def postprocess(self) -> Tuple[List[Cluster], List[Cell]]:
|
||||||
|
"""Main processing pipeline."""
|
||||||
|
regular_clusters = self._process_regular_clusters()
|
||||||
|
special_clusters = self._process_special_clusters()
|
||||||
|
final_clusters = self._sort_clusters(regular_clusters + special_clusters)
|
||||||
|
return final_clusters, self.cells
|
||||||
|
|
||||||
|
def _process_regular_clusters(self) -> List[Cluster]:
|
||||||
|
"""Process regular clusters with iterative refinement."""
|
||||||
|
clusters = [
|
||||||
|
c
|
||||||
|
for c in self.regular_clusters
|
||||||
|
if c.confidence >= self.CONFIDENCE_THRESHOLDS[c.label]
|
||||||
|
]
|
||||||
|
|
||||||
|
# Apply label remapping
|
||||||
|
for cluster in clusters:
|
||||||
|
if cluster.label in self.LABEL_REMAPPING:
|
||||||
|
cluster.label = self.LABEL_REMAPPING[cluster.label]
|
||||||
|
|
||||||
|
# Initial cell assignment
|
||||||
|
clusters = self._assign_cells_to_clusters(clusters)
|
||||||
|
|
||||||
|
# Handle orphaned cells
|
||||||
|
unassigned = self._find_unassigned_cells(clusters)
|
||||||
|
if unassigned:
|
||||||
|
next_id = max((c.id for c in clusters), default=0) + 1
|
||||||
|
orphan_clusters = [
|
||||||
|
Cluster(
|
||||||
|
id=next_id + i,
|
||||||
|
label=DocItemLabel.TEXT,
|
||||||
|
bbox=cell.bbox,
|
||||||
|
confidence=0.0,
|
||||||
|
cells=[cell],
|
||||||
|
)
|
||||||
|
for i, cell in enumerate(unassigned)
|
||||||
|
]
|
||||||
|
clusters.extend(orphan_clusters)
|
||||||
|
|
||||||
|
# Iterative refinement
|
||||||
|
prev_count = len(clusters) + 1
|
||||||
|
for _ in range(3): # Maximum 3 iterations
|
||||||
|
if prev_count == len(clusters):
|
||||||
|
break
|
||||||
|
prev_count = len(clusters)
|
||||||
|
clusters = self._adjust_cluster_bboxes(clusters)
|
||||||
|
clusters = self._remove_overlapping_clusters(clusters, "regular")
|
||||||
|
|
||||||
|
return clusters
|
||||||
|
|
||||||
|
def _process_special_clusters(self) -> List[Cluster]:
|
||||||
|
"""Process special clusters (pictures and wrappers)."""
|
||||||
|
# Handle pictures
|
||||||
|
picture_clusters = [
|
||||||
|
c
|
||||||
|
for c in self.special_clusters
|
||||||
|
if c.label == DocItemLabel.PICTURE
|
||||||
|
and c.confidence >= self.CONFIDENCE_THRESHOLDS[c.label]
|
||||||
|
]
|
||||||
|
picture_clusters = self._remove_overlapping_clusters(
|
||||||
|
picture_clusters, "picture"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process wrapper clusters
|
||||||
|
wrapper_clusters = []
|
||||||
|
for wrapper in (
|
||||||
|
c for c in self.special_clusters if c.label in self.WRAPPER_TYPES
|
||||||
|
):
|
||||||
|
if wrapper.confidence < self.CONFIDENCE_THRESHOLDS[wrapper.label]:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Find contained regular clusters
|
||||||
|
contained = []
|
||||||
|
for cluster in self.regular_clusters:
|
||||||
|
overlap = cluster.bbox.intersection_area_with(wrapper.bbox)
|
||||||
|
if overlap > 0:
|
||||||
|
containment = overlap / cluster.bbox.area()
|
||||||
|
if containment > 0.8: # High containment threshold for wrappers
|
||||||
|
contained.append(cluster)
|
||||||
|
|
||||||
|
if contained:
|
||||||
|
wrapper.children = contained
|
||||||
|
wrapper.bbox = BoundingBox(
|
||||||
|
l=min(c.bbox.l for c in contained),
|
||||||
|
t=min(c.bbox.t for c in contained),
|
||||||
|
r=max(c.bbox.r for c in contained),
|
||||||
|
b=max(c.bbox.b for c in contained),
|
||||||
|
)
|
||||||
|
wrapper_clusters.append(wrapper)
|
||||||
|
|
||||||
|
return picture_clusters + self._remove_overlapping_clusters(
|
||||||
|
wrapper_clusters, "wrapper"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _remove_overlapping_clusters(
|
||||||
|
self,
|
||||||
|
clusters: List[Cluster],
|
||||||
|
cluster_type: str,
|
||||||
|
overlap_threshold: float = 0.8,
|
||||||
|
containment_threshold: float = 0.8,
|
||||||
|
) -> List[Cluster]:
|
||||||
|
"""Remove overlapping clusters using efficient spatial indexing."""
|
||||||
|
if not clusters:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Initialize spatial index
|
||||||
|
spatial_index = SpatialClusterIndex(clusters)
|
||||||
|
uf = UnionFind(spatial_index.clusters_by_id.keys())
|
||||||
|
|
||||||
|
# Group overlapping clusters using spatial index
|
||||||
|
for cluster in clusters:
|
||||||
|
candidates = spatial_index.find_candidates(cluster.bbox)
|
||||||
|
candidates.discard(cluster.id) # Remove self
|
||||||
|
|
||||||
|
for other_id in candidates:
|
||||||
|
if spatial_index.check_overlap(
|
||||||
|
cluster.bbox,
|
||||||
|
spatial_index.clusters_by_id[other_id].bbox,
|
||||||
|
overlap_threshold,
|
||||||
|
containment_threshold,
|
||||||
|
):
|
||||||
|
uf.union(cluster.id, other_id)
|
||||||
|
|
||||||
|
# Process each group using type-specific parameters
|
||||||
|
params = self.OVERLAP_PARAMS[cluster_type]
|
||||||
|
result = []
|
||||||
|
|
||||||
|
for group in uf.groups().values():
|
||||||
|
if len(group) == 1:
|
||||||
|
result.append(spatial_index.clusters_by_id[group[0]])
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Get clusters in group
|
||||||
|
group_clusters = [spatial_index.clusters_by_id[cid] for cid in group]
|
||||||
|
|
||||||
|
# Find best cluster using area and confidence
|
||||||
|
best = self._select_best_cluster(
|
||||||
|
group_clusters, params["area_threshold"], params["conf_threshold"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Merge cells from other clusters into best
|
||||||
|
for cluster in group_clusters:
|
||||||
|
if cluster != best:
|
||||||
|
best.cells.extend(cluster.cells)
|
||||||
|
|
||||||
|
result.append(best)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _select_best_cluster(
|
||||||
|
self,
|
||||||
|
clusters: List[Cluster],
|
||||||
|
area_threshold: float,
|
||||||
|
conf_threshold: float,
|
||||||
|
) -> Cluster:
|
||||||
|
"""Iteratively select best cluster based on area and confidence thresholds."""
|
||||||
|
current_best = None
|
||||||
|
for candidate in clusters:
|
||||||
|
should_select = True
|
||||||
|
for other in clusters:
|
||||||
|
if other == candidate:
|
||||||
|
continue
|
||||||
|
|
||||||
|
area_ratio = candidate.bbox.area() / other.bbox.area()
|
||||||
|
conf_diff = other.confidence - candidate.confidence
|
||||||
|
|
||||||
|
if area_ratio <= area_threshold and conf_diff > conf_threshold:
|
||||||
|
should_select = False
|
||||||
|
break
|
||||||
|
|
||||||
|
if should_select:
|
||||||
|
if current_best is None or (
|
||||||
|
candidate.bbox.area() > current_best.bbox.area()
|
||||||
|
and current_best.confidence - candidate.confidence <= conf_threshold
|
||||||
|
):
|
||||||
|
current_best = candidate
|
||||||
|
|
||||||
|
return current_best if current_best else clusters[0]
|
||||||
|
|
||||||
|
def _assign_cells_to_clusters(
|
||||||
|
self, clusters: List[Cluster], min_overlap: float = 0.2
|
||||||
|
) -> List[Cluster]:
|
||||||
|
"""Assign cells to best overlapping cluster."""
|
||||||
|
for cluster in clusters:
|
||||||
|
cluster.cells = []
|
||||||
|
|
||||||
|
for cell in self.cells:
|
||||||
|
if not cell.text.strip():
|
||||||
|
continue
|
||||||
|
|
||||||
|
best_overlap = min_overlap
|
||||||
|
best_cluster = None
|
||||||
|
|
||||||
|
for cluster in clusters:
|
||||||
|
if cell.bbox.area() <= 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
overlap = cell.bbox.intersection_area_with(cluster.bbox)
|
||||||
|
overlap_ratio = overlap / cell.bbox.area()
|
||||||
|
|
||||||
|
if overlap_ratio > best_overlap:
|
||||||
|
best_overlap = overlap_ratio
|
||||||
|
best_cluster = cluster
|
||||||
|
|
||||||
|
if best_cluster is not None:
|
||||||
|
best_cluster.cells.append(cell)
|
||||||
|
|
||||||
|
return clusters
|
||||||
|
|
||||||
|
def _find_unassigned_cells(self, clusters: List[Cluster]) -> List[Cell]:
|
||||||
|
"""Find cells not assigned to any cluster."""
|
||||||
|
assigned = {cell.id for cluster in clusters for cell in cluster.cells}
|
||||||
|
return [
|
||||||
|
cell for cell in self.cells if cell.id not in assigned and cell.text.strip()
|
||||||
|
]
|
||||||
|
|
||||||
|
def _adjust_cluster_bboxes(self, clusters: List[Cluster]) -> List[Cluster]:
|
||||||
|
"""Adjust cluster bounding boxes to contain their cells."""
|
||||||
|
for cluster in clusters:
|
||||||
|
if not cluster.cells:
|
||||||
|
continue
|
||||||
|
|
||||||
|
cells_bbox = BoundingBox(
|
||||||
|
l=min(cell.bbox.l for cell in cluster.cells),
|
||||||
|
t=min(cell.bbox.t for cell in cluster.cells),
|
||||||
|
r=max(cell.bbox.r for cell in cluster.cells),
|
||||||
|
b=max(cell.bbox.b for cell in cluster.cells),
|
||||||
|
)
|
||||||
|
|
||||||
|
if cluster.label == DocItemLabel.TABLE:
|
||||||
|
# For tables, take union of current bbox and cells bbox
|
||||||
|
cluster.bbox = BoundingBox(
|
||||||
|
l=min(cluster.bbox.l, cells_bbox.l),
|
||||||
|
t=min(cluster.bbox.t, cells_bbox.t),
|
||||||
|
r=max(cluster.bbox.r, cells_bbox.r),
|
||||||
|
b=max(cluster.bbox.b, cells_bbox.b),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cluster.bbox = cells_bbox
|
||||||
|
|
||||||
|
return clusters
|
||||||
|
|
||||||
|
def _sort_clusters(self, clusters: List[Cluster]) -> List[Cluster]:
|
||||||
|
"""Sort clusters in reading order (top-to-bottom, left-to-right)."""
|
||||||
|
|
||||||
|
def reading_order_key(cluster: Cluster) -> Tuple[float, float]:
|
||||||
|
if cluster.cells and cluster.label != DocItemLabel.PICTURE:
|
||||||
|
first_cell = min(cluster.cells, key=lambda c: (c.bbox.t, c.bbox.l))
|
||||||
|
return (first_cell.bbox.t, first_cell.bbox.l)
|
||||||
|
return (cluster.bbox.t, cluster.bbox.l)
|
||||||
|
|
||||||
|
return sorted(clusters, key=reading_order_key)
|
@ -1,812 +0,0 @@
|
|||||||
import copy
|
|
||||||
import logging
|
|
||||||
|
|
||||||
import networkx as nx
|
|
||||||
from docling_core.types.doc import DocItemLabel
|
|
||||||
|
|
||||||
logger = logging.getLogger("layout_utils")
|
|
||||||
|
|
||||||
|
|
||||||
## -------------------------------
|
|
||||||
## Geometric helper functions
|
|
||||||
## The coordinates grow left to right, and bottom to top.
|
|
||||||
## The bounding box list elements 0 to 3 are x_left, y_bottom, x_right, y_top.
|
|
||||||
|
|
||||||
|
|
||||||
def area(bbox):
|
|
||||||
return (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
|
|
||||||
|
|
||||||
|
|
||||||
def contains(bbox_i, bbox_j):
|
|
||||||
## Returns True if bbox_i contains bbox_j, else False
|
|
||||||
return (
|
|
||||||
bbox_i[0] <= bbox_j[0]
|
|
||||||
and bbox_i[1] <= bbox_j[1]
|
|
||||||
and bbox_i[2] >= bbox_j[2]
|
|
||||||
and bbox_i[3] >= bbox_j[3]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def is_intersecting(bbox_i, bbox_j):
|
|
||||||
return not (
|
|
||||||
bbox_i[2] < bbox_j[0]
|
|
||||||
or bbox_i[0] > bbox_j[2]
|
|
||||||
or bbox_i[3] < bbox_j[1]
|
|
||||||
or bbox_i[1] > bbox_j[3]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def bb_iou(boxA, boxB):
|
|
||||||
# determine the (x, y)-coordinates of the intersection rectangle
|
|
||||||
xA = max(boxA[0], boxB[0])
|
|
||||||
yA = max(boxA[1], boxB[1])
|
|
||||||
xB = min(boxA[2], boxB[2])
|
|
||||||
yB = min(boxA[3], boxB[3])
|
|
||||||
# compute the area of intersection rectangle
|
|
||||||
interArea = max(0, xB - xA + 1) * max(0, yB - yA + 1)
|
|
||||||
# compute the area of both the prediction and ground-truth
|
|
||||||
# rectangles
|
|
||||||
boxAArea = (boxA[2] - boxA[0] + 1) * (boxA[3] - boxA[1] + 1)
|
|
||||||
boxBArea = (boxB[2] - boxB[0] + 1) * (boxB[3] - boxB[1] + 1)
|
|
||||||
# compute the intersection over union by taking the intersection
|
|
||||||
# area and dividing it by the sum of prediction + ground-truth
|
|
||||||
# areas - the interesection area
|
|
||||||
iou = interArea / float(boxAArea + boxBArea - interArea)
|
|
||||||
# return the intersection over union value
|
|
||||||
return iou
|
|
||||||
|
|
||||||
|
|
||||||
def compute_intersection(bbox_i, bbox_j):
|
|
||||||
## Returns the size of the intersection area of the two boxes
|
|
||||||
if not is_intersecting(bbox_i, bbox_j):
|
|
||||||
return 0
|
|
||||||
## Determine the (x, y)-coordinates of the intersection rectangle:
|
|
||||||
xA = max(bbox_i[0], bbox_j[0])
|
|
||||||
yA = max(bbox_i[1], bbox_j[1])
|
|
||||||
xB = min(bbox_i[2], bbox_j[2])
|
|
||||||
yB = min(bbox_i[3], bbox_j[3])
|
|
||||||
## Compute the area of intersection rectangle:
|
|
||||||
interArea = (xB - xA) * (yB - yA)
|
|
||||||
if interArea < 0:
|
|
||||||
logger.debug("Warning: Negative intersection detected!")
|
|
||||||
return 0
|
|
||||||
return interArea
|
|
||||||
|
|
||||||
|
|
||||||
def surrounding(bbox_i, bbox_j):
|
|
||||||
## Computes minimal box that contains both input boxes
|
|
||||||
sbox = []
|
|
||||||
sbox.append(min(bbox_i[0], bbox_j[0]))
|
|
||||||
sbox.append(min(bbox_i[1], bbox_j[1]))
|
|
||||||
sbox.append(max(bbox_i[2], bbox_j[2]))
|
|
||||||
sbox.append(max(bbox_i[3], bbox_j[3]))
|
|
||||||
return sbox
|
|
||||||
|
|
||||||
|
|
||||||
def surrounding_list(bbox_list):
|
|
||||||
## Computes minimal box that contains all boxes in the input list
|
|
||||||
## The list should be non-empty, but just in case it's not:
|
|
||||||
if len(bbox_list) == 0:
|
|
||||||
sbox = [0, 0, 0, 0]
|
|
||||||
else:
|
|
||||||
sbox = []
|
|
||||||
sbox.append(min([bbox[0] for bbox in bbox_list]))
|
|
||||||
sbox.append(min([bbox[1] for bbox in bbox_list]))
|
|
||||||
sbox.append(max([bbox[2] for bbox in bbox_list]))
|
|
||||||
sbox.append(max([bbox[3] for bbox in bbox_list]))
|
|
||||||
return sbox
|
|
||||||
|
|
||||||
|
|
||||||
def vertical_overlap(bboxA, bboxB):
|
|
||||||
## bbox[1] is the lower bound, bbox[3] the upper bound (larger number)
|
|
||||||
if bboxB[3] < bboxA[1]: ## B below A
|
|
||||||
return False
|
|
||||||
elif bboxA[3] < bboxB[1]: ## A below B
|
|
||||||
return False
|
|
||||||
else:
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def vertical_overlap_fraction(bboxA, bboxB):
|
|
||||||
## Returns the vertical overlap as fraction of the lower bbox height.
|
|
||||||
## bbox[1] is the lower bound, bbox[3] the upper bound (larger number)
|
|
||||||
## Height 0 is permitted in the input.
|
|
||||||
heightA = bboxA[3] - bboxA[1]
|
|
||||||
heightB = bboxB[3] - bboxB[1]
|
|
||||||
min_height = min(heightA, heightB)
|
|
||||||
if bboxA[3] >= bboxB[3]: ## A starts higher or equal
|
|
||||||
if (
|
|
||||||
bboxA[1] <= bboxB[1]
|
|
||||||
): ## B is completely in A; this can include height of B = 0:
|
|
||||||
fraction = 1
|
|
||||||
else:
|
|
||||||
overlap = max(bboxB[3] - bboxA[1], 0)
|
|
||||||
fraction = overlap / max(min_height, 0.001)
|
|
||||||
else:
|
|
||||||
if (
|
|
||||||
bboxB[1] <= bboxA[1]
|
|
||||||
): ## A is completely in B; this can include height of A = 0:
|
|
||||||
fraction = 1
|
|
||||||
else:
|
|
||||||
overlap = max(bboxA[3] - bboxB[1], 0)
|
|
||||||
fraction = overlap / max(min_height, 0.001)
|
|
||||||
return fraction
|
|
||||||
|
|
||||||
|
|
||||||
## -------------------------------
|
|
||||||
## Cluster-and-cell relations
|
|
||||||
|
|
||||||
|
|
||||||
def compute_enclosed_cells(
|
|
||||||
cluster_bbox, raw_cells, min_cell_intersection_with_cluster=0.2
|
|
||||||
):
|
|
||||||
cells_in_cluster = []
|
|
||||||
cells_in_cluster_int = []
|
|
||||||
for ix, cell in enumerate(raw_cells):
|
|
||||||
cell_bbox = cell["bbox"]
|
|
||||||
intersection = compute_intersection(cell_bbox, cluster_bbox)
|
|
||||||
frac_area = area(cell_bbox) * min_cell_intersection_with_cluster
|
|
||||||
|
|
||||||
if (
|
|
||||||
intersection > frac_area and frac_area > 0
|
|
||||||
): # intersect > certain fraction of cell
|
|
||||||
cells_in_cluster.append(ix)
|
|
||||||
cells_in_cluster_int.append(intersection)
|
|
||||||
elif contains(
|
|
||||||
cluster_bbox,
|
|
||||||
[cell_bbox[0] + 3, cell_bbox[1] + 3, cell_bbox[2] - 3, cell_bbox[3] - 3],
|
|
||||||
):
|
|
||||||
cells_in_cluster.append(ix)
|
|
||||||
return cells_in_cluster, cells_in_cluster_int
|
|
||||||
|
|
||||||
|
|
||||||
def find_clusters_around_cells(cell_count, clusters):
|
|
||||||
## Per raw cell, find to which clusters it belongs.
|
|
||||||
## Return list of these indices in the raw-cell order.
|
|
||||||
clusters_around_cells = [[] for _ in range(cell_count)]
|
|
||||||
for cl_ix, cluster in enumerate(clusters):
|
|
||||||
for ix in cluster["cell_ids"]:
|
|
||||||
clusters_around_cells[ix].append(cl_ix)
|
|
||||||
return clusters_around_cells
|
|
||||||
|
|
||||||
|
|
||||||
def find_cell_index(raw_ix, cell_array):
|
|
||||||
## "raw_ix" is a rawcell_id.
|
|
||||||
## "cell_array" has the structure of an (annotation) cells array.
|
|
||||||
## Returns index of cell in cell_array that has this rawcell_id.
|
|
||||||
for ix, cell in enumerate(cell_array):
|
|
||||||
if cell["rawcell_id"] == raw_ix:
|
|
||||||
return ix
|
|
||||||
|
|
||||||
|
|
||||||
def find_cell_indices(cluster, cell_array):
|
|
||||||
## "cluster" must have the structure as in a clusters array in a prediction,
|
|
||||||
## "cell_array" that of a cells array.
|
|
||||||
## Returns list of indices of cells in cell_array that have the rawcell_ids as in the cluster,
|
|
||||||
## in the order of the rawcell_ids.
|
|
||||||
result = []
|
|
||||||
for raw_ix in sorted(cluster["cell_ids"]):
|
|
||||||
## Find the cell with this rawcell_id (if any)
|
|
||||||
for ix, cell in enumerate(cell_array):
|
|
||||||
if cell["rawcell_id"] == raw_ix:
|
|
||||||
result.append(ix)
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def find_first_cell_index(cluster, cell_array):
|
|
||||||
## "cluster" must be a dict with key "cell_ids"; it can also be a line.
|
|
||||||
## "cell_array" has the structure of a cells array in an annotation.
|
|
||||||
## Returns index of cell in cell_array that has the lowest rawcell_id from the cluster.
|
|
||||||
result = [] ## We keep it a list as it can be empty (picture without text cells)
|
|
||||||
if len(cluster["cell_ids"]) == 0:
|
|
||||||
return result
|
|
||||||
raw_ix = min(cluster["cell_ids"])
|
|
||||||
## Find the cell with this rawcell_id (if any)
|
|
||||||
for ix, cell in enumerate(cell_array):
|
|
||||||
if cell["rawcell_id"] == raw_ix:
|
|
||||||
result.append(ix)
|
|
||||||
break ## One is enough; should be only one anyway.
|
|
||||||
if result == []:
|
|
||||||
logger.debug(
|
|
||||||
" Warning: Raw cell " + str(raw_ix) + " not found in annotation cells"
|
|
||||||
)
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
## -------------------------------
|
|
||||||
## Cluster labels and text
|
|
||||||
|
|
||||||
|
|
||||||
def relabel_cluster(cluster, cl_ix, new_label, target_pred):
|
|
||||||
## "cluster" must have the structure as in a clusters array in a prediction,
|
|
||||||
## "cl_ix" is its index in target_pred,
|
|
||||||
## "new_label" is the intended new label,
|
|
||||||
## "target_pred" is the entire current target prediction.
|
|
||||||
## Sets label on the cluster itself, and on the cells in the target_pred.
|
|
||||||
## Returns new_label so that also the cl_label variable in the main code is easily set.
|
|
||||||
target_pred["clusters"][cl_ix]["type"] = new_label
|
|
||||||
cluster_target_cells = find_cell_indices(cluster, target_pred["cells"])
|
|
||||||
for ix in cluster_target_cells:
|
|
||||||
target_pred["cells"][ix]["label"] = new_label
|
|
||||||
return new_label
|
|
||||||
|
|
||||||
|
|
||||||
def find_cluster_text(cluster, raw_cells):
|
|
||||||
## "cluster" must be a dict with "cell_ids"; it can also be a line.
|
|
||||||
## "raw_cells" must have the format of item["raw"]["cells"]
|
|
||||||
## Returns the text of the cluster, with blanks between the cell contents
|
|
||||||
## (which seem to be words or phrases without starting or trailing blanks).
|
|
||||||
## Note that in formulas, this may give a lot more blanks than originally
|
|
||||||
cluster_text = ""
|
|
||||||
for raw_ix in sorted(cluster["cell_ids"]):
|
|
||||||
cluster_text = cluster_text + raw_cells[raw_ix]["text"] + " "
|
|
||||||
return cluster_text.rstrip()
|
|
||||||
|
|
||||||
|
|
||||||
def find_cluster_text_without_blanks(cluster, raw_cells):
|
|
||||||
## "cluster" must be a dict with "cell_ids"; it can also be a line.
|
|
||||||
## "raw_cells" must have the format of item["raw"]["cells"]
|
|
||||||
## Returns the text of the cluster, without blanks between the cell contents
|
|
||||||
## Interesting in formula analysis.
|
|
||||||
cluster_text = ""
|
|
||||||
for raw_ix in sorted(cluster["cell_ids"]):
|
|
||||||
cluster_text = cluster_text + raw_cells[raw_ix]["text"]
|
|
||||||
return cluster_text.rstrip()
|
|
||||||
|
|
||||||
|
|
||||||
## -------------------------------
|
|
||||||
## Clusters and lines
|
|
||||||
## (Most line-oriented functions are only needed in TextAnalysisGivenClusters,
|
|
||||||
## but this one also in FormulaAnalysis)
|
|
||||||
|
|
||||||
|
|
||||||
def build_cluster_from_lines(lines, label, id):
|
|
||||||
## Lines must be a non-empty list of dicts (lines) with elements "cell_ids" and "bbox"
|
|
||||||
## (There is no condition that they are really geometrically lines)
|
|
||||||
## A cluster in standard format is returned with given label and id
|
|
||||||
local_lines = copy.deepcopy(
|
|
||||||
lines
|
|
||||||
) ## without this, it changes "lines" also outside this function
|
|
||||||
first_line = local_lines.pop(0)
|
|
||||||
cluster = {
|
|
||||||
"id": id,
|
|
||||||
"type": label,
|
|
||||||
"cell_ids": first_line["cell_ids"],
|
|
||||||
"bbox": first_line["bbox"],
|
|
||||||
"confidence": 0,
|
|
||||||
"created_by": "merged_cells",
|
|
||||||
}
|
|
||||||
confidence = 0
|
|
||||||
counter = 0
|
|
||||||
for line in local_lines:
|
|
||||||
new_cell_ids = cluster["cell_ids"] + line["cell_ids"]
|
|
||||||
cluster["cell_ids"] = new_cell_ids
|
|
||||||
cluster["bbox"] = surrounding(cluster["bbox"], line["bbox"])
|
|
||||||
counter += 1
|
|
||||||
confidence += line["confidence"]
|
|
||||||
confidence = confidence / counter
|
|
||||||
cluster["confidence"] = confidence
|
|
||||||
return cluster
|
|
||||||
|
|
||||||
|
|
||||||
## -------------------------------
|
|
||||||
## Reading order
|
|
||||||
|
|
||||||
|
|
||||||
def produce_reading_order(clusters, cluster_sort_type, cell_sort_type, sort_ids):
|
|
||||||
## In:
|
|
||||||
## Clusters: list as in predictions.
|
|
||||||
## cluster_sort_type: string, currently only "raw_cells".
|
|
||||||
## cell_sort_type: string, currently only "raw_cells".
|
|
||||||
## sort_ids: Boolean, whether the cluster ids should be adapted to their new position
|
|
||||||
## Out: Another clusters list, sorted according to the type.
|
|
||||||
|
|
||||||
logger.debug("---- Start cluster sorting ------")
|
|
||||||
|
|
||||||
if cell_sort_type == "raw_cell_ids":
|
|
||||||
for cl in clusters:
|
|
||||||
sorted_cell_ids = sorted(cl["cell_ids"])
|
|
||||||
cl["cell_ids"] = sorted_cell_ids
|
|
||||||
else:
|
|
||||||
logger.debug(
|
|
||||||
"Unknown cell_sort_type `"
|
|
||||||
+ cell_sort_type
|
|
||||||
+ "`, no cell sorting will happen."
|
|
||||||
)
|
|
||||||
|
|
||||||
if cluster_sort_type == "raw_cell_ids":
|
|
||||||
clusters_with_cells = [cl for cl in clusters if cl["cell_ids"] != []]
|
|
||||||
clusters_without_cells = [cl for cl in clusters if cl["cell_ids"] == []]
|
|
||||||
logger.debug(
|
|
||||||
"Clusters with cells: " + str([cl["id"] for cl in clusters_with_cells])
|
|
||||||
)
|
|
||||||
logger.debug(
|
|
||||||
" Their first cell ids: "
|
|
||||||
+ str([cl["cell_ids"][0] for cl in clusters_with_cells])
|
|
||||||
)
|
|
||||||
logger.debug(
|
|
||||||
"Clusters without cells: "
|
|
||||||
+ str([cl["id"] for cl in clusters_without_cells])
|
|
||||||
)
|
|
||||||
clusters_with_cells_sorted = sorted(
|
|
||||||
clusters_with_cells, key=lambda cluster: cluster["cell_ids"][0]
|
|
||||||
)
|
|
||||||
logger.debug(
|
|
||||||
" First cell ids after sorting: "
|
|
||||||
+ str([cl["cell_ids"][0] for cl in clusters_with_cells_sorted])
|
|
||||||
)
|
|
||||||
sorted_clusters = clusters_with_cells_sorted + clusters_without_cells
|
|
||||||
else:
|
|
||||||
logger.debug(
|
|
||||||
"Unknown cluster_sort_type: `"
|
|
||||||
+ cluster_sort_type
|
|
||||||
+ "`, no cluster sorting will happen."
|
|
||||||
)
|
|
||||||
|
|
||||||
if sort_ids:
|
|
||||||
for i, cl in enumerate(sorted_clusters):
|
|
||||||
cl["id"] = i
|
|
||||||
return sorted_clusters
|
|
||||||
|
|
||||||
|
|
||||||
## -------------------------------
|
|
||||||
## Line Splitting
|
|
||||||
|
|
||||||
|
|
||||||
def sort_cells_horizontal(line_cell_ids, raw_cells):
|
|
||||||
## "line_cells" should be a non-empty list of (raw) cell_ids
|
|
||||||
## "raw_cells" has the structure of item["raw"]["cells"].
|
|
||||||
## Sorts the cells in the line by x0 (left start).
|
|
||||||
new_line_cell_ids = sorted(
|
|
||||||
line_cell_ids, key=lambda cell_id: raw_cells[cell_id]["bbox"][0]
|
|
||||||
)
|
|
||||||
return new_line_cell_ids
|
|
||||||
|
|
||||||
|
|
||||||
def adapt_bboxes(raw_cells, clusters, orphan_cell_indices):
|
|
||||||
new_clusters = []
|
|
||||||
for ix, cluster in enumerate(clusters):
|
|
||||||
new_cluster = copy.deepcopy(cluster)
|
|
||||||
logger.debug(
|
|
||||||
"Treating cluster " + str(ix) + ", type " + str(new_cluster["type"])
|
|
||||||
)
|
|
||||||
logger.debug(" with cells: " + str(new_cluster["cell_ids"]))
|
|
||||||
if len(cluster["cell_ids"]) == 0 and cluster["type"] != DocItemLabel.PICTURE:
|
|
||||||
logger.debug(" Empty non-picture, removed")
|
|
||||||
continue ## Skip this former cluster, now without cells.
|
|
||||||
new_bbox = adapt_bbox(raw_cells, new_cluster, orphan_cell_indices)
|
|
||||||
new_cluster["bbox"] = new_bbox
|
|
||||||
new_clusters.append(new_cluster)
|
|
||||||
return new_clusters
|
|
||||||
|
|
||||||
|
|
||||||
def adapt_bbox(raw_cells, cluster, orphan_cell_indices):
|
|
||||||
if not (cluster["type"] in [DocItemLabel.TABLE, DocItemLabel.PICTURE]):
|
|
||||||
## A text-like cluster. The bbox only needs to be around the text cells:
|
|
||||||
logger.debug(" Initial bbox: " + str(cluster["bbox"]))
|
|
||||||
new_bbox = surrounding_list(
|
|
||||||
[raw_cells[cid]["bbox"] for cid in cluster["cell_ids"]]
|
|
||||||
)
|
|
||||||
logger.debug(" New bounding box:" + str(new_bbox))
|
|
||||||
if cluster["type"] == DocItemLabel.PICTURE:
|
|
||||||
## We only make the bbox completely comprise included text cells:
|
|
||||||
logger.debug(" Picture")
|
|
||||||
if len(cluster["cell_ids"]) != 0:
|
|
||||||
min_bbox = surrounding_list(
|
|
||||||
[raw_cells[cid]["bbox"] for cid in cluster["cell_ids"]]
|
|
||||||
)
|
|
||||||
logger.debug(" Minimum bbox: " + str(min_bbox))
|
|
||||||
logger.debug(" Initial bbox: " + str(cluster["bbox"]))
|
|
||||||
new_bbox = surrounding(min_bbox, cluster["bbox"])
|
|
||||||
logger.debug(" New bbox (initial and text cells): " + str(new_bbox))
|
|
||||||
else:
|
|
||||||
logger.debug(" without text cells, no change.")
|
|
||||||
new_bbox = cluster["bbox"]
|
|
||||||
else: ## A table
|
|
||||||
## At least we have to keep the included text cells, and we make the bbox completely comprise them
|
|
||||||
min_bbox = surrounding_list(
|
|
||||||
[raw_cells[cid]["bbox"] for cid in cluster["cell_ids"]]
|
|
||||||
)
|
|
||||||
logger.debug(" Minimum bbox: " + str(min_bbox))
|
|
||||||
logger.debug(" Initial bbox: " + str(cluster["bbox"]))
|
|
||||||
new_bbox = surrounding(min_bbox, cluster["bbox"])
|
|
||||||
logger.debug(" Possibly increased bbox: " + str(new_bbox))
|
|
||||||
|
|
||||||
## Now we look which non-belonging cells are covered.
|
|
||||||
## (To decrease dependencies, we don't make use of which cells we actually removed.)
|
|
||||||
## We don't worry about orphan cells, those could still be added to the table.
|
|
||||||
enclosed_cells = compute_enclosed_cells(
|
|
||||||
new_bbox, raw_cells, min_cell_intersection_with_cluster=0.3
|
|
||||||
)[0]
|
|
||||||
additional_cells = set(enclosed_cells) - set(cluster["cell_ids"])
|
|
||||||
logger.debug(
|
|
||||||
" Additional cells enclosed by Table bbox: " + str(additional_cells)
|
|
||||||
)
|
|
||||||
spurious_cells = additional_cells - set(orphan_cell_indices)
|
|
||||||
logger.debug(
|
|
||||||
" Spurious cells enclosed by Table bbox (additional minus orphans): "
|
|
||||||
+ str(spurious_cells)
|
|
||||||
)
|
|
||||||
if len(spurious_cells) == 0:
|
|
||||||
return new_bbox
|
|
||||||
|
|
||||||
## Else we want to keep as much as possible, e.g., grid lines, but not the spurious cells if we can.
|
|
||||||
## We initialize possible cuts with the current bbox.
|
|
||||||
left_cut = new_bbox[0]
|
|
||||||
right_cut = new_bbox[2]
|
|
||||||
upper_cut = new_bbox[3]
|
|
||||||
lower_cut = new_bbox[1]
|
|
||||||
|
|
||||||
for cell_ix in spurious_cells:
|
|
||||||
cell = raw_cells[cell_ix]
|
|
||||||
# logger.debug(" Spurious cell bbox: " + str(cell["bbox"]))
|
|
||||||
is_left = cell["bbox"][2] < min_bbox[0]
|
|
||||||
is_right = cell["bbox"][0] > min_bbox[2]
|
|
||||||
is_above = cell["bbox"][1] > min_bbox[3]
|
|
||||||
is_below = cell["bbox"][3] < min_bbox[1]
|
|
||||||
# logger.debug(" Left, right, above, below? " + str([is_left, is_right, is_above, is_below]))
|
|
||||||
|
|
||||||
if is_left:
|
|
||||||
if cell["bbox"][2] > left_cut:
|
|
||||||
## We move the left cut to exclude this cell:
|
|
||||||
left_cut = cell["bbox"][2]
|
|
||||||
if is_right:
|
|
||||||
if cell["bbox"][0] < right_cut:
|
|
||||||
## We move the right cut to exclude this cell:
|
|
||||||
right_cut = cell["bbox"][0]
|
|
||||||
if is_above:
|
|
||||||
if cell["bbox"][1] < upper_cut:
|
|
||||||
## We move the upper cut to exclude this cell:
|
|
||||||
upper_cut = cell["bbox"][1]
|
|
||||||
if is_below:
|
|
||||||
if cell["bbox"][3] > lower_cut:
|
|
||||||
## We move the left cut to exclude this cell:
|
|
||||||
lower_cut = cell["bbox"][3]
|
|
||||||
# logger.debug(" Current bbox: " + str([left_cut, lower_cut, right_cut, upper_cut]))
|
|
||||||
|
|
||||||
new_bbox = [left_cut, lower_cut, right_cut, upper_cut]
|
|
||||||
|
|
||||||
logger.debug(" Final bbox: " + str(new_bbox))
|
|
||||||
return new_bbox
|
|
||||||
|
|
||||||
|
|
||||||
def remove_cluster_duplicates_by_conf(cluster_predictions, threshold=0.5):
|
|
||||||
DuplicateDeletedClusterIDs = []
|
|
||||||
for cluster_1 in cluster_predictions:
|
|
||||||
for cluster_2 in cluster_predictions:
|
|
||||||
if cluster_1["id"] != cluster_2["id"]:
|
|
||||||
if_conf = False
|
|
||||||
if cluster_1["confidence"] > cluster_2["confidence"]:
|
|
||||||
if_conf = True
|
|
||||||
if if_conf == True:
|
|
||||||
if bb_iou(cluster_1["bbox"], cluster_2["bbox"]) > threshold:
|
|
||||||
DuplicateDeletedClusterIDs.append(cluster_2["id"])
|
|
||||||
elif contains(
|
|
||||||
cluster_1["bbox"],
|
|
||||||
[
|
|
||||||
cluster_2["bbox"][0] + 3,
|
|
||||||
cluster_2["bbox"][1] + 3,
|
|
||||||
cluster_2["bbox"][2] - 3,
|
|
||||||
cluster_2["bbox"][3] - 3,
|
|
||||||
],
|
|
||||||
):
|
|
||||||
DuplicateDeletedClusterIDs.append(cluster_2["id"])
|
|
||||||
|
|
||||||
DuplicateDeletedClusterIDs = list(set(DuplicateDeletedClusterIDs))
|
|
||||||
|
|
||||||
for cl_id in DuplicateDeletedClusterIDs:
|
|
||||||
for cluster in cluster_predictions:
|
|
||||||
if cl_id == cluster["id"]:
|
|
||||||
cluster_predictions.remove(cluster)
|
|
||||||
return cluster_predictions
|
|
||||||
|
|
||||||
|
|
||||||
# Assign orphan cells by a low confidence prediction that is below the assigned confidence
|
|
||||||
def assign_orphans_with_low_conf_pred(
|
|
||||||
cluster_predictions, cluster_predictions_low, raw_cells, orphan_cell_indices
|
|
||||||
):
|
|
||||||
for orph_id in orphan_cell_indices:
|
|
||||||
cluster_chosen = {}
|
|
||||||
iou_thresh = 0.05
|
|
||||||
confidence = 0.05
|
|
||||||
|
|
||||||
# Loop over all predictions, and find the one with the highest IOU, and confidence
|
|
||||||
for cluster in cluster_predictions_low:
|
|
||||||
calc_iou = bb_iou(cluster["bbox"], raw_cells[orph_id]["bbox"])
|
|
||||||
cluster_area = (cluster["bbox"][3] - cluster["bbox"][1]) * (
|
|
||||||
cluster["bbox"][2] - cluster["bbox"][0]
|
|
||||||
)
|
|
||||||
cell_area = (
|
|
||||||
raw_cells[orph_id]["bbox"][3] - raw_cells[orph_id]["bbox"][1]
|
|
||||||
) * (raw_cells[orph_id]["bbox"][2] - raw_cells[orph_id]["bbox"][0])
|
|
||||||
|
|
||||||
if (
|
|
||||||
(iou_thresh < calc_iou)
|
|
||||||
and (cluster["confidence"] > confidence)
|
|
||||||
and (cell_area * 3 > cluster_area)
|
|
||||||
):
|
|
||||||
cluster_chosen = cluster
|
|
||||||
iou_thresh = calc_iou
|
|
||||||
confidence = cluster["confidence"]
|
|
||||||
# If a candidate is found, assign to it the PDF cell ids, and tag that it was created by this function for tracking
|
|
||||||
if iou_thresh != 0.05 and confidence != 0.05:
|
|
||||||
cluster_chosen["cell_ids"].append(orph_id)
|
|
||||||
cluster_chosen["created_by"] = "orph_low_conf"
|
|
||||||
cluster_predictions.append(cluster_chosen)
|
|
||||||
orphan_cell_indices.remove(orph_id)
|
|
||||||
return cluster_predictions, orphan_cell_indices
|
|
||||||
|
|
||||||
|
|
||||||
def remove_ambigous_pdf_cell_by_conf(cluster_predictions, raw_cells, amb_cell_idxs):
|
|
||||||
for amb_cell_id in amb_cell_idxs:
|
|
||||||
highest_conf = 0
|
|
||||||
highest_bbox_iou = 0
|
|
||||||
cluster_chosen = None
|
|
||||||
problamatic_clusters = []
|
|
||||||
|
|
||||||
# Find clusters in question
|
|
||||||
for cluster in cluster_predictions:
|
|
||||||
|
|
||||||
if amb_cell_id in cluster["cell_ids"]:
|
|
||||||
problamatic_clusters.append(amb_cell_id)
|
|
||||||
|
|
||||||
# If the cell_id is in a cluster of high conf, and highest iou score, and smaller in area
|
|
||||||
bbox_iou_val = bb_iou(cluster["bbox"], raw_cells[amb_cell_id]["bbox"])
|
|
||||||
|
|
||||||
if (
|
|
||||||
cluster["confidence"] > highest_conf
|
|
||||||
and bbox_iou_val > highest_bbox_iou
|
|
||||||
):
|
|
||||||
cluster_chosen = cluster
|
|
||||||
highest_conf = cluster["confidence"]
|
|
||||||
highest_bbox_iou = bbox_iou_val
|
|
||||||
if cluster["id"] in problamatic_clusters:
|
|
||||||
problamatic_clusters.remove(cluster["id"])
|
|
||||||
|
|
||||||
# now remove the assigning of cell id from lower confidence, and threshold
|
|
||||||
for cluster in cluster_predictions:
|
|
||||||
for prob_amb_id in problamatic_clusters:
|
|
||||||
if prob_amb_id in cluster["cell_ids"]:
|
|
||||||
cluster["cell_ids"].remove(prob_amb_id)
|
|
||||||
amb_cell_idxs.remove(amb_cell_id)
|
|
||||||
|
|
||||||
return cluster_predictions, amb_cell_idxs
|
|
||||||
|
|
||||||
|
|
||||||
def ranges(nums):
|
|
||||||
# Find if consecutive numbers exist within pdf cells
|
|
||||||
# Used to remove line numbers for review manuscripts
|
|
||||||
nums = sorted(set(nums))
|
|
||||||
gaps = [[s, e] for s, e in zip(nums, nums[1:]) if s + 1 < e]
|
|
||||||
edges = iter(nums[:1] + sum(gaps, []) + nums[-1:])
|
|
||||||
return list(zip(edges, edges))
|
|
||||||
|
|
||||||
|
|
||||||
def set_orphan_as_text(
|
|
||||||
cluster_predictions, cluster_predictions_low, raw_cells, orphan_cell_indices
|
|
||||||
):
|
|
||||||
max_id = -1
|
|
||||||
figures = []
|
|
||||||
for cluster in cluster_predictions:
|
|
||||||
if cluster["type"] == DocItemLabel.PICTURE:
|
|
||||||
figures.append(cluster)
|
|
||||||
|
|
||||||
if cluster["id"] > max_id:
|
|
||||||
max_id = cluster["id"]
|
|
||||||
max_id += 1
|
|
||||||
|
|
||||||
lines_detector = False
|
|
||||||
content_of_orphans = []
|
|
||||||
for orph_id in orphan_cell_indices:
|
|
||||||
orph_cell = raw_cells[orph_id]
|
|
||||||
content_of_orphans.append(raw_cells[orph_id]["text"])
|
|
||||||
|
|
||||||
fil_content_of_orphans = []
|
|
||||||
for cell_content in content_of_orphans:
|
|
||||||
if cell_content.isnumeric():
|
|
||||||
try:
|
|
||||||
num = int(cell_content)
|
|
||||||
fil_content_of_orphans.append(num)
|
|
||||||
except ValueError: # ignore the cell
|
|
||||||
pass
|
|
||||||
|
|
||||||
# line_orphans = []
|
|
||||||
# Check if there are more than 2 pdf orphan cells, if there are more than 2,
|
|
||||||
# then check between the orphan cells if they are numeric
|
|
||||||
# and if they are a consecutive series of numbers (using ranges function) to decide
|
|
||||||
|
|
||||||
if len(fil_content_of_orphans) > 2:
|
|
||||||
out_ranges = ranges(fil_content_of_orphans)
|
|
||||||
if len(out_ranges) > 1:
|
|
||||||
cnt_range = 0
|
|
||||||
for ranges_ in out_ranges:
|
|
||||||
if ranges_[0] != ranges_[1]:
|
|
||||||
# If there are more than 75 (half the total line number of a review manuscript page)
|
|
||||||
# decide that there are line numbers on page to be ignored.
|
|
||||||
if len(list(range(ranges_[0], ranges_[1]))) > 75:
|
|
||||||
lines_detector = True
|
|
||||||
# line_orphans = line_orphans + list(range(ranges_[0], ranges_[1]))
|
|
||||||
|
|
||||||
for orph_id in orphan_cell_indices:
|
|
||||||
orph_cell = raw_cells[orph_id]
|
|
||||||
if bool(orph_cell["text"] and not orph_cell["text"].isspace()):
|
|
||||||
fig_flag = False
|
|
||||||
# Do not assign orphan cells if they are inside a figure
|
|
||||||
for fig in figures:
|
|
||||||
if contains(fig["bbox"], orph_cell["bbox"]):
|
|
||||||
fig_flag = True
|
|
||||||
|
|
||||||
# if fig_flag == False and raw_cells[orph_id]["text"] not in line_orphans:
|
|
||||||
if fig_flag == False and lines_detector == False:
|
|
||||||
# get class from low confidence detections if not set as text:
|
|
||||||
class_type = DocItemLabel.TEXT
|
|
||||||
|
|
||||||
for cluster in cluster_predictions_low:
|
|
||||||
intersection = compute_intersection(
|
|
||||||
orph_cell["bbox"], cluster["bbox"]
|
|
||||||
)
|
|
||||||
class_type = DocItemLabel.TEXT
|
|
||||||
if (
|
|
||||||
cluster["confidence"] > 0.1
|
|
||||||
and bb_iou(cluster["bbox"], orph_cell["bbox"]) > 0.4
|
|
||||||
):
|
|
||||||
class_type = cluster["type"]
|
|
||||||
elif contains(
|
|
||||||
cluster["bbox"],
|
|
||||||
[
|
|
||||||
orph_cell["bbox"][0] + 3,
|
|
||||||
orph_cell["bbox"][1] + 3,
|
|
||||||
orph_cell["bbox"][2] - 3,
|
|
||||||
orph_cell["bbox"][3] - 3,
|
|
||||||
],
|
|
||||||
):
|
|
||||||
class_type = cluster["type"]
|
|
||||||
elif intersection > area(orph_cell["bbox"]) * 0.2:
|
|
||||||
class_type = cluster["type"]
|
|
||||||
|
|
||||||
new_cluster = {
|
|
||||||
"id": max_id,
|
|
||||||
"bbox": orph_cell["bbox"],
|
|
||||||
"type": class_type,
|
|
||||||
"cell_ids": [orph_id],
|
|
||||||
"confidence": -1,
|
|
||||||
"created_by": "orphan_default",
|
|
||||||
}
|
|
||||||
max_id += 1
|
|
||||||
cluster_predictions.append(new_cluster)
|
|
||||||
return cluster_predictions, orphan_cell_indices
|
|
||||||
|
|
||||||
|
|
||||||
def merge_cells(cluster_predictions):
|
|
||||||
# Using graph component creates clusters if orphan cells are touching or too close.
|
|
||||||
G = nx.Graph()
|
|
||||||
for cluster in cluster_predictions:
|
|
||||||
if cluster["created_by"] == "orphan_default":
|
|
||||||
G.add_node(cluster["id"])
|
|
||||||
|
|
||||||
for cluster_1 in cluster_predictions:
|
|
||||||
for cluster_2 in cluster_predictions:
|
|
||||||
if (
|
|
||||||
cluster_1["id"] != cluster_2["id"]
|
|
||||||
and cluster_2["created_by"] == "orphan_default"
|
|
||||||
and cluster_1["created_by"] == "orphan_default"
|
|
||||||
):
|
|
||||||
cl1 = copy.deepcopy(cluster_1["bbox"])
|
|
||||||
cl2 = copy.deepcopy(cluster_2["bbox"])
|
|
||||||
cl1[0] = cl1[0] - 2
|
|
||||||
cl1[1] = cl1[1] - 2
|
|
||||||
cl1[2] = cl1[2] + 2
|
|
||||||
cl1[3] = cl1[3] + 2
|
|
||||||
cl2[0] = cl2[0] - 2
|
|
||||||
cl2[1] = cl2[1] - 2
|
|
||||||
cl2[2] = cl2[2] + 2
|
|
||||||
cl2[3] = cl2[3] + 2
|
|
||||||
if is_intersecting(cl1, cl2):
|
|
||||||
G.add_edge(cluster_1["id"], cluster_2["id"])
|
|
||||||
|
|
||||||
component = sorted(map(sorted, nx.k_edge_components(G, k=1)))
|
|
||||||
max_id = -1
|
|
||||||
for cluster_1 in cluster_predictions:
|
|
||||||
if cluster_1["id"] > max_id:
|
|
||||||
max_id = cluster_1["id"]
|
|
||||||
|
|
||||||
for nodes in component:
|
|
||||||
if len(nodes) > 1:
|
|
||||||
max_id += 1
|
|
||||||
lines = []
|
|
||||||
for node in nodes:
|
|
||||||
for cluster in cluster_predictions:
|
|
||||||
if cluster["id"] == node:
|
|
||||||
lines.append(cluster)
|
|
||||||
cluster_predictions.remove(cluster)
|
|
||||||
new_merged_cluster = build_cluster_from_lines(
|
|
||||||
lines, DocItemLabel.TEXT, max_id
|
|
||||||
)
|
|
||||||
cluster_predictions.append(new_merged_cluster)
|
|
||||||
return cluster_predictions
|
|
||||||
|
|
||||||
|
|
||||||
def clean_up_clusters(
|
|
||||||
cluster_predictions,
|
|
||||||
raw_cells,
|
|
||||||
merge_cells=False,
|
|
||||||
img_table=False,
|
|
||||||
one_cell_table=False,
|
|
||||||
):
|
|
||||||
DuplicateDeletedClusterIDs = []
|
|
||||||
|
|
||||||
for cluster_1 in cluster_predictions:
|
|
||||||
for cluster_2 in cluster_predictions:
|
|
||||||
if cluster_1["id"] != cluster_2["id"]:
|
|
||||||
# remove any artifcats created by merging clusters
|
|
||||||
if merge_cells == True:
|
|
||||||
if contains(
|
|
||||||
cluster_1["bbox"],
|
|
||||||
[
|
|
||||||
cluster_2["bbox"][0] + 3,
|
|
||||||
cluster_2["bbox"][1] + 3,
|
|
||||||
cluster_2["bbox"][2] - 3,
|
|
||||||
cluster_2["bbox"][3] - 3,
|
|
||||||
],
|
|
||||||
):
|
|
||||||
cluster_1["cell_ids"] = (
|
|
||||||
cluster_1["cell_ids"] + cluster_2["cell_ids"]
|
|
||||||
)
|
|
||||||
DuplicateDeletedClusterIDs.append(cluster_2["id"])
|
|
||||||
# remove clusters that might appear inside tables, or images (such as pdf cells in graphs)
|
|
||||||
elif img_table == True:
|
|
||||||
if (
|
|
||||||
cluster_1["type"] == DocItemLabel.TEXT
|
|
||||||
and cluster_2["type"] == DocItemLabel.PICTURE
|
|
||||||
or cluster_2["type"] == DocItemLabel.TABLE
|
|
||||||
):
|
|
||||||
if bb_iou(cluster_1["bbox"], cluster_2["bbox"]) > 0.5:
|
|
||||||
DuplicateDeletedClusterIDs.append(cluster_1["id"])
|
|
||||||
elif contains(
|
|
||||||
[
|
|
||||||
cluster_2["bbox"][0] - 3,
|
|
||||||
cluster_2["bbox"][1] - 3,
|
|
||||||
cluster_2["bbox"][2] + 3,
|
|
||||||
cluster_2["bbox"][3] + 3,
|
|
||||||
],
|
|
||||||
cluster_1["bbox"],
|
|
||||||
):
|
|
||||||
DuplicateDeletedClusterIDs.append(cluster_1["id"])
|
|
||||||
# remove tables that have one pdf cell
|
|
||||||
if one_cell_table == True:
|
|
||||||
if (
|
|
||||||
cluster_1["type"] == DocItemLabel.TABLE
|
|
||||||
and len(cluster_1["cell_ids"]) < 2
|
|
||||||
):
|
|
||||||
DuplicateDeletedClusterIDs.append(cluster_1["id"])
|
|
||||||
|
|
||||||
DuplicateDeletedClusterIDs = list(set(DuplicateDeletedClusterIDs))
|
|
||||||
|
|
||||||
for cl_id in DuplicateDeletedClusterIDs:
|
|
||||||
for cluster in cluster_predictions:
|
|
||||||
if cl_id == cluster["id"]:
|
|
||||||
cluster_predictions.remove(cluster)
|
|
||||||
return cluster_predictions
|
|
||||||
|
|
||||||
|
|
||||||
def assigning_cell_ids_to_clusters(clusters, raw_cells, threshold):
|
|
||||||
for cluster in clusters:
|
|
||||||
cells_in_cluster, _ = compute_enclosed_cells(
|
|
||||||
cluster["bbox"], raw_cells, min_cell_intersection_with_cluster=threshold
|
|
||||||
)
|
|
||||||
cluster["cell_ids"] = cells_in_cluster
|
|
||||||
## These cell_ids are ids of the raw cells.
|
|
||||||
## They are often, but not always, the same as the "id" or the index of the "cells" list in a prediction.
|
|
||||||
return clusters
|
|
||||||
|
|
||||||
|
|
||||||
# Creates a map of cell_id->cluster_id
|
|
||||||
def cell_id_state_map(clusters, cell_count):
|
|
||||||
clusters_around_cells = find_clusters_around_cells(cell_count, clusters)
|
|
||||||
orphan_cell_indices = [
|
|
||||||
ix for ix in range(cell_count) if len(clusters_around_cells[ix]) == 0
|
|
||||||
] # which cells are assigned no cluster?
|
|
||||||
ambiguous_cell_indices = [
|
|
||||||
ix for ix in range(cell_count) if len(clusters_around_cells[ix]) > 1
|
|
||||||
] # which cells are assigned > 1 clusters?
|
|
||||||
return clusters_around_cells, orphan_cell_indices, ambiguous_cell_indices
|
|
@ -74,6 +74,10 @@ def main():
|
|||||||
pipeline_options.do_ocr = True
|
pipeline_options.do_ocr = True
|
||||||
pipeline_options.do_table_structure = True
|
pipeline_options.do_table_structure = True
|
||||||
pipeline_options.table_structure_options.do_cell_matching = True
|
pipeline_options.table_structure_options.do_cell_matching = True
|
||||||
|
pipeline_options.ocr_options.lang = "es"
|
||||||
|
pipeline_options.accelerator_options = AcceleratorOptions(
|
||||||
|
num_threads=4, device=Device.AUTO
|
||||||
|
)
|
||||||
|
|
||||||
doc_converter = DocumentConverter(
|
doc_converter = DocumentConverter(
|
||||||
format_options={
|
format_options={
|
||||||
|
Loading…
Reference in New Issue
Block a user