mirror of
https://github.com/DS4SD/docling.git
synced 2025-12-18 09:31:02 +00:00
feat!: Docling v2 (#117)
--------- Signed-off-by: Christoph Auer <cau@zurich.ibm.com> Signed-off-by: Maxim Lysak <mly@zurich.ibm.com> Signed-off-by: Michele Dolfi <dol@zurich.ibm.com> Signed-off-by: Panos Vagenas <35837085+vagenas@users.noreply.github.com> Co-authored-by: Maxim Lysak <mly@zurich.ibm.com> Co-authored-by: Michele Dolfi <dol@zurich.ibm.com> Co-authored-by: Panos Vagenas <35837085+vagenas@users.noreply.github.com>
This commit is contained in:
25
docling/models/base_model.py
Normal file
25
docling/models/base_model.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Iterable
|
||||
|
||||
from docling_core.types.doc import DoclingDocument, NodeItem
|
||||
|
||||
from docling.datamodel.base_models import Page
|
||||
|
||||
|
||||
class BasePageModel(ABC):
|
||||
@abstractmethod
|
||||
def __call__(self, page_batch: Iterable[Page]) -> Iterable[Page]:
|
||||
pass
|
||||
|
||||
|
||||
class BaseEnrichmentModel(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def is_processable(self, doc: DoclingDocument, element: NodeItem) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def __call__(
|
||||
self, doc: DoclingDocument, element_batch: Iterable[NodeItem]
|
||||
) -> Iterable[Any]:
|
||||
pass
|
||||
@@ -1,14 +1,15 @@
|
||||
import copy
|
||||
import logging
|
||||
from abc import abstractmethod
|
||||
from typing import Iterable, List, Tuple
|
||||
from typing import Iterable, List
|
||||
|
||||
import numpy as np
|
||||
from docling_core.types.doc import BoundingBox, CoordOrigin
|
||||
from PIL import Image, ImageDraw
|
||||
from rtree import index
|
||||
from scipy.ndimage import find_objects, label
|
||||
|
||||
from docling.datamodel.base_models import BoundingBox, CoordOrigin, OcrCell, Page
|
||||
from docling.datamodel.base_models import OcrCell, Page
|
||||
from docling.datamodel.pipeline_options import OcrOptions
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
@@ -20,8 +21,9 @@ class BaseOcrModel:
|
||||
self.options = options
|
||||
|
||||
# Computes the optimum amount and coordinates of rectangles to OCR on a given page
|
||||
def get_ocr_rects(self, page: Page) -> Tuple[bool, List[BoundingBox]]:
|
||||
def get_ocr_rects(self, page: Page) -> List[BoundingBox]:
|
||||
BITMAP_COVERAGE_TRESHOLD = 0.75
|
||||
assert page.size is not None
|
||||
|
||||
def find_ocr_rects(size, bitmap_rects):
|
||||
image = Image.new(
|
||||
@@ -60,7 +62,10 @@ class BaseOcrModel:
|
||||
|
||||
return (area_frac, bounding_boxes) # fraction covered # boxes
|
||||
|
||||
bitmap_rects = page._backend.get_bitmap_rects()
|
||||
if page._backend is not None:
|
||||
bitmap_rects = page._backend.get_bitmap_rects()
|
||||
else:
|
||||
bitmap_rects = []
|
||||
coverage, ocr_rects = find_ocr_rects(page.size, bitmap_rects)
|
||||
|
||||
# return full-page rectangle if sufficiently covered with bitmaps
|
||||
@@ -75,7 +80,7 @@ class BaseOcrModel:
|
||||
)
|
||||
]
|
||||
# return individual rectangles if the bitmap coverage is smaller
|
||||
elif coverage < BITMAP_COVERAGE_TRESHOLD:
|
||||
else: # coverage <= BITMAP_COVERAGE_TRESHOLD:
|
||||
return ocr_rects
|
||||
|
||||
# Filters OCR cells by dropping any OCR cell that intersects with an existing programmatic cell.
|
||||
|
||||
@@ -1,39 +1,228 @@
|
||||
import copy
|
||||
import random
|
||||
from typing import List, Union
|
||||
|
||||
from deepsearch_glm.nlp_utils import init_nlp_model
|
||||
from deepsearch_glm.utils.doc_utils import to_legacy_document_format
|
||||
from deepsearch_glm.utils.doc_utils import to_docling_document
|
||||
from deepsearch_glm.utils.load_pretrained_models import load_pretrained_nlp_models
|
||||
from docling_core.types import BaseText
|
||||
from docling_core.types import Document as DsDocument
|
||||
from docling_core.types import Ref
|
||||
from docling_core.types import DocumentDescription as DsDocumentDescription
|
||||
from docling_core.types import FileInfoObject as DsFileInfoObject
|
||||
from docling_core.types import PageDimensions, PageReference, Prov, Ref
|
||||
from docling_core.types import Table as DsSchemaTable
|
||||
from docling_core.types.doc import BoundingBox, CoordOrigin, DoclingDocument
|
||||
from docling_core.types.legacy_doc.base import BoundingBox as DsBoundingBox
|
||||
from docling_core.types.legacy_doc.base import Figure, TableCell
|
||||
from PIL import ImageDraw
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from docling.datamodel.base_models import BoundingBox, Cluster, CoordOrigin
|
||||
from docling.datamodel.document import ConversionResult
|
||||
from docling.datamodel.base_models import Cluster, FigureElement, Table, TextElement
|
||||
from docling.datamodel.document import ConversionResult, layout_label_to_ds_type
|
||||
from docling.utils.utils import create_hash
|
||||
|
||||
|
||||
class GlmOptions(BaseModel):
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
model_names: str = "" # e.g. "language;term;reference"
|
||||
|
||||
|
||||
class GlmModel:
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.model_names = self.config.get(
|
||||
"model_names", ""
|
||||
) # "language;term;reference"
|
||||
load_pretrained_nlp_models()
|
||||
# model = init_nlp_model(model_names="language;term;reference")
|
||||
model = init_nlp_model(model_names=self.model_names)
|
||||
self.model = model
|
||||
def __init__(self, options: GlmOptions):
|
||||
self.options = options
|
||||
|
||||
def __call__(self, conv_res: ConversionResult) -> DsDocument:
|
||||
ds_doc = conv_res._to_ds_document()
|
||||
load_pretrained_nlp_models()
|
||||
self.model = init_nlp_model(model_names=self.options.model_names)
|
||||
|
||||
def _to_legacy_document(self, conv_res) -> DsDocument:
|
||||
title = ""
|
||||
desc: DsDocumentDescription = DsDocumentDescription(logs=[])
|
||||
|
||||
page_hashes = [
|
||||
PageReference(
|
||||
hash=create_hash(conv_res.input.document_hash + ":" + str(p.page_no)),
|
||||
page=p.page_no + 1,
|
||||
model="default",
|
||||
)
|
||||
for p in conv_res.pages
|
||||
]
|
||||
|
||||
file_info = DsFileInfoObject(
|
||||
filename=conv_res.input.file.name,
|
||||
document_hash=conv_res.input.document_hash,
|
||||
num_pages=conv_res.input.page_count,
|
||||
page_hashes=page_hashes,
|
||||
)
|
||||
|
||||
main_text: List[Union[Ref, BaseText]] = []
|
||||
tables: List[DsSchemaTable] = []
|
||||
figures: List[Figure] = []
|
||||
|
||||
page_no_to_page = {p.page_no: p for p in conv_res.pages}
|
||||
|
||||
for element in conv_res.assembled.elements:
|
||||
# Convert bboxes to lower-left origin.
|
||||
target_bbox = DsBoundingBox(
|
||||
element.cluster.bbox.to_bottom_left_origin(
|
||||
page_no_to_page[element.page_no].size.height
|
||||
).as_tuple()
|
||||
)
|
||||
|
||||
if isinstance(element, TextElement):
|
||||
main_text.append(
|
||||
BaseText(
|
||||
text=element.text,
|
||||
obj_type=layout_label_to_ds_type.get(element.label),
|
||||
name=element.label,
|
||||
prov=[
|
||||
Prov(
|
||||
bbox=target_bbox,
|
||||
page=element.page_no + 1,
|
||||
span=[0, len(element.text)],
|
||||
)
|
||||
],
|
||||
)
|
||||
)
|
||||
elif isinstance(element, Table):
|
||||
index = len(tables)
|
||||
ref_str = f"#/tables/{index}"
|
||||
main_text.append(
|
||||
Ref(
|
||||
name=element.label,
|
||||
obj_type=layout_label_to_ds_type.get(element.label),
|
||||
ref=ref_str,
|
||||
),
|
||||
)
|
||||
|
||||
# Initialise empty table data grid (only empty cells)
|
||||
table_data = [
|
||||
[
|
||||
TableCell(
|
||||
text="",
|
||||
# bbox=[0,0,0,0],
|
||||
spans=[[i, j]],
|
||||
obj_type="body",
|
||||
)
|
||||
for j in range(element.num_cols)
|
||||
]
|
||||
for i in range(element.num_rows)
|
||||
]
|
||||
|
||||
# Overwrite cells in table data for which there is actual cell content.
|
||||
for cell in element.table_cells:
|
||||
for i in range(
|
||||
min(cell.start_row_offset_idx, element.num_rows),
|
||||
min(cell.end_row_offset_idx, element.num_rows),
|
||||
):
|
||||
for j in range(
|
||||
min(cell.start_col_offset_idx, element.num_cols),
|
||||
min(cell.end_col_offset_idx, element.num_cols),
|
||||
):
|
||||
celltype = "body"
|
||||
if cell.column_header:
|
||||
celltype = "col_header"
|
||||
elif cell.row_header:
|
||||
celltype = "row_header"
|
||||
elif cell.row_section:
|
||||
celltype = "row_section"
|
||||
|
||||
def make_spans(cell):
|
||||
for rspan in range(
|
||||
min(cell.start_row_offset_idx, element.num_rows),
|
||||
min(cell.end_row_offset_idx, element.num_rows),
|
||||
):
|
||||
for cspan in range(
|
||||
min(
|
||||
cell.start_col_offset_idx, element.num_cols
|
||||
),
|
||||
min(cell.end_col_offset_idx, element.num_cols),
|
||||
):
|
||||
yield [rspan, cspan]
|
||||
|
||||
spans = list(make_spans(cell))
|
||||
if cell.bbox is not None:
|
||||
bbox = cell.bbox.to_bottom_left_origin(
|
||||
page_no_to_page[element.page_no].size.height
|
||||
).as_tuple()
|
||||
else:
|
||||
bbox = None
|
||||
|
||||
table_data[i][j] = TableCell(
|
||||
text=cell.text,
|
||||
bbox=bbox,
|
||||
# col=j,
|
||||
# row=i,
|
||||
spans=spans,
|
||||
obj_type=celltype,
|
||||
# col_span=[cell.start_col_offset_idx, cell.end_col_offset_idx],
|
||||
# row_span=[cell.start_row_offset_idx, cell.end_row_offset_idx]
|
||||
)
|
||||
|
||||
tables.append(
|
||||
DsSchemaTable(
|
||||
num_cols=element.num_cols,
|
||||
num_rows=element.num_rows,
|
||||
obj_type=layout_label_to_ds_type.get(element.label),
|
||||
data=table_data,
|
||||
prov=[
|
||||
Prov(
|
||||
bbox=target_bbox,
|
||||
page=element.page_no + 1,
|
||||
span=[0, 0],
|
||||
)
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
elif isinstance(element, FigureElement):
|
||||
index = len(figures)
|
||||
ref_str = f"#/figures/{index}"
|
||||
main_text.append(
|
||||
Ref(
|
||||
name=element.label,
|
||||
obj_type=layout_label_to_ds_type.get(element.label),
|
||||
ref=ref_str,
|
||||
),
|
||||
)
|
||||
figures.append(
|
||||
Figure(
|
||||
prov=[
|
||||
Prov(
|
||||
bbox=target_bbox,
|
||||
page=element.page_no + 1,
|
||||
span=[0, 0],
|
||||
)
|
||||
],
|
||||
obj_type=layout_label_to_ds_type.get(element.label),
|
||||
# data=[[]],
|
||||
)
|
||||
)
|
||||
|
||||
page_dimensions = [
|
||||
PageDimensions(page=p.page_no + 1, height=p.size.height, width=p.size.width)
|
||||
for p in conv_res.pages
|
||||
]
|
||||
|
||||
ds_doc: DsDocument = DsDocument(
|
||||
name=title,
|
||||
description=desc,
|
||||
file_info=file_info,
|
||||
main_text=main_text,
|
||||
tables=tables,
|
||||
figures=figures,
|
||||
page_dimensions=page_dimensions,
|
||||
)
|
||||
|
||||
return ds_doc
|
||||
|
||||
def __call__(self, conv_res: ConversionResult) -> DoclingDocument:
|
||||
ds_doc = self._to_legacy_document(conv_res)
|
||||
ds_doc_dict = ds_doc.model_dump(by_alias=True)
|
||||
|
||||
glm_doc = self.model.apply_on_doc(ds_doc_dict)
|
||||
ds_doc_dict = to_legacy_document_format(
|
||||
glm_doc, ds_doc_dict, update_name_label=True
|
||||
)
|
||||
|
||||
exported_doc = DsDocument.model_validate(ds_doc_dict)
|
||||
docling_doc: DoclingDocument = to_docling_document(glm_doc) # Experimental
|
||||
|
||||
# DEBUG code:
|
||||
def draw_clusters_and_cells(ds_document, page_no):
|
||||
@@ -48,7 +237,7 @@ class GlmModel:
|
||||
if arr == "tables":
|
||||
prov = ds_document.tables[index].prov[0]
|
||||
elif arr == "figures":
|
||||
prov = ds_document.figures[index].prov[0]
|
||||
prov = ds_document.pictures[index].prov[0]
|
||||
else:
|
||||
prov = None
|
||||
|
||||
@@ -83,4 +272,4 @@ class GlmModel:
|
||||
# draw_clusters_and_cells(ds_doc, 0)
|
||||
# draw_clusters_and_cells(exported_doc, 0)
|
||||
|
||||
return exported_doc
|
||||
return docling_doc
|
||||
|
||||
@@ -2,8 +2,9 @@ import logging
|
||||
from typing import Iterable
|
||||
|
||||
import numpy
|
||||
from docling_core.types.doc import BoundingBox, CoordOrigin
|
||||
|
||||
from docling.datamodel.base_models import BoundingBox, CoordOrigin, OcrCell, Page
|
||||
from docling.datamodel.base_models import OcrCell, Page
|
||||
from docling.datamodel.pipeline_options import EasyOcrOptions
|
||||
from docling.models.base_ocr_model import BaseOcrModel
|
||||
|
||||
@@ -39,6 +40,8 @@ class EasyOcrModel(BaseOcrModel):
|
||||
return
|
||||
|
||||
for page in page_batch:
|
||||
assert page._backend is not None
|
||||
|
||||
ocr_rects = self.get_ocr_rects(page)
|
||||
|
||||
all_ocr_cells = []
|
||||
|
||||
@@ -2,8 +2,10 @@ import copy
|
||||
import logging
|
||||
import random
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Iterable, List
|
||||
|
||||
from docling_core.types.doc import CoordOrigin, DocItemLabel
|
||||
from docling_ibm_models.layoutmodel.layout_predictor import LayoutPredictor
|
||||
from PIL import ImageDraw
|
||||
|
||||
@@ -11,74 +13,73 @@ from docling.datamodel.base_models import (
|
||||
BoundingBox,
|
||||
Cell,
|
||||
Cluster,
|
||||
CoordOrigin,
|
||||
LayoutPrediction,
|
||||
Page,
|
||||
)
|
||||
from docling.models.base_model import BasePageModel
|
||||
from docling.utils import layout_utils as lu
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LayoutModel:
|
||||
class LayoutModel(BasePageModel):
|
||||
|
||||
TEXT_ELEM_LABELS = [
|
||||
"Text",
|
||||
"Footnote",
|
||||
"Caption",
|
||||
"Checkbox-Unselected",
|
||||
"Checkbox-Selected",
|
||||
"Section-header",
|
||||
"Page-header",
|
||||
"Page-footer",
|
||||
"Code",
|
||||
"List-item",
|
||||
# "Title"
|
||||
DocItemLabel.TEXT,
|
||||
DocItemLabel.FOOTNOTE,
|
||||
DocItemLabel.CAPTION,
|
||||
DocItemLabel.CHECKBOX_UNSELECTED,
|
||||
DocItemLabel.CHECKBOX_SELECTED,
|
||||
DocItemLabel.SECTION_HEADER,
|
||||
DocItemLabel.PAGE_HEADER,
|
||||
DocItemLabel.PAGE_FOOTER,
|
||||
DocItemLabel.CODE,
|
||||
DocItemLabel.LIST_ITEM,
|
||||
# "Formula",
|
||||
]
|
||||
PAGE_HEADER_LABELS = ["Page-header", "Page-footer"]
|
||||
PAGE_HEADER_LABELS = [DocItemLabel.PAGE_HEADER, DocItemLabel.PAGE_FOOTER]
|
||||
|
||||
TABLE_LABEL = "Table"
|
||||
FIGURE_LABEL = "Picture"
|
||||
FORMULA_LABEL = "Formula"
|
||||
TABLE_LABEL = DocItemLabel.TABLE
|
||||
FIGURE_LABEL = DocItemLabel.PICTURE
|
||||
FORMULA_LABEL = DocItemLabel.FORMULA
|
||||
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.layout_predictor = LayoutPredictor(
|
||||
config["artifacts_path"]
|
||||
) # TODO temporary
|
||||
def __init__(self, artifacts_path: Path):
|
||||
self.layout_predictor = LayoutPredictor(artifacts_path) # TODO temporary
|
||||
|
||||
def postprocess(self, clusters: List[Cluster], cells: List[Cell], page_height):
|
||||
def postprocess(self, clusters_in: List[Cluster], cells: List[Cell], page_height):
|
||||
MIN_INTERSECTION = 0.2
|
||||
CLASS_THRESHOLDS = {
|
||||
"Caption": 0.35,
|
||||
"Footnote": 0.35,
|
||||
"Formula": 0.35,
|
||||
"List-item": 0.35,
|
||||
"Page-footer": 0.35,
|
||||
"Page-header": 0.35,
|
||||
"Picture": 0.2, # low threshold adjust to capture chemical structures for examples.
|
||||
"Section-header": 0.45,
|
||||
"Table": 0.35,
|
||||
"Text": 0.45,
|
||||
"Title": 0.45,
|
||||
"Document Index": 0.45,
|
||||
"Code": 0.45,
|
||||
"Checkbox-Selected": 0.45,
|
||||
"Checkbox-Unselected": 0.45,
|
||||
"Form": 0.45,
|
||||
"Key-Value Region": 0.45,
|
||||
DocItemLabel.CAPTION: 0.35,
|
||||
DocItemLabel.FOOTNOTE: 0.35,
|
||||
DocItemLabel.FORMULA: 0.35,
|
||||
DocItemLabel.LIST_ITEM: 0.35,
|
||||
DocItemLabel.PAGE_FOOTER: 0.35,
|
||||
DocItemLabel.PAGE_HEADER: 0.35,
|
||||
DocItemLabel.PICTURE: 0.2, # low threshold adjust to capture chemical structures for examples.
|
||||
DocItemLabel.SECTION_HEADER: 0.45,
|
||||
DocItemLabel.TABLE: 0.35,
|
||||
DocItemLabel.TEXT: 0.45,
|
||||
DocItemLabel.TITLE: 0.45,
|
||||
DocItemLabel.DOCUMENT_INDEX: 0.45,
|
||||
DocItemLabel.CODE: 0.45,
|
||||
DocItemLabel.CHECKBOX_SELECTED: 0.45,
|
||||
DocItemLabel.CHECKBOX_UNSELECTED: 0.45,
|
||||
DocItemLabel.FORM: 0.45,
|
||||
DocItemLabel.KEY_VALUE_REGION: 0.45,
|
||||
}
|
||||
|
||||
CLASS_REMAPPINGS = {"Document Index": "Table", "Title": "Section-header"}
|
||||
CLASS_REMAPPINGS = {
|
||||
DocItemLabel.DOCUMENT_INDEX: DocItemLabel.TABLE,
|
||||
DocItemLabel.TITLE: DocItemLabel.SECTION_HEADER,
|
||||
}
|
||||
|
||||
_log.debug("================= Start postprocess function ====================")
|
||||
start_time = time.time()
|
||||
# Apply Confidence Threshold to cluster predictions
|
||||
# confidence = self.conf_threshold
|
||||
clusters_out = []
|
||||
clusters_mod = []
|
||||
|
||||
for cluster in clusters:
|
||||
for cluster in clusters_in:
|
||||
confidence = CLASS_THRESHOLDS[cluster.label]
|
||||
if cluster.confidence >= confidence:
|
||||
# annotation["created_by"] = "high_conf_pred"
|
||||
@@ -86,10 +87,10 @@ class LayoutModel:
|
||||
# Remap class labels where needed.
|
||||
if cluster.label in CLASS_REMAPPINGS.keys():
|
||||
cluster.label = CLASS_REMAPPINGS[cluster.label]
|
||||
clusters_out.append(cluster)
|
||||
clusters_mod.append(cluster)
|
||||
|
||||
# map to dictionary clusters and cells, with bottom left origin
|
||||
clusters = [
|
||||
clusters_orig = [
|
||||
{
|
||||
"id": c.id,
|
||||
"bbox": list(
|
||||
@@ -99,7 +100,7 @@ class LayoutModel:
|
||||
"cell_ids": [],
|
||||
"type": c.label,
|
||||
}
|
||||
for c in clusters
|
||||
for c in clusters_in
|
||||
]
|
||||
|
||||
clusters_out = [
|
||||
@@ -113,9 +114,11 @@ class LayoutModel:
|
||||
"cell_ids": [],
|
||||
"type": c.label,
|
||||
}
|
||||
for c in clusters_out
|
||||
for c in clusters_mod
|
||||
]
|
||||
|
||||
del clusters_mod
|
||||
|
||||
raw_cells = [
|
||||
{
|
||||
"id": c.id,
|
||||
@@ -149,7 +152,7 @@ class LayoutModel:
|
||||
|
||||
# Assign orphan cells with lower confidence predictions
|
||||
clusters_out, orphan_cell_indices = lu.assign_orphans_with_low_conf_pred(
|
||||
clusters_out, clusters, raw_cells, orphan_cell_indices
|
||||
clusters_out, clusters_orig, raw_cells, orphan_cell_indices
|
||||
)
|
||||
|
||||
# Refresh the cell_ids assignment, after creating new clusters using low conf predictions
|
||||
@@ -178,7 +181,7 @@ class LayoutModel:
|
||||
) = lu.cell_id_state_map(clusters_out, cell_count)
|
||||
|
||||
clusters_out, orphan_cell_indices = lu.set_orphan_as_text(
|
||||
clusters_out, clusters, raw_cells, orphan_cell_indices
|
||||
clusters_out, clusters_orig, raw_cells, orphan_cell_indices
|
||||
)
|
||||
|
||||
_log.debug("---- 5. Merge Cells & and adapt the bounding boxes")
|
||||
@@ -237,46 +240,55 @@ class LayoutModel:
|
||||
end_time = time.time() - start_time
|
||||
_log.debug(f"Finished post processing in seconds={end_time:.3f}")
|
||||
|
||||
cells_out = [
|
||||
cells_out_new = [
|
||||
Cell(
|
||||
id=c["id"],
|
||||
id=c["id"], # type: ignore
|
||||
bbox=BoundingBox.from_tuple(
|
||||
coord=c["bbox"], origin=CoordOrigin.BOTTOMLEFT
|
||||
coord=c["bbox"], origin=CoordOrigin.BOTTOMLEFT # type: ignore
|
||||
).to_top_left_origin(page_height),
|
||||
text=c["text"],
|
||||
text=c["text"], # type: ignore
|
||||
)
|
||||
for c in cells_out
|
||||
]
|
||||
|
||||
del cells_out
|
||||
|
||||
clusters_out_new = []
|
||||
for c in clusters_out:
|
||||
cluster_cells = [ccell for ccell in cells_out if ccell.id in c["cell_ids"]]
|
||||
cluster_cells = [
|
||||
ccell for ccell in cells_out_new if ccell.id in c["cell_ids"] # type: ignore
|
||||
]
|
||||
c_new = Cluster(
|
||||
id=c["id"],
|
||||
id=c["id"], # type: ignore
|
||||
bbox=BoundingBox.from_tuple(
|
||||
coord=c["bbox"], origin=CoordOrigin.BOTTOMLEFT
|
||||
coord=c["bbox"], origin=CoordOrigin.BOTTOMLEFT # type: ignore
|
||||
).to_top_left_origin(page_height),
|
||||
confidence=c["confidence"],
|
||||
label=c["type"],
|
||||
confidence=c["confidence"], # type: ignore
|
||||
label=DocItemLabel(c["type"]),
|
||||
cells=cluster_cells,
|
||||
)
|
||||
clusters_out_new.append(c_new)
|
||||
|
||||
return clusters_out_new, cells_out
|
||||
return clusters_out_new, cells_out_new
|
||||
|
||||
def __call__(self, page_batch: Iterable[Page]) -> Iterable[Page]:
|
||||
for page in page_batch:
|
||||
assert page.size is not None
|
||||
|
||||
clusters = []
|
||||
for ix, pred_item in enumerate(
|
||||
self.layout_predictor.predict(page.get_image(scale=1.0))
|
||||
):
|
||||
label = DocItemLabel(
|
||||
pred_item["label"].lower().replace(" ", "_").replace("-", "_")
|
||||
) # Temporary, until docling-ibm-model uses docling-core types
|
||||
cluster = Cluster(
|
||||
id=ix,
|
||||
label=pred_item["label"],
|
||||
label=label,
|
||||
confidence=pred_item["confidence"],
|
||||
bbox=BoundingBox.model_validate(pred_item),
|
||||
cells=[],
|
||||
)
|
||||
|
||||
clusters.append(cluster)
|
||||
|
||||
# Map cells to clusters
|
||||
|
||||
@@ -2,22 +2,29 @@ import logging
|
||||
import re
|
||||
from typing import Iterable, List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from docling.datamodel.base_models import (
|
||||
AssembledUnit,
|
||||
FigureElement,
|
||||
Page,
|
||||
PageElement,
|
||||
TableElement,
|
||||
Table,
|
||||
TextElement,
|
||||
)
|
||||
from docling.models.base_model import BasePageModel
|
||||
from docling.models.layout_model import LayoutModel
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PageAssembleModel:
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
class PageAssembleOptions(BaseModel):
|
||||
keep_images: bool = False
|
||||
|
||||
|
||||
class PageAssembleModel(BasePageModel):
|
||||
def __init__(self, options: PageAssembleOptions):
|
||||
self.options = options
|
||||
|
||||
def sanitize_text(self, lines):
|
||||
if len(lines) <= 1:
|
||||
@@ -46,6 +53,8 @@ class PageAssembleModel:
|
||||
|
||||
def __call__(self, page_batch: Iterable[Page]) -> Iterable[Page]:
|
||||
for page in page_batch:
|
||||
assert page._backend is not None
|
||||
assert page.predictions.layout is not None
|
||||
# assembles some JSON output page by page.
|
||||
|
||||
elements: List[PageElement] = []
|
||||
@@ -84,7 +93,7 @@ class PageAssembleModel:
|
||||
if (
|
||||
not tbl
|
||||
): # fallback: add table without structure, if it isn't present
|
||||
tbl = TableElement(
|
||||
tbl = Table(
|
||||
label=cluster.label,
|
||||
id=cluster.id,
|
||||
text="",
|
||||
@@ -145,4 +154,11 @@ class PageAssembleModel:
|
||||
elements=elements, headers=headers, body=body
|
||||
)
|
||||
|
||||
# Remove page images (can be disabled)
|
||||
if not self.options.keep_images:
|
||||
page._image_cache = {}
|
||||
|
||||
# Unload backend
|
||||
page._backend.unload()
|
||||
|
||||
yield page
|
||||
|
||||
57
docling/models/page_preprocessing_model.py
Normal file
57
docling/models/page_preprocessing_model.py
Normal file
@@ -0,0 +1,57 @@
|
||||
from typing import Iterable, Optional
|
||||
|
||||
from PIL import ImageDraw
|
||||
from pydantic import BaseModel
|
||||
|
||||
from docling.datamodel.base_models import Page
|
||||
from docling.models.base_model import BasePageModel
|
||||
|
||||
|
||||
class PagePreprocessingOptions(BaseModel):
|
||||
images_scale: Optional[float]
|
||||
|
||||
|
||||
class PagePreprocessingModel(BasePageModel):
|
||||
def __init__(self, options: PagePreprocessingOptions):
|
||||
self.options = options
|
||||
|
||||
def __call__(self, page_batch: Iterable[Page]) -> Iterable[Page]:
|
||||
for page in page_batch:
|
||||
page = self._populate_page_images(page)
|
||||
page = self._parse_page_cells(page)
|
||||
yield page
|
||||
|
||||
# Generate the page image and store it in the page object
|
||||
def _populate_page_images(self, page: Page) -> Page:
|
||||
# default scale
|
||||
page.get_image(
|
||||
scale=1.0
|
||||
) # puts the page image on the image cache at default scale
|
||||
|
||||
images_scale = self.options.images_scale
|
||||
# user requested scales
|
||||
if images_scale is not None:
|
||||
page._default_image_scale = images_scale
|
||||
page.get_image(
|
||||
scale=images_scale
|
||||
) # this will trigger storing the image in the internal cache
|
||||
|
||||
return page
|
||||
|
||||
# Extract and populate the page cells and store it in the page object
|
||||
def _parse_page_cells(self, page: Page) -> Page:
|
||||
assert page._backend is not None
|
||||
|
||||
page.cells = list(page._backend.get_text_cells())
|
||||
|
||||
# DEBUG code:
|
||||
def draw_text_boxes(image, cells):
|
||||
draw = ImageDraw.Draw(image)
|
||||
for c in cells:
|
||||
x0, y0, x1, y1 = c.bbox.as_tuple()
|
||||
draw.rectangle([(x0, y0), (x1, y1)], outline="red")
|
||||
image.show()
|
||||
|
||||
# draw_text_boxes(page.get_image(scale=1.0), cells)
|
||||
|
||||
return page
|
||||
@@ -3,29 +3,25 @@ from pathlib import Path
|
||||
from typing import Iterable, List
|
||||
|
||||
import numpy
|
||||
from docling_core.types.doc import BoundingBox, DocItemLabel, TableCell
|
||||
from docling_ibm_models.tableformer.data_management.tf_predictor import TFPredictor
|
||||
from PIL import ImageDraw
|
||||
|
||||
from docling.datamodel.base_models import (
|
||||
BoundingBox,
|
||||
Page,
|
||||
TableCell,
|
||||
TableElement,
|
||||
TableStructurePrediction,
|
||||
)
|
||||
from docling.datamodel.pipeline_options import TableFormerMode
|
||||
from docling.datamodel.base_models import Page, Table, TableStructurePrediction
|
||||
from docling.datamodel.pipeline_options import TableFormerMode, TableStructureOptions
|
||||
from docling.models.base_model import BasePageModel
|
||||
|
||||
|
||||
class TableStructureModel:
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.do_cell_matching = config["do_cell_matching"]
|
||||
self.mode = config["mode"]
|
||||
class TableStructureModel(BasePageModel):
|
||||
def __init__(
|
||||
self, enabled: bool, artifacts_path: Path, options: TableStructureOptions
|
||||
):
|
||||
self.options = options
|
||||
self.do_cell_matching = self.options.do_cell_matching
|
||||
self.mode = self.options.mode
|
||||
|
||||
self.enabled = config["enabled"]
|
||||
self.enabled = enabled
|
||||
if self.enabled:
|
||||
artifacts_path: Path = config["artifacts_path"]
|
||||
|
||||
if self.mode == TableFormerMode.ACCURATE:
|
||||
artifacts_path = artifacts_path / "fat"
|
||||
|
||||
@@ -39,7 +35,9 @@ class TableStructureModel:
|
||||
self.tf_predictor = TFPredictor(self.tm_config)
|
||||
self.scale = 2.0 # Scale up table input images to 144 dpi
|
||||
|
||||
def draw_table_and_cells(self, page: Page, tbl_list: List[TableElement]):
|
||||
def draw_table_and_cells(self, page: Page, tbl_list: List[Table]):
|
||||
assert page._backend is not None
|
||||
|
||||
image = (
|
||||
page._backend.get_page_image()
|
||||
) # make new image to avoid drawing on the saved ones
|
||||
@@ -50,17 +48,18 @@ class TableStructureModel:
|
||||
draw.rectangle([(x0, y0), (x1, y1)], outline="red")
|
||||
|
||||
for tc in table_element.table_cells:
|
||||
x0, y0, x1, y1 = tc.bbox.as_tuple()
|
||||
if tc.column_header:
|
||||
width = 3
|
||||
else:
|
||||
width = 1
|
||||
draw.rectangle([(x0, y0), (x1, y1)], outline="blue", width=width)
|
||||
draw.text(
|
||||
(x0 + 3, y0 + 3),
|
||||
text=f"{tc.start_row_offset_idx}, {tc.start_col_offset_idx}",
|
||||
fill="black",
|
||||
)
|
||||
if tc.bbox is not None:
|
||||
x0, y0, x1, y1 = tc.bbox.as_tuple()
|
||||
if tc.column_header:
|
||||
width = 3
|
||||
else:
|
||||
width = 1
|
||||
draw.rectangle([(x0, y0), (x1, y1)], outline="blue", width=width)
|
||||
draw.text(
|
||||
(x0 + 3, y0 + 3),
|
||||
text=f"{tc.start_row_offset_idx}, {tc.start_col_offset_idx}",
|
||||
fill="black",
|
||||
)
|
||||
|
||||
image.show()
|
||||
|
||||
@@ -71,6 +70,9 @@ class TableStructureModel:
|
||||
return
|
||||
|
||||
for page in page_batch:
|
||||
assert page._backend is not None
|
||||
assert page.predictions.layout is not None
|
||||
assert page.size is not None
|
||||
|
||||
page.predictions.tablestructure = TableStructurePrediction() # dummy
|
||||
|
||||
@@ -85,7 +87,7 @@ class TableStructureModel:
|
||||
],
|
||||
)
|
||||
for cluster in page.predictions.layout.clusters
|
||||
if cluster.label == "Table"
|
||||
if cluster.label == DocItemLabel.TABLE
|
||||
]
|
||||
if not len(in_tables):
|
||||
yield page
|
||||
@@ -132,7 +134,7 @@ class TableStructureModel:
|
||||
element["bbox"]["token"] = text_piece
|
||||
|
||||
tc = TableCell.model_validate(element)
|
||||
if self.do_cell_matching:
|
||||
if self.do_cell_matching and tc.bbox is not None:
|
||||
tc.bbox = tc.bbox.scaled(1 / self.scale)
|
||||
table_cells.append(tc)
|
||||
|
||||
@@ -141,7 +143,7 @@ class TableStructureModel:
|
||||
num_cols = table_out["predict_details"]["num_cols"]
|
||||
otsl_seq = table_out["predict_details"]["prediction"]["rs_seq"]
|
||||
|
||||
tbl = TableElement(
|
||||
tbl = Table(
|
||||
otsl_seq=otsl_seq,
|
||||
table_cells=table_cells,
|
||||
num_rows=num_rows,
|
||||
@@ -149,7 +151,7 @@ class TableStructureModel:
|
||||
id=table_cluster.id,
|
||||
page_no=page.page_no,
|
||||
cluster=table_cluster,
|
||||
label="Table",
|
||||
label=DocItemLabel.TABLE,
|
||||
)
|
||||
|
||||
page.predictions.tablestructure.table_map[table_cluster.id] = tbl
|
||||
|
||||
@@ -2,11 +2,12 @@ import io
|
||||
import logging
|
||||
import tempfile
|
||||
from subprocess import DEVNULL, PIPE, Popen
|
||||
from typing import Iterable, Tuple
|
||||
from typing import Iterable, Optional, Tuple
|
||||
|
||||
import pandas as pd
|
||||
from docling_core.types.doc import BoundingBox, CoordOrigin
|
||||
|
||||
from docling.datamodel.base_models import BoundingBox, CoordOrigin, OcrCell, Page
|
||||
from docling.datamodel.base_models import OcrCell, Page
|
||||
from docling.datamodel.pipeline_options import TesseractCliOcrOptions
|
||||
from docling.models.base_ocr_model import BaseOcrModel
|
||||
|
||||
@@ -21,8 +22,8 @@ class TesseractOcrCliModel(BaseOcrModel):
|
||||
|
||||
self.scale = 3 # multiplier for 72 dpi == 216 dpi.
|
||||
|
||||
self._name = None
|
||||
self._version = None
|
||||
self._name: Optional[str] = None
|
||||
self._version: Optional[str] = None
|
||||
|
||||
if self.enabled:
|
||||
try:
|
||||
@@ -39,7 +40,7 @@ class TesseractOcrCliModel(BaseOcrModel):
|
||||
def _get_name_and_version(self) -> Tuple[str, str]:
|
||||
|
||||
if self._name != None and self._version != None:
|
||||
return self._name, self._version
|
||||
return self._name, self._version # type: ignore
|
||||
|
||||
cmd = [self.options.tesseract_cmd, "--version"]
|
||||
|
||||
@@ -108,6 +109,8 @@ class TesseractOcrCliModel(BaseOcrModel):
|
||||
return
|
||||
|
||||
for page in page_batch:
|
||||
assert page._backend is not None
|
||||
|
||||
ocr_rects = self.get_ocr_rects(page)
|
||||
|
||||
all_ocr_cells = []
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import logging
|
||||
from typing import Iterable
|
||||
|
||||
import numpy
|
||||
from docling_core.types.doc import BoundingBox, CoordOrigin
|
||||
|
||||
from docling.datamodel.base_models import BoundingBox, CoordOrigin, OcrCell, Page
|
||||
from docling.datamodel.base_models import OcrCell, Page
|
||||
from docling.datamodel.pipeline_options import TesseractOcrOptions
|
||||
from docling.models.base_ocr_model import BaseOcrModel
|
||||
|
||||
@@ -68,6 +68,9 @@ class TesseractOcrModel(BaseOcrModel):
|
||||
return
|
||||
|
||||
for page in page_batch:
|
||||
assert page._backend is not None
|
||||
assert self.reader is not None
|
||||
|
||||
ocr_rects = self.get_ocr_rects(page)
|
||||
|
||||
all_ocr_cells = []
|
||||
|
||||
Reference in New Issue
Block a user