Switch everything to use label enum, and more

Signed-off-by: Christoph Auer <cau@zurich.ibm.com>
This commit is contained in:
Christoph Auer 2024-09-24 16:00:39 +02:00
parent 867e06f9f2
commit 33373ac0dd
7 changed files with 88 additions and 70 deletions

View File

@ -6,6 +6,7 @@ from typing import Annotated, Any, Dict, List, Optional, Tuple, Union
from docling_core.types.experimental.base import BoundingBox, Size from docling_core.types.experimental.base import BoundingBox, Size
from docling_core.types.experimental.document import BaseFigureData, TableCell from docling_core.types.experimental.document import BaseFigureData, TableCell
from docling_core.types.experimental.labels import PageLabel
from PIL.Image import Image from PIL.Image import Image
from pydantic import BaseModel, ConfigDict, Field, model_validator from pydantic import BaseModel, ConfigDict, Field, model_validator
from typing_extensions import Self from typing_extensions import Self
@ -50,14 +51,14 @@ class OcrCell(Cell):
class Cluster(BaseModel): class Cluster(BaseModel):
id: int id: int
label: str label: PageLabel
bbox: BoundingBox bbox: BoundingBox
confidence: float = 1.0 confidence: float = 1.0
cells: List[Cell] = [] cells: List[Cell] = []
class BasePageElement(BaseModel): class BasePageElement(BaseModel):
label: str label: PageLabel
id: int id: int
page_no: int page_no: int
cluster: Cluster cluster: Cluster

View File

