feat: Updated Layout processing with forms and key-value areas (#530)

* Upgraded Layout Postprocessing, sending old code back to ERZ

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Implement hierachical cluster layout processing

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Pass nested cluster processing through full pipeline

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Pass nested clusters through GLM as payload

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Move to_docling_document from ds-glm to this repo

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Clean up imports again

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* feat(Accelerator): Introduce options to control the num_threads and device from API, envvars, CLI.
- Introduce the AcceleratorOptions, AcceleratorDevice and use them to set the device where the models run.
- Introduce the accelerator_utils with function to decide the device and resolve the AUTO setting.
- Refactor the way how the docling-ibm-models are called to match the new init signature of models.
- Translate the accelerator options to the specific inputs for third-party models.
- Extend the docling CLI with parameters to set the num_threads and device.
- Add new unit tests.
- Write new example how to use the accelerator options.

* fix: Improve the pydantic objects in the pipeline_options and imports.

Signed-off-by: Nikos Livathinos <nli@zurich.ibm.com>

* fix: TableStructureModel: Refactor the artifacts path to use the new structure for fast/accurate model

Signed-off-by: Nikos Livathinos <nli@zurich.ibm.com>

* Updated test ground-truth

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Updated test ground-truth (again), bugfix for empty layout

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* fix: Do proper check to set the device in EasyOCR, RapidOCR.

Signed-off-by: Nikos Livathinos <nli@zurich.ibm.com>

* fix: Correct the way to set GPU for EasyOCR, RapidOCR

Signed-off-by: Nikos Livathinos <nli@zurich.ibm.com>

* fix: Ocr AccleratorDevice

Signed-off-by: Nikos Livathinos <nli@zurich.ibm.com>

* Merge pull request #556 from DS4SD/cau/layout-processing-improvement

feat: layout processing improvements and bugfixes

* Update lockfile

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Update tests

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Update HF model ref, reset test generate

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Repin to release package versions

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Many layout processing improvements, add document index type

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Update pinnings to docling-core

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Update test GT

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Fix table box snapping

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Fixes for cluster pre-ordering

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Introduce OCR confidence, propagate to orphan in post-processing

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Fix form and key value area groups

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Adjust confidence in EasyOcr

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Roll back CLI changes from main

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Update test GT

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Update docling-core pinning

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Annoying fixes for historical python versions

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Updated test GT for legacy

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

* Comment cleanup

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>

---------

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>
Signed-off-by: Nikos Livathinos <nli@zurich.ibm.com>
Co-authored-by: Nikos Livathinos <nli@zurich.ibm.com>
This commit is contained in:
Christoph Auer
2024-12-17 17:32:24 +01:00
committed by GitHub
parent 00dec7a2f3
commit 60dc852f16
56 changed files with 1659 additions and 1718 deletions

View File

@@ -22,9 +22,15 @@ from docling_core.types.legacy_doc.document import (
from docling_core.types.legacy_doc.document import CCSFileInfoObject as DsFileInfoObject
from docling_core.types.legacy_doc.document import ExportedCCSDocument as DsDocument
from PIL import ImageDraw
from pydantic import BaseModel, ConfigDict
from pydantic import BaseModel, ConfigDict, TypeAdapter
from docling.datamodel.base_models import Cluster, FigureElement, Table, TextElement
from docling.datamodel.base_models import (
Cluster,
ContainerElement,
FigureElement,
Table,
TextElement,
)
from docling.datamodel.document import ConversionResult, layout_label_to_ds_type
from docling.datamodel.settings import settings
from docling.utils.glm_utils import to_docling_document
@@ -204,7 +210,31 @@ class GlmModel:
)
],
obj_type=layout_label_to_ds_type.get(element.label),
# data=[[]],
payload={
"children": TypeAdapter(List[Cluster]).dump_python(
element.cluster.children
)
}, # hack to channel child clusters through GLM
)
)
elif isinstance(element, ContainerElement):
main_text.append(
BaseText(
text="",
payload={
"children": TypeAdapter(List[Cluster]).dump_python(
element.cluster.children
)
}, # hack to channel child clusters through GLM
obj_type=layout_label_to_ds_type.get(element.label),
name=element.label,
prov=[
Prov(
bbox=target_bbox,
page=element.page_no + 1,
span=[0, 0],
)
],
)
)

