Support tableformer model choice

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>
This commit is contained in:
Christoph Auer 2024-09-19 15:25:37 +02:00
parent 6dd1e91c4a
commit 14dcba11c0
3 changed files with 15 additions and 1 deletions

View File

@ -24,6 +24,11 @@ class DocInputType(str, Enum):
STREAM = auto() STREAM = auto()
class TableFormerMode(str, Enum):
FAST = auto()
ACCURATE = auto()
class CoordOrigin(str, Enum): class CoordOrigin(str, Enum):
TOPLEFT = auto() TOPLEFT = auto()
BOTTOMLEFT = auto() BOTTOMLEFT = auto()
@ -305,6 +310,7 @@ class TableStructureOptions(BaseModel):
# are merged across table columns. # are merged across table columns.
# False: Let table structure model define the text cells, ignore PDF cells. # False: Let table structure model define the text cells, ignore PDF cells.
) )
mode: TableFormerMode = TableFormerMode.FAST
class PipelineOptions(BaseModel): class PipelineOptions(BaseModel):

View File

@ -1,4 +1,5 @@
import copy import copy
from pathlib import Path
from typing import Iterable, List from typing import Iterable, List
import numpy import numpy
@ -10,6 +11,7 @@ from docling.datamodel.base_models import (
Page, Page,
TableCell, TableCell,
TableElement, TableElement,
TableFormerMode,
TableStructurePrediction, TableStructurePrediction,
) )
@ -18,10 +20,15 @@ class TableStructureModel:
def __init__(self, config): def __init__(self, config):
self.config = config self.config = config
self.do_cell_matching = config["do_cell_matching"] self.do_cell_matching = config["do_cell_matching"]
self.mode = config["mode"]
self.enabled = config["enabled"] self.enabled = config["enabled"]
if self.enabled: if self.enabled:
artifacts_path = config["artifacts_path"] artifacts_path: Path = config["artifacts_path"]
if self.mode == TableFormerMode.ACCURATE:
artifacts_path = artifacts_path / "fat"
# Third Party # Third Party
import docling_ibm_models.tableformer.common as c import docling_ibm_models.tableformer.common as c

View File

@ -32,6 +32,7 @@ class StandardModelPipeline(BaseModelPipeline):
"artifacts_path": artifacts_path "artifacts_path": artifacts_path
/ StandardModelPipeline._table_model_path, / StandardModelPipeline._table_model_path,
"enabled": pipeline_options.do_table_structure, "enabled": pipeline_options.do_table_structure,
"mode": pipeline_options.table_structure_options.mode,
"do_cell_matching": pipeline_options.table_structure_options.do_cell_matching, "do_cell_matching": pipeline_options.table_structure_options.do_cell_matching,
} }
), ),