feat: Optimize table extraction quality, add configuration options (#11)

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>
Signed-off-by: Christoph Auer <60343111+cau-git@users.noreply.github.com>
Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>
Signed-off-by: Michele Dolfi <97102151+dolfim-ibm@users.noreply.github.com>
Signed-off-by: Panos Vagenas <35837085+vagenas@users.noreply.github.com>
Co-authored-by: Christoph Auer <cau@zurich.ibm.com>
Co-authored-by: Michele Dolfi <97102151+dolfim-ibm@users.noreply.github.com>
Co-authored-by: Panos Vagenas <35837085+vagenas@users.noreply.github.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
Christoph Auer
2024-07-17 16:13:21 +02:00
committed by GitHub
parent 3e2ede8107
commit e9526bb11e
5 changed files with 87 additions and 27 deletions

View File

@@ -19,18 +19,6 @@ class PageAssembleModel:
def __init__(self, config):
self.config = config
# self.line_wrap_pattern = re.compile(r'(?<=[^\W_])- \n(?=\w)')
# def sanitize_text_poor(self, lines):
# text = '\n'.join(lines)
#
# # treat line wraps.
# sanitized_text = self.line_wrap_pattern.sub('', text)
#
# sanitized_text = sanitized_text.replace('\n', ' ')
#
# return sanitized_text
def sanitize_text(self, lines):
if len(lines) <= 1:
return " ".join(lines)

View File

@@ -1,7 +1,10 @@
from typing import Iterable
import copy
import random
from typing import Iterable, List
import numpy
from docling_ibm_models.tableformer.data_management.tf_predictor import TFPredictor
from PIL import ImageDraw
from docling.datamodel.base_models import (
BoundingBox,
@@ -28,6 +31,21 @@ class TableStructureModel:
self.tm_model_type = self.tm_config["model"]["type"]
self.tf_predictor = TFPredictor(self.tm_config)
self.scale = 2.0 # Scale up table input images to 144 dpi
def draw_table_and_cells(self, page: Page, tbl_list: List[TableElement]):
image = page._backend.get_page_image()
draw = ImageDraw.Draw(image)
for table_element in tbl_list:
x0, y0, x1, y1 = table_element.cluster.bbox.as_tuple()
draw.rectangle([(x0, y0), (x1, y1)], outline="red")
for tc in table_element.table_cells:
x0, y0, x1, y1 = tc.bbox.as_tuple()
draw.rectangle([(x0, y0), (x1, y1)], outline="blue")
image.show()
def __call__(self, page_batch: Iterable[Page]) -> Iterable[Page]:
@@ -36,16 +54,17 @@ class TableStructureModel:
return
for page in page_batch:
page.predictions.tablestructure = TableStructurePrediction() # dummy
in_tables = [
(
cluster,
[
round(cluster.bbox.l),
round(cluster.bbox.t),
round(cluster.bbox.r),
round(cluster.bbox.b),
round(cluster.bbox.l) * self.scale,
round(cluster.bbox.t) * self.scale,
round(cluster.bbox.r) * self.scale,
round(cluster.bbox.b) * self.scale,
],
)
for cluster in page.predictions.layout.clusters
@@ -65,20 +84,29 @@ class TableStructureModel:
):
# Only allow non empty stings (spaces) into the cells of a table
if len(c.text.strip()) > 0:
tokens.append(c.model_dump())
new_cell = copy.deepcopy(c)
new_cell.bbox = new_cell.bbox.scaled(scale=self.scale)
iocr_page = {
"image": numpy.asarray(page.image),
tokens.append(new_cell.model_dump())
page_input = {
"tokens": tokens,
"width": page.size.width,
"height": page.size.height,
"width": page.size.width * self.scale,
"height": page.size.height * self.scale,
}
# add image to page input.
if self.scale == 1.0:
page_input["image"] = numpy.asarray(page.image)
else: # render new page image on the fly at desired scale
page_input["image"] = numpy.asarray(
page._backend.get_page_image(scale=self.scale)
)
table_clusters, table_bboxes = zip(*in_tables)
if len(table_bboxes):
tf_output = self.tf_predictor.multi_table_predict(
iocr_page, table_bboxes, do_matching=self.do_cell_matching
page_input, table_bboxes, do_matching=self.do_cell_matching
)
for table_cluster, table_out in zip(table_clusters, tf_output):
@@ -91,6 +119,7 @@ class TableStructureModel:
element["bbox"]["token"] = text_piece
tc = TableCell.model_validate(element)
tc.bbox = tc.bbox.scaled(1 / self.scale)
table_cells.append(tc)
# Retrieving cols/rows, after post processing:
@@ -111,4 +140,7 @@ class TableStructureModel:
page.predictions.tablestructure.table_map[table_cluster.id] = tbl
# For debugging purposes:
# self.draw_table_and_cells(page, page.predictions.tablestructure.table_map.values())
yield page