@ -12,6 +12,7 @@ from docling_core.types import Table as DsSchemaTable
from docling_core.types.doc.base import BoundingBox as DsBoundingBox from docling_core.types.doc.base import BoundingBox as DsBoundingBox
from docling_core.types.doc.base import Figure, TableCell from docling_core.types.doc.base import Figure, TableCell
from docling_core.types.experimental.document import DoclingDocument, FileInfo from docling_core.types.experimental.document import DoclingDocument, FileInfo
from docling_core.types.experimental.labels import PageLabel
from pydantic import BaseModel from pydantic import BaseModel
from typing_extensions import deprecated from typing_extensions import deprecated
@ -34,21 +35,21 @@ from docling.utils.utils import create_file_hash
_log = logging.getLogger(__name__) _log = logging.getLogger(__name__)
layout_label_to_ds_type = { layout_label_to_ds_type = {
"Title": "title", PageLabel.TITLE: "title",
"Document Index": "table-of-path_or_stream", PageLabel.DOCUMENT_INDEX: "table-of-contents",
"Section-header": "subtitle-level-1", PageLabel.SECTION_HEADER: "subtitle-level-1",
"Checkbox-Selected": "checkbox-selected", PageLabel.CHECKBOX_SELECTED: "checkbox-selected",
"Checkbox-Unselected": "checkbox-unselected", PageLabel.CHECKBOX_UNSELECTED: "checkbox-unselected",
"Caption": "caption", PageLabel.CAPTION: "caption",
"Page-header": "page-header", PageLabel.PAGE_HEADER: "page-header",
"Page-footer": "page-footer", PageLabel.PAGE_FOOTER: "page-footer",
"Footnote": "footnote", PageLabel.FOOTNOTE: "footnote",
"Table": "table", PageLabel.TABLE: "table",
"Formula": "equation", PageLabel.FORMULA: "equation",
"List-item": "paragraph", PageLabel.LIST_ITEM: "paragraph",
"Code": "paragraph", PageLabel.CODE: "paragraph",
"Picture": "figure", PageLabel.PICTURE: "figure",
"Text": "paragraph", PageLabel.TEXT: "paragraph",
} }
_EMPTY_DOC = DsDocument( _EMPTY_DOC = DsDocument(

View File

@ -5,6 +5,7 @@ import time
from typing import Iterable, List from typing import Iterable, List
from docling_core.types.experimental.base import CoordOrigin from docling_core.types.experimental.base import CoordOrigin
from docling_core.types.experimental.labels import PageLabel
from docling_ibm_models.layoutmodel.layout_predictor import LayoutPredictor from docling_ibm_models.layoutmodel.layout_predictor import LayoutPredictor
from PIL import ImageDraw from PIL import ImageDraw
@ -23,23 +24,23 @@ _log = logging.getLogger(__name__)
class LayoutModel: class LayoutModel:
TEXT_ELEM_LABELS = [ TEXT_ELEM_LABELS = [
"Text", PageLabel.TEXT,
"Footnote", PageLabel.FOOTNOTE,
"Caption", PageLabel.CAPTION,
"Checkbox-Unselected", PageLabel.CHECKBOX_UNSELECTED,
"Checkbox-Selected", PageLabel.CHECKBOX_SELECTED,
"Section-header", PageLabel.SECTION_HEADER,
"Page-header", PageLabel.PAGE_HEADER,
"Page-footer", PageLabel.PAGE_FOOTER,
"Code", PageLabel.CODE,
"List-item", PageLabel.LIST_ITEM,
# "Formula", # "Formula",
] ]
PAGE_HEADER_LABELS = ["Page-header", "Page-footer"] PAGE_HEADER_LABELS = [PageLabel.PAGE_HEADER, PageLabel.PAGE_FOOTER]
TABLE_LABEL = "Table" TABLE_LABEL = PageLabel.TABLE
FIGURE_LABEL = "Picture" FIGURE_LABEL = PageLabel.PICTURE
FORMULA_LABEL = "Formula" FORMULA_LABEL = PageLabel.FORMULA
def __init__(self, config): def __init__(self, config):
self.config = config self.config = config
@ -50,27 +51,27 @@ class LayoutModel:
def postprocess(self, clusters: List[Cluster], cells: List[Cell], page_height): def postprocess(self, clusters: List[Cluster], cells: List[Cell], page_height):
MIN_INTERSECTION = 0.2 MIN_INTERSECTION = 0.2
CLASS_THRESHOLDS = { CLASS_THRESHOLDS = {
"Caption": 0.35, PageLabel.CAPTION: 0.35,
"Footnote": 0.35, PageLabel.FOOTNOTE: 0.35,
"Formula": 0.35, PageLabel.FORMULA: 0.35,
"List-item": 0.35, PageLabel.LIST_ITEM: 0.35,
"Page-footer": 0.35, PageLabel.PAGE_FOOTER: 0.35,
"Page-header": 0.35, PageLabel.PAGE_HEADER: 0.35,
"Picture": 0.2, # low threshold adjust to capture chemical structures for examples. PageLabel.PICTURE: 0.2, # low threshold adjust to capture chemical structures for examples.
"Section-header": 0.45, PageLabel.SECTION_HEADER: 0.45,
"Table": 0.35, PageLabel.TABLE: 0.35,
"Text": 0.45, PageLabel.TEXT: 0.45,
"Title": 0.45, PageLabel.TITLE: 0.45,
"Document Index": 0.45, PageLabel.DOCUMENT_INDEX: 0.45,
"Code": 0.45, PageLabel.CODE: 0.45,
"Checkbox-Selected": 0.45, PageLabel.CHECKBOX_SELECTED: 0.45,
"Checkbox-Unselected": 0.45, PageLabel.CHECKBOX_UNSELECTED: 0.45,
"Form": 0.45, PageLabel.FORM: 0.45,
"Key-Value Region": 0.45, PageLabel.KEY_VALUE_REGION: 0.45,
} }
CLASS_REMAPPINGS = { CLASS_REMAPPINGS = {
"Document Index": "Table", PageLabel.DOCUMENT_INDEX: PageLabel.TABLE,
} }
_log.debug("================= Start postprocess function ====================") _log.debug("================= Start postprocess function ====================")
@ -257,7 +258,7 @@ class LayoutModel:
coord=c["bbox"], origin=CoordOrigin.BOTTOMLEFT coord=c["bbox"], origin=CoordOrigin.BOTTOMLEFT
).to_top_left_origin(page_height), ).to_top_left_origin(page_height),
confidence=c["confidence"], confidence=c["confidence"],
label=c["type"], label=PageLabel(c["type"]),
cells=cluster_cells, cells=cluster_cells,
) )
clusters_out_new.append(c_new) clusters_out_new.append(c_new)
@ -270,9 +271,12 @@ class LayoutModel:
for ix, pred_item in enumerate( for ix, pred_item in enumerate(
self.layout_predictor.predict(page.get_image(scale=1.0)) self.layout_predictor.predict(page.get_image(scale=1.0))
): ):
label = PageLabel(
pred_item["label"].lower().replace(" ", "_").replace("-", "_")
) # Temporary, until docling-ibm-model uses docling-core types
cluster = Cluster( cluster = Cluster(
id=ix, id=ix,
label=pred_item["label"], label=label,
confidence=pred_item["confidence"], confidence=pred_item["confidence"],
bbox=BoundingBox.model_validate(pred_item), bbox=BoundingBox.model_validate(pred_item),
cells=[], cells=[],

View File

@ -2,6 +2,7 @@ import copy
import logging import logging
import networkx as nx import networkx as nx
from docling_core.types.experimental.labels import PageLabel
logger = logging.getLogger("layout_utils") logger = logging.getLogger("layout_utils")
@ -370,7 +371,7 @@ def adapt_bboxes(raw_cells, clusters, orphan_cell_indices):
"Treating cluster " + str(ix) + ", type " + str(new_cluster["type"]) "Treating cluster " + str(ix) + ", type " + str(new_cluster["type"])
) )
logger.debug(" with cells: " + str(new_cluster["cell_ids"])) logger.debug(" with cells: " + str(new_cluster["cell_ids"]))
if len(cluster["cell_ids"]) == 0 and cluster["type"] != "Picture": if len(cluster["cell_ids"]) == 0 and cluster["type"] != PageLabel.PICTURE:
logger.debug(" Empty non-picture, removed") logger.debug(" Empty non-picture, removed")
continue ## Skip this former cluster, now without cells. continue ## Skip this former cluster, now without cells.
new_bbox = adapt_bbox(raw_cells, new_cluster, orphan_cell_indices) new_bbox = adapt_bbox(raw_cells, new_cluster, orphan_cell_indices)
@ -380,14 +381,14 @@ def adapt_bboxes(raw_cells, clusters, orphan_cell_indices):
def adapt_bbox(raw_cells, cluster, orphan_cell_indices): def adapt_bbox(raw_cells, cluster, orphan_cell_indices):
if not (cluster["type"] in ["Table", "Picture"]): if not (cluster["type"] in [PageLabel.TABLE, PageLabel.PICTURE]):
## A text-like cluster. The bbox only needs to be around the text cells: ## A text-like cluster. The bbox only needs to be around the text cells:
logger.debug(" Initial bbox: " + str(cluster["bbox"])) logger.debug(" Initial bbox: " + str(cluster["bbox"]))
new_bbox = surrounding_list( new_bbox = surrounding_list(
[raw_cells[cid]["bbox"] for cid in cluster["cell_ids"]] [raw_cells[cid]["bbox"] for cid in cluster["cell_ids"]]
) )
logger.debug(" New bounding box:" + str(new_bbox)) logger.debug(" New bounding box:" + str(new_bbox))
if cluster["type"] == "Picture": if cluster["type"] == PageLabel.PICTURE:
## We only make the bbox completely comprise included text cells: ## We only make the bbox completely comprise included text cells:
logger.debug(" Picture") logger.debug(" Picture")
if len(cluster["cell_ids"]) != 0: if len(cluster["cell_ids"]) != 0:
@ -587,7 +588,7 @@ def set_orphan_as_text(
max_id = -1 max_id = -1
figures = [] figures = []
for cluster in cluster_predictions: for cluster in cluster_predictions:
if cluster["type"] == "Picture": if cluster["type"] == PageLabel.PICTURE:
figures.append(cluster) figures.append(cluster)
if cluster["id"] > max_id: if cluster["id"] > max_id:
@ -638,13 +639,13 @@ def set_orphan_as_text(
# if fig_flag == False and raw_cells[orph_id]["text"] not in line_orphans: # if fig_flag == False and raw_cells[orph_id]["text"] not in line_orphans:
if fig_flag == False and lines_detector == False: if fig_flag == False and lines_detector == False:
# get class from low confidence detections if not set as text: # get class from low confidence detections if not set as text:
class_type = "Text" class_type = PageLabel.TEXT
for cluster in cluster_predictions_low: for cluster in cluster_predictions_low:
intersection = compute_intersection( intersection = compute_intersection(
orph_cell["bbox"], cluster["bbox"] orph_cell["bbox"], cluster["bbox"]
) )
class_type = "Text" class_type = PageLabel.TEXT
if ( if (
cluster["confidence"] > 0.1 cluster["confidence"] > 0.1
and bb_iou(cluster["bbox"], orph_cell["bbox"]) > 0.4 and bb_iou(cluster["bbox"], orph_cell["bbox"]) > 0.4
@ -718,7 +719,7 @@ def merge_cells(cluster_predictions):
if cluster["id"] == node: if cluster["id"] == node:
lines.append(cluster) lines.append(cluster)
cluster_predictions.remove(cluster) cluster_predictions.remove(cluster)
new_merged_cluster = build_cluster_from_lines(lines, "Text", max_id) new_merged_cluster = build_cluster_from_lines(lines, PageLabel.TEXT, max_id)
cluster_predictions.append(new_merged_cluster) cluster_predictions.append(new_merged_cluster)
return cluster_predictions return cluster_predictions
@ -753,9 +754,9 @@ def clean_up_clusters(
# remove clusters that might appear inside tables, or images (such as pdf cells in graphs) # remove clusters that might appear inside tables, or images (such as pdf cells in graphs)
elif img_table == True: elif img_table == True:
if ( if (
cluster_1["type"] == "Text" cluster_1["type"] == PageLabel.TEXT
and cluster_2["type"] == "Picture" and cluster_2["type"] == PageLabel.PICTURE
or cluster_2["type"] == "Table" or cluster_2["type"] == PageLabel.TABLE
): ):
if bb_iou(cluster_1["bbox"], cluster_2["bbox"]) > 0.5: if bb_iou(cluster_1["bbox"], cluster_2["bbox"]) > 0.5:
DuplicateDeletedClusterIDs.append(cluster_1["id"]) DuplicateDeletedClusterIDs.append(cluster_1["id"])
@ -771,7 +772,10 @@ def clean_up_clusters(
DuplicateDeletedClusterIDs.append(cluster_1["id"]) DuplicateDeletedClusterIDs.append(cluster_1["id"])
# remove tables that have one pdf cell # remove tables that have one pdf cell
if one_cell_table == True: if one_cell_table == True:
if cluster_1["type"] == "Table" and len(cluster_1["cell_ids"]) < 2: if (
cluster_1["type"] == PageLabel.TABLE
and len(cluster_1["cell_ids"]) < 2
):
DuplicateDeletedClusterIDs.append(cluster_1["id"]) DuplicateDeletedClusterIDs.append(cluster_1["id"])
DuplicateDeletedClusterIDs = list(set(DuplicateDeletedClusterIDs)) DuplicateDeletedClusterIDs = list(set(DuplicateDeletedClusterIDs))

View File

@ -48,6 +48,14 @@ def export_documents(
) )
) )
# Export Docling document format to doctags (experimental):
with (output_dir / f"{doc_filename}.experimental.doctags").open("w") as fp:
fp.write(conv_res.experimental.export_to_document_tokens())
# Export Docling document format to markdown (experimental):
with (output_dir / f"{doc_filename}.experimental.md").open("w") as fp:
fp.write(conv_res.experimental.export_to_markdown())
# Export Text format: # Export Text format:
with (output_dir / f"{doc_filename}.txt").open("w") as fp: with (output_dir / f"{doc_filename}.txt").open("w") as fp:
fp.write(conv_res.render_as_text()) fp.write(conv_res.render_as_text())

14
poetry.lock generated
View File

@ -862,7 +862,7 @@ files = []
develop = false develop = false
[package.dependencies] [package.dependencies]
docling-core = {git = "ssh://git@github.com/DS4SD/docling-core.git", branch = "cau/new-format-dev"} docling-core = {git = "ssh://git@github.com/DS4SD/docling-core.git", rev = "a83ff0056138d83ac2cb52bfb2ab1728ff86972f"}
docutils = "!=0.21" docutils = "!=0.21"
matplotlib = "^3.7.1" matplotlib = "^3.7.1"
networkx = "^3.1" networkx = "^3.1"
@ -882,7 +882,7 @@ toolkit = ["deepsearch-toolkit (>=0.31.0)"]
type = "git" type = "git"
url = "ssh://git@github.com/DS4SD/deepsearch-glm.git" url = "ssh://git@github.com/DS4SD/deepsearch-glm.git"
reference = "cau/new-format-dev" reference = "cau/new-format-dev"
resolved_reference = "60e4bda21fbe7ee8849d27a9321ba37cca04e7aa" resolved_reference = "c26b52e8faf789cb31fcbed816d25e775391832f"
[[package]] [[package]]
name = "deprecated" name = "deprecated"
@ -960,8 +960,8 @@ tabulate = "^0.9.0"
[package.source] [package.source]
type = "git" type = "git"
url = "ssh://git@github.com/DS4SD/docling-core.git" url = "ssh://git@github.com/DS4SD/docling-core.git"
reference = "cau/new-format-dev" reference = "a83ff0056138d83ac2cb52bfb2ab1728ff86972f"
resolved_reference = "0a1e6ce9559ffccf50c5e63c33962ac8fde35648" resolved_reference = "a83ff0056138d83ac2cb52bfb2ab1728ff86972f"
[[package]] [[package]]
name = "docling-ibm-models" name = "docling-ibm-models"
@ -1059,12 +1059,12 @@ files = [
[[package]] [[package]]
name = "easyocr" name = "easyocr"
version = "1.7.1" version = "1.7.2"
description = "End-to-End Multi-Lingual Optical Character Recognition (OCR) Solution" description = "End-to-End Multi-Lingual Optical Character Recognition (OCR) Solution"
optional = false optional = false
python-versions = "*" python-versions = "*"
files = [ files = [
{file = "easyocr-1.7.1-py3-none-any.whl", hash = "sha256:5b0a2e7cfdfc6c1ec99d9583663e570e4189dca6fbf373f074b21b8809e44d2b"}, {file = "easyocr-1.7.2-py3-none-any.whl", hash = "sha256:5be12f9b0e595d443c9c3d10b0542074b50f0ec2d98b141a109cd961fd1c177c"},
] ]
[package.dependencies] [package.dependencies]
@ -7314,4 +7314,4 @@ examples = ["langchain-huggingface", "langchain-milvus", "langchain-text-splitte
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = "^3.10" python-versions = "^3.10"
content-hash = "1b908180d822d74ae8033e8b6c650b8d00b4365fc7dd36cea6505305651b79b6" content-hash = "325aebca1bdc6e0cfeb8fc59a84102a804d750211fe8e59cd4cb15876c1ca12e"

View File

@ -23,7 +23,7 @@ packages = [{include = "docling"}]
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = "^3.10" python = "^3.10"
pydantic = "^2.0.0" pydantic = "^2.0.0"
docling-core = {git = "ssh://git@github.com/DS4SD/docling-core.git", branch = "cau/new-format-dev"} docling-core = {git = "ssh://git@github.com/DS4SD/docling-core.git", rev = "a83ff0056138d83ac2cb52bfb2ab1728ff86972f"}
docling-ibm-models = "^1.2.0" docling-ibm-models = "^1.2.0"
deepsearch-glm = {git = "ssh://git@github.com/DS4SD/deepsearch-glm.git", branch = "cau/new-format-dev"} deepsearch-glm = {git = "ssh://git@github.com/DS4SD/deepsearch-glm.git", branch = "cau/new-format-dev"}