Optimizations for table extraction quality, configurable options for cell matching

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>
This commit is contained in:
Christoph Auer 2024-07-17 15:21:13 +02:00
parent 78b154fde7
commit 6c01600194
4 changed files with 68 additions and 18 deletions

View File

@ -1,3 +1,4 @@
import copy
from enum import Enum, auto from enum import Enum, auto
from io import BytesIO from io import BytesIO
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
@ -47,6 +48,15 @@ class BoundingBox(BaseModel):
def height(self): def height(self):
return abs(self.t - self.b) return abs(self.t - self.b)
def scaled(self, scale: float) -> "BoundingBox":
out_bbox = copy.deepcopy(self)
out_bbox.l *= scale
out_bbox.r *= scale
out_bbox.t *= scale
out_bbox.b *= scale
return out_bbox
def as_tuple(self): def as_tuple(self):
if self.coord_origin == CoordOrigin.TOPLEFT: if self.coord_origin == CoordOrigin.TOPLEFT:
return (self.l, self.t, self.r, self.b) return (self.l, self.t, self.r, self.b)
@ -180,8 +190,7 @@ class TableStructurePrediction(BaseModel):
table_map: Dict[int, TableElement] = {} table_map: Dict[int, TableElement] = {}
class TextElement(BasePageElement): class TextElement(BasePageElement): ...
...
class FigureData(BaseModel): class FigureData(BaseModel):
@ -242,6 +251,17 @@ class DocumentStream(BaseModel):
stream: BytesIO stream: BytesIO
class TableStructureOptions(BaseModel):
do_cell_matching: bool = (
True
# True: Matches predictions back to PDF cells. Can break table output if PDF cells
# are merged across table columns.
# False: Let table structure model define the text cells, ignore PDF cells.
)
class PipelineOptions(BaseModel): class PipelineOptions(BaseModel):
do_table_structure: bool = True do_table_structure: bool = True # True: perform table structure extraction
do_ocr: bool = False do_ocr: bool = False # True: perform OCR, replace programmatic PDF text
table_structure_options: TableStructureOptions = TableStructureOptions()

View File

@ -1,7 +1,10 @@
from typing import Iterable import copy
import random
from typing import Iterable, List
import numpy import numpy
from docling_ibm_models.tableformer.data_management.tf_predictor import TFPredictor from docling_ibm_models.tableformer.data_management.tf_predictor import TFPredictor
from PIL import ImageDraw
from docling.datamodel.base_models import ( from docling.datamodel.base_models import (
BoundingBox, BoundingBox,
@ -28,6 +31,21 @@ class TableStructureModel:
self.tm_model_type = self.tm_config["model"]["type"] self.tm_model_type = self.tm_config["model"]["type"]
self.tf_predictor = TFPredictor(self.tm_config) self.tf_predictor = TFPredictor(self.tm_config)
self.scale = 2.0 # Scale up table input images to 144 dpi
def draw_table_and_cells(self, page: Page, tbl_list: List[TableElement]):
image = page._backend.get_page_image()
draw = ImageDraw.Draw(image)
for table_element in tbl_list:
x0, y0, x1, y1 = table_element.cluster.bbox.as_tuple()
draw.rectangle([(x0, y0), (x1, y1)], outline="red")
for tc in table_element.table_cells:
x0, y0, x1, y1 = tc.bbox.as_tuple()
draw.rectangle([(x0, y0), (x1, y1)], outline="blue")
image.show()
def __call__(self, page_batch: Iterable[Page]) -> Iterable[Page]: def __call__(self, page_batch: Iterable[Page]) -> Iterable[Page]:
@ -36,16 +54,17 @@ class TableStructureModel:
return return
for page in page_batch: for page in page_batch:
page.predictions.tablestructure = TableStructurePrediction() # dummy page.predictions.tablestructure = TableStructurePrediction() # dummy
in_tables = [ in_tables = [
( (
cluster, cluster,
[ [
round(cluster.bbox.l), round(cluster.bbox.l) * self.scale,
round(cluster.bbox.t), round(cluster.bbox.t) * self.scale,
round(cluster.bbox.r), round(cluster.bbox.r) * self.scale,
round(cluster.bbox.b), round(cluster.bbox.b) * self.scale,
], ],
) )
for cluster in page.predictions.layout.clusters for cluster in page.predictions.layout.clusters
@ -65,20 +84,29 @@ class TableStructureModel:
): ):
# Only allow non empty stings (spaces) into the cells of a table # Only allow non empty stings (spaces) into the cells of a table
if len(c.text.strip()) > 0: if len(c.text.strip()) > 0:
tokens.append(c.model_dump()) new_cell = copy.deepcopy(c)
new_cell.bbox = new_cell.bbox.scaled(scale=self.scale)
iocr_page = { tokens.append(new_cell.model_dump())
"image": numpy.asarray(page.image),
page_input = {
"tokens": tokens, "tokens": tokens,
"width": page.size.width, "width": page.size.width * self.scale,
"height": page.size.height, "height": page.size.height * self.scale,
} }
# add image to page input.
if self.scale == 1.0:
page_input["image"] = numpy.asarray(page.image)
else: # render new page image on the fly at desired scale
page_input["image"] = numpy.asarray(
page._backend.get_page_image(scale=self.scale)
)
table_clusters, table_bboxes = zip(*in_tables) table_clusters, table_bboxes = zip(*in_tables)
if len(table_bboxes): if len(table_bboxes):
tf_output = self.tf_predictor.multi_table_predict( tf_output = self.tf_predictor.multi_table_predict(
iocr_page, table_bboxes, do_matching=self.do_cell_matching page_input, table_bboxes, do_matching=self.do_cell_matching
) )
for table_cluster, table_out in zip(table_clusters, tf_output): for table_cluster, table_out in zip(table_clusters, tf_output):
@ -91,6 +119,7 @@ class TableStructureModel:
element["bbox"]["token"] = text_piece element["bbox"]["token"] = text_piece
tc = TableCell.model_validate(element) tc = TableCell.model_validate(element)
tc.bbox = tc.bbox.scaled(1 / self.scale)
table_cells.append(tc) table_cells.append(tc)
# Retrieving cols/rows, after post processing: # Retrieving cols/rows, after post processing:
@ -111,4 +140,7 @@ class TableStructureModel:
page.predictions.tablestructure.table_map[table_cluster.id] = tbl page.predictions.tablestructure.table_map[table_cluster.id] = tbl
# For debugging purposes:
# self.draw_table_and_cells(page, page.predictions.tablestructure.table_map.values())
yield page yield page

View File

@ -34,7 +34,7 @@ class StandardModelPipeline(BaseModelPipeline):
"artifacts_path": artifacts_path "artifacts_path": artifacts_path
/ StandardModelPipeline._table_model_path, / StandardModelPipeline._table_model_path,
"enabled": pipeline_options.do_table_structure, "enabled": pipeline_options.do_table_structure,
"do_cell_matching": False, "do_cell_matching": pipeline_options.table_structure_options.do_cell_matching,
} }
), ),
] ]

View File

@ -46,8 +46,6 @@ def main():
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
input_doc_paths = [ input_doc_paths = [
# Path("/Users/cau/Downloads/Issue-36122.pdf"),
# Path("/Users/cau/Downloads/IBM_Storage_Insights_Fact_Sheet.pdf"),
Path("./test/data/2206.01062.pdf"), Path("./test/data/2206.01062.pdf"),
Path("./test/data/2203.01017v2.pdf"), Path("./test/data/2203.01017v2.pdf"),
Path("./test/data/2305.03393v1.pdf"), Path("./test/data/2305.03393v1.pdf"),