View File

@@ -118,6 +118,7 @@ class EasyOcrModel(BaseOcrModel):
),
)
for ix, line in enumerate(result)
if line[2] >= self.options.confidence_threshold
]
all_ocr_cells.extend(cells)

View File

@@ -7,9 +7,8 @@ from typing import Iterable, List
from docling_core.types.doc import CoordOrigin, DocItemLabel
from docling_ibm_models.layoutmodel.layout_predictor import LayoutPredictor
from PIL import ImageDraw
from PIL import Image, ImageDraw, ImageFont
import docling.utils.layout_utils as lu
from docling.datamodel.base_models import (
BoundingBox,
Cell,
@@ -22,6 +21,7 @@ from docling.datamodel.pipeline_options import AcceleratorDevice, AcceleratorOpt
from docling.datamodel.settings import settings
from docling.models.base_model import BasePageModel
from docling.utils.accelerator_utils import decide_device
from docling.utils.layout_postprocessor import LayoutPostprocessor
from docling.utils.profiling import TimeRecorder
_log = logging.getLogger(__name__)
@@ -44,9 +44,10 @@ class LayoutModel(BasePageModel):
]
PAGE_HEADER_LABELS = [DocItemLabel.PAGE_HEADER, DocItemLabel.PAGE_FOOTER]
TABLE_LABEL = DocItemLabel.TABLE
TABLE_LABELS = [DocItemLabel.TABLE, DocItemLabel.DOCUMENT_INDEX]
FIGURE_LABEL = DocItemLabel.PICTURE
FORMULA_LABEL = DocItemLabel.FORMULA
CONTAINER_LABELS = [DocItemLabel.FORM, DocItemLabel.KEY_VALUE_REGION]
def __init__(self, artifacts_path: Path, accelerator_options: AcceleratorOptions):
device = decide_device(accelerator_options.device)
@@ -55,234 +56,127 @@ class LayoutModel(BasePageModel):
artifact_path=str(artifacts_path),
device=device,
num_threads=accelerator_options.num_threads,
base_threshold=0.6,
blacklist_classes={"Form", "Key-Value Region"},
)
def postprocess(self, clusters_in: List[Cluster], cells: List[Cell], page_height):
MIN_INTERSECTION = 0.2
CLASS_THRESHOLDS = {
DocItemLabel.CAPTION: 0.35,
DocItemLabel.FOOTNOTE: 0.35,
DocItemLabel.FORMULA: 0.35,
DocItemLabel.LIST_ITEM: 0.35,
DocItemLabel.PAGE_FOOTER: 0.35,
DocItemLabel.PAGE_HEADER: 0.35,
DocItemLabel.PICTURE: 0.2, # low threshold adjust to capture chemical structures for examples.
DocItemLabel.SECTION_HEADER: 0.45,
DocItemLabel.TABLE: 0.35,
DocItemLabel.TEXT: 0.45,
DocItemLabel.TITLE: 0.45,
DocItemLabel.DOCUMENT_INDEX: 0.45,
DocItemLabel.CODE: 0.45,
DocItemLabel.CHECKBOX_SELECTED: 0.45,
DocItemLabel.CHECKBOX_UNSELECTED: 0.45,
DocItemLabel.FORM: 0.45,
DocItemLabel.KEY_VALUE_REGION: 0.45,
def draw_clusters_and_cells_side_by_side(
self, conv_res, page, clusters, mode_prefix: str, show: bool = False
):
"""
Draws a page image side by side with clusters filtered into two categories:
- Left: Clusters excluding FORM, KEY_VALUE_REGION, and PICTURE.
- Right: Clusters including FORM, KEY_VALUE_REGION, and PICTURE.
Includes label names and confidence scores for each cluster.
"""
label_to_color = {
DocItemLabel.TEXT: (255, 255, 153), # Light Yellow
DocItemLabel.CAPTION: (255, 204, 153), # Light Orange
DocItemLabel.LIST_ITEM: (153, 153, 255), # Light Purple
DocItemLabel.FORMULA: (192, 192, 192), # Gray
DocItemLabel.TABLE: (255, 204, 204), # Light Pink
DocItemLabel.PICTURE: (255, 204, 164), # Light Beige
DocItemLabel.SECTION_HEADER: (255, 153, 153), # Light Red
DocItemLabel.PAGE_HEADER: (204, 255, 204), # Light Green
DocItemLabel.PAGE_FOOTER: (
204,
255,
204,
), # Light Green (same as Page-Header)
DocItemLabel.TITLE: (255, 153, 153), # Light Red (same as Section-Header)
DocItemLabel.FOOTNOTE: (200, 200, 255), # Light Blue
DocItemLabel.DOCUMENT_INDEX: (220, 220, 220), # Light Gray
DocItemLabel.CODE: (125, 125, 125), # Gray
DocItemLabel.CHECKBOX_SELECTED: (255, 182, 193), # Pale Green
DocItemLabel.CHECKBOX_UNSELECTED: (255, 182, 193), # Light Pink
DocItemLabel.FORM: (200, 255, 255), # Light Cyan
DocItemLabel.KEY_VALUE_REGION: (183, 65, 14), # Rusty orange
}
CLASS_REMAPPINGS = {
DocItemLabel.DOCUMENT_INDEX: DocItemLabel.TABLE,
DocItemLabel.TITLE: DocItemLabel.SECTION_HEADER,
# Filter clusters for left and right images
exclude_labels = {
DocItemLabel.FORM,
DocItemLabel.KEY_VALUE_REGION,
DocItemLabel.PICTURE,
}
left_clusters = [c for c in clusters if c.label not in exclude_labels]
right_clusters = [c for c in clusters if c.label in exclude_labels]
# Create a deep copy of the original image for both sides
left_image = copy.deepcopy(page.image)
right_image = copy.deepcopy(page.image)
_log.debug("================= Start postprocess function ====================")
start_time = time.time()
# Apply Confidence Threshold to cluster predictions
# confidence = self.conf_threshold
clusters_mod = []
# Function to draw clusters on an image
def draw_clusters(image, clusters):
draw = ImageDraw.Draw(image, "RGBA")
# Create a smaller font for the labels
try:
font = ImageFont.truetype("arial.ttf", 12)
except OSError:
# Fallback to default font if arial is not available
font = ImageFont.load_default()
for c_tl in clusters:
all_clusters = [c_tl, *c_tl.children]
for c in all_clusters:
# Draw cells first (underneath)
cell_color = (0, 0, 0, 40) # Transparent black for cells
for tc in c.cells:
cx0, cy0, cx1, cy1 = tc.bbox.as_tuple()
draw.rectangle(
[(cx0, cy0), (cx1, cy1)],
outline=None,
fill=cell_color,
)
# Draw cluster rectangle
x0, y0, x1, y1 = c.bbox.as_tuple()
cluster_fill_color = (*list(label_to_color.get(c.label)), 70)
cluster_outline_color = (*list(label_to_color.get(c.label)), 255)
draw.rectangle(
[(x0, y0), (x1, y1)],
outline=cluster_outline_color,
fill=cluster_fill_color,
)
# Add label name and confidence
label_text = f"{c.label.name} ({c.confidence:.2f})"
# Create semi-transparent background for text
text_bbox = draw.textbbox((x0, y0), label_text, font=font)
text_bg_padding = 2
draw.rectangle(
[
(
text_bbox[0] - text_bg_padding,
text_bbox[1] - text_bg_padding,
),
(
text_bbox[2] + text_bg_padding,
text_bbox[3] + text_bg_padding,
),
],
fill=(255, 255, 255, 180), # Semi-transparent white
)
# Draw text
draw.text(
(x0, y0),
label_text,
fill=(0, 0, 0, 255), # Solid black
font=font,
)
for cluster in clusters_in:
confidence = CLASS_THRESHOLDS[cluster.label]
if cluster.confidence >= confidence:
# annotation["created_by"] = "high_conf_pred"
# Remap class labels where needed.
if cluster.label in CLASS_REMAPPINGS.keys():
cluster.label = CLASS_REMAPPINGS[cluster.label]
clusters_mod.append(cluster)
# map to dictionary clusters and cells, with bottom left origin
clusters_orig = [
{
"id": c.id,
"bbox": list(
c.bbox.to_bottom_left_origin(page_height).as_tuple()
), # TODO
"confidence": c.confidence,
"cell_ids": [],
"type": c.label,
}
for c in clusters_in
]
clusters_out = [
{
"id": c.id,
"bbox": list(
c.bbox.to_bottom_left_origin(page_height).as_tuple()
), # TODO
"confidence": c.confidence,
"created_by": "high_conf_pred",
"cell_ids": [],
"type": c.label,
}
for c in clusters_mod
]
del clusters_mod
raw_cells = [
{
"id": c.id,
"bbox": list(
c.bbox.to_bottom_left_origin(page_height).as_tuple()
), # TODO
"text": c.text,
}
for c in cells
]
cell_count = len(raw_cells)
_log.debug("---- 0. Treat cluster overlaps ------")
clusters_out = lu.remove_cluster_duplicates_by_conf(clusters_out, 0.8)
_log.debug(
"---- 1. Initially assign cells to clusters based on minimum intersection ------"
)
## Check for cells included in or touched by clusters:
clusters_out = lu.assigning_cell_ids_to_clusters(
clusters_out, raw_cells, MIN_INTERSECTION
)
_log.debug("---- 2. Assign Orphans with Low Confidence Detections")
# Creates a map of cell_id->cluster_id
(
clusters_around_cells,
orphan_cell_indices,
ambiguous_cell_indices,
) = lu.cell_id_state_map(clusters_out, cell_count)
# Assign orphan cells with lower confidence predictions
clusters_out, orphan_cell_indices = lu.assign_orphans_with_low_conf_pred(
clusters_out, clusters_orig, raw_cells, orphan_cell_indices
)
# Refresh the cell_ids assignment, after creating new clusters using low conf predictions
clusters_out = lu.assigning_cell_ids_to_clusters(
clusters_out, raw_cells, MIN_INTERSECTION
)
_log.debug("---- 3. Settle Ambigous Cells")
# Creates an update map after assignment of cell_id->cluster_id
(
clusters_around_cells,
orphan_cell_indices,
ambiguous_cell_indices,
) = lu.cell_id_state_map(clusters_out, cell_count)
# Settle pdf cells that belong to multiple clusters
clusters_out, ambiguous_cell_indices = lu.remove_ambigous_pdf_cell_by_conf(
clusters_out, raw_cells, ambiguous_cell_indices
)
_log.debug("---- 4. Set Orphans as Text")
(
clusters_around_cells,
orphan_cell_indices,
ambiguous_cell_indices,
) = lu.cell_id_state_map(clusters_out, cell_count)
clusters_out, orphan_cell_indices = lu.set_orphan_as_text(
clusters_out, clusters_orig, raw_cells, orphan_cell_indices
)
_log.debug("---- 5. Merge Cells & and adapt the bounding boxes")
# Merge cells orphan cells
clusters_out = lu.merge_cells(clusters_out)
# Clean up clusters that remain from merged and unreasonable clusters
clusters_out = lu.clean_up_clusters(
clusters_out,
raw_cells,
merge_cells=True,
img_table=True,
one_cell_table=True,
)
new_clusters = lu.adapt_bboxes(raw_cells, clusters_out, orphan_cell_indices)
clusters_out = new_clusters
## We first rebuild where every cell is now:
## Now we write into a prediction cells list, not into the raw cells list.
## As we don't need previous labels, we best overwrite any old list, because that might
## have been sorted differently.
(
clusters_around_cells,
orphan_cell_indices,
ambiguous_cell_indices,
) = lu.cell_id_state_map(clusters_out, cell_count)
target_cells = []
for ix, cell in enumerate(raw_cells):
new_cell = {
"id": ix,
"rawcell_id": ix,
"label": "None",
"bbox": cell["bbox"],
"text": cell["text"],
}
for cluster_index in clusters_around_cells[
ix
]: # By previous analysis, this is always 1 cluster.
new_cell["label"] = clusters_out[cluster_index]["type"]
target_cells.append(new_cell)
# _log.debug("New label of cell " + str(ix) + " is " + str(new_cell["label"]))
cells_out = target_cells
## -------------------------------
## Sort clusters into reasonable reading order, and sort the cells inside each cluster
_log.debug("---- 5. Sort clusters in reading order ------")
sorted_clusters = lu.produce_reading_order(
clusters_out, "raw_cell_ids", "raw_cell_ids", True
)
clusters_out = sorted_clusters
# end_time = timer()
_log.debug("---- End of postprocessing function ------")
end_time = time.time() - start_time
_log.debug(f"Finished post processing in seconds={end_time:.3f}")
cells_out_new = [
Cell(
id=c["id"], # type: ignore
bbox=BoundingBox.from_tuple(
coord=c["bbox"], origin=CoordOrigin.BOTTOMLEFT # type: ignore
).to_top_left_origin(page_height),
text=c["text"], # type: ignore
# Draw clusters on both images
draw_clusters(left_image, left_clusters)
draw_clusters(right_image, right_clusters)
# Combine the images side by side
combined_width = left_image.width * 2
combined_height = left_image.height
combined_image = Image.new("RGB", (combined_width, combined_height))
combined_image.paste(left_image, (0, 0))
combined_image.paste(right_image, (left_image.width, 0))
if show:
combined_image.show()
else:
out_path: Path = (
Path(settings.debug.debug_output_path)
/ f"debug_{conv_res.input.file.stem}"
)
for c in cells_out
]
del cells_out
clusters_out_new = []
for c in clusters_out:
cluster_cells = [
ccell for ccell in cells_out_new if ccell.id in c["cell_ids"] # type: ignore
]
c_new = Cluster(
id=c["id"], # type: ignore
bbox=BoundingBox.from_tuple(
coord=c["bbox"], origin=CoordOrigin.BOTTOMLEFT # type: ignore
).to_top_left_origin(page_height),
confidence=c["confidence"], # type: ignore
label=DocItemLabel(c["type"]),
cells=cluster_cells,
)
clusters_out_new.append(c_new)
return clusters_out_new, cells_out_new
out_path.mkdir(parents=True, exist_ok=True)
out_file = out_path / f"{mode_prefix}_layout_page_{page.page_no:05}.png"
combined_image.save(str(out_file), format="png")
def __call__(
self, conv_res: ConversionResult, page_batch: Iterable[Page]
@@ -315,66 +209,26 @@ class LayoutModel(BasePageModel):
)
clusters.append(cluster)
# Map cells to clusters
# TODO: Remove, postprocess should take care of it anyway.
for cell in page.cells:
for cluster in clusters:
if not cell.bbox.area() > 0:
overlap_frac = 0.0
else:
overlap_frac = (
cell.bbox.intersection_area_with(cluster.bbox)
/ cell.bbox.area()
)
if settings.debug.visualize_raw_layout:
self.draw_clusters_and_cells_side_by_side(
conv_res, page, clusters, mode_prefix="raw"
)
if overlap_frac > 0.5:
cluster.cells.append(cell)
# Apply postprocessing
# Pre-sort clusters
# clusters = self.sort_clusters_by_cell_order(clusters)
processed_clusters, processed_cells = LayoutPostprocessor(
page.cells, clusters, page.size
).postprocess()
# processed_clusters, processed_cells = clusters, page.cells
# DEBUG code:
def draw_clusters_and_cells(show: bool = False):
image = copy.deepcopy(page.image)
if image is not None:
draw = ImageDraw.Draw(image)
for c in clusters:
x0, y0, x1, y1 = c.bbox.as_tuple()
draw.rectangle([(x0, y0), (x1, y1)], outline="green")
cell_color = (
random.randint(30, 140),
random.randint(30, 140),
random.randint(30, 140),
)
for tc in c.cells: # [:1]:
x0, y0, x1, y1 = tc.bbox.as_tuple()
draw.rectangle(
[(x0, y0), (x1, y1)], outline=cell_color
)
if show:
image.show()
else:
out_path: Path = (
Path(settings.debug.debug_output_path)
/ f"debug_{conv_res.input.file.stem}"
)
out_path.mkdir(parents=True, exist_ok=True)
out_file = (
out_path / f"layout_page_{page.page_no:05}.png"
)
image.save(str(out_file), format="png")
# draw_clusters_and_cells()
clusters, page.cells = self.postprocess(
clusters, page.cells, page.size.height
page.cells = processed_cells
page.predictions.layout = LayoutPrediction(
clusters=processed_clusters
)
page.predictions.layout = LayoutPrediction(clusters=clusters)
if settings.debug.visualize_layout:
draw_clusters_and_cells()
self.draw_clusters_and_cells_side_by_side(
conv_res, page, processed_clusters, mode_prefix="postprocessed"
)
yield page

View File

@@ -6,6 +6,7 @@ from pydantic import BaseModel
from docling.datamodel.base_models import (
AssembledUnit,
ContainerElement,
FigureElement,
Page,
PageElement,
@@ -94,7 +95,7 @@ class PageAssembleModel(BasePageModel):
headers.append(text_el)
else:
body.append(text_el)
elif cluster.label == LayoutModel.TABLE_LABEL:
elif cluster.label in LayoutModel.TABLE_LABELS:
tbl = None
if page.predictions.tablestructure:
tbl = page.predictions.tablestructure.table_map.get(
@@ -159,6 +160,15 @@ class PageAssembleModel(BasePageModel):
)
elements.append(equation)
body.append(equation)
elif cluster.label in LayoutModel.CONTAINER_LABELS:
container_el = ContainerElement(
label=cluster.label,
id=cluster.id,
page_no=page.page_no,
cluster=cluster,
)
elements.append(container_el)
body.append(container_el)
page.assembled = AssembledUnit(
elements=elements, headers=headers, body=body

View File

@@ -76,6 +76,10 @@ class TableStructureModel(BasePageModel):
x0, y0, x1, y1 = table_element.cluster.bbox.as_tuple()
draw.rectangle([(x0, y0), (x1, y1)], outline="red")
for cell in table_element.cluster.cells:
x0, y0, x1, y1 = cell.bbox.as_tuple()
draw.rectangle([(x0, y0), (x1, y1)], outline="green")
for tc in table_element.table_cells:
if tc.bbox is not None:
x0, y0, x1, y1 = tc.bbox.as_tuple()
@@ -89,7 +93,6 @@ class TableStructureModel(BasePageModel):
text=f"{tc.start_row_offset_idx}, {tc.start_col_offset_idx}",
fill="black",
)
if show:
image.show()
else:
@@ -135,47 +138,40 @@ class TableStructureModel(BasePageModel):
],
)
for cluster in page.predictions.layout.clusters
if cluster.label == DocItemLabel.TABLE
if cluster.label
in [DocItemLabel.TABLE, DocItemLabel.DOCUMENT_INDEX]
]
if not len(in_tables):
yield page
continue
tokens = []
for c in page.cells:
for cluster, _ in in_tables:
if c.bbox.area() > 0:
if (
c.bbox.intersection_area_with(cluster.bbox)
/ c.bbox.area()
> 0.2
):
# Only allow non empty stings (spaces) into the cells of a table
if len(c.text.strip()) > 0:
new_cell = copy.deepcopy(c)
new_cell.bbox = new_cell.bbox.scaled(
scale=self.scale
)
tokens.append(new_cell.model_dump())
page_input = {
"tokens": tokens,
"width": page.size.width * self.scale,
"height": page.size.height * self.scale,
"image": numpy.asarray(page.get_image(scale=self.scale)),
}
page_input["image"] = numpy.asarray(
page.get_image(scale=self.scale)
)
table_clusters, table_bboxes = zip(*in_tables)
if len(table_bboxes):
tf_output = self.tf_predictor.multi_table_predict(
page_input, table_bboxes, do_matching=self.do_cell_matching
)
for table_cluster, tbl_box in in_tables:
for table_cluster, table_out in zip(table_clusters, tf_output):
tokens = []
for c in table_cluster.cells:
# Only allow non empty stings (spaces) into the cells of a table
if len(c.text.strip()) > 0:
new_cell = copy.deepcopy(c)
new_cell.bbox = new_cell.bbox.scaled(
scale=self.scale
)
tokens.append(new_cell.model_dump())
page_input["tokens"] = tokens
tf_output = self.tf_predictor.multi_table_predict(
page_input, [tbl_box], do_matching=self.do_cell_matching
)
table_out = tf_output[0]
table_cells = []
for element in table_out["tf_responses"]:
@@ -208,7 +204,7 @@ class TableStructureModel(BasePageModel):
id=table_cluster.id,
page_no=page.page_no,
cluster=table_cluster,
label=DocItemLabel.TABLE,
label=table_cluster.label,
)
page.predictions.tablestructure.table_map[