mirror of
https://github.com/DS4SD/docling.git
synced 2025-12-16 16:48:21 +00:00
feat: Updated Layout processing with forms and key-value areas (#530)
* Upgraded Layout Postprocessing, sending old code back to ERZ Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Implement hierachical cluster layout processing Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Pass nested cluster processing through full pipeline Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Pass nested clusters through GLM as payload Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Move to_docling_document from ds-glm to this repo Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Clean up imports again Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * feat(Accelerator): Introduce options to control the num_threads and device from API, envvars, CLI. - Introduce the AcceleratorOptions, AcceleratorDevice and use them to set the device where the models run. - Introduce the accelerator_utils with function to decide the device and resolve the AUTO setting. - Refactor the way how the docling-ibm-models are called to match the new init signature of models. - Translate the accelerator options to the specific inputs for third-party models. - Extend the docling CLI with parameters to set the num_threads and device. - Add new unit tests. - Write new example how to use the accelerator options. * fix: Improve the pydantic objects in the pipeline_options and imports. Signed-off-by: Nikos Livathinos <nli@zurich.ibm.com> * fix: TableStructureModel: Refactor the artifacts path to use the new structure for fast/accurate model Signed-off-by: Nikos Livathinos <nli@zurich.ibm.com> * Updated test ground-truth Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Updated test ground-truth (again), bugfix for empty layout Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * fix: Do proper check to set the device in EasyOCR, RapidOCR. Signed-off-by: Nikos Livathinos <nli@zurich.ibm.com> * fix: Correct the way to set GPU for EasyOCR, RapidOCR Signed-off-by: Nikos Livathinos <nli@zurich.ibm.com> * fix: Ocr AccleratorDevice Signed-off-by: Nikos Livathinos <nli@zurich.ibm.com> * Merge pull request #556 from DS4SD/cau/layout-processing-improvement feat: layout processing improvements and bugfixes * Update lockfile Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Update tests Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Update HF model ref, reset test generate Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Repin to release package versions Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Many layout processing improvements, add document index type Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Update pinnings to docling-core Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Update test GT Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Fix table box snapping Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Fixes for cluster pre-ordering Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Introduce OCR confidence, propagate to orphan in post-processing Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Fix form and key value area groups Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Adjust confidence in EasyOcr Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Roll back CLI changes from main Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Update test GT Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Update docling-core pinning Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Annoying fixes for historical python versions Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Updated test GT for legacy Signed-off-by: Christoph Auer <cau@zurich.ibm.com> * Comment cleanup Signed-off-by: Christoph Auer <cau@zurich.ibm.com> --------- Signed-off-by: Christoph Auer <cau@zurich.ibm.com> Signed-off-by: Nikos Livathinos <nli@zurich.ibm.com> Co-authored-by: Nikos Livathinos <nli@zurich.ibm.com>
This commit is contained in:
@@ -22,9 +22,15 @@ from docling_core.types.legacy_doc.document import (
|
||||
from docling_core.types.legacy_doc.document import CCSFileInfoObject as DsFileInfoObject
|
||||
from docling_core.types.legacy_doc.document import ExportedCCSDocument as DsDocument
|
||||
from PIL import ImageDraw
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from pydantic import BaseModel, ConfigDict, TypeAdapter
|
||||
|
||||
from docling.datamodel.base_models import Cluster, FigureElement, Table, TextElement
|
||||
from docling.datamodel.base_models import (
|
||||
Cluster,
|
||||
ContainerElement,
|
||||
FigureElement,
|
||||
Table,
|
||||
TextElement,
|
||||
)
|
||||
from docling.datamodel.document import ConversionResult, layout_label_to_ds_type
|
||||
from docling.datamodel.settings import settings
|
||||
from docling.utils.glm_utils import to_docling_document
|
||||
@@ -204,7 +210,31 @@ class GlmModel:
|
||||
)
|
||||
],
|
||||
obj_type=layout_label_to_ds_type.get(element.label),
|
||||
# data=[[]],
|
||||
payload={
|
||||
"children": TypeAdapter(List[Cluster]).dump_python(
|
||||
element.cluster.children
|
||||
)
|
||||
}, # hack to channel child clusters through GLM
|
||||
)
|
||||
)
|
||||
elif isinstance(element, ContainerElement):
|
||||
main_text.append(
|
||||
BaseText(
|
||||
text="",
|
||||
payload={
|
||||
"children": TypeAdapter(List[Cluster]).dump_python(
|
||||
element.cluster.children
|
||||
)
|
||||
}, # hack to channel child clusters through GLM
|
||||
obj_type=layout_label_to_ds_type.get(element.label),
|
||||
name=element.label,
|
||||
prov=[
|
||||
Prov(
|
||||
bbox=target_bbox,
|
||||
page=element.page_no + 1,
|
||||
span=[0, 0],
|
||||
)
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -118,6 +118,7 @@ class EasyOcrModel(BaseOcrModel):
|
||||
),
|
||||
)
|
||||
for ix, line in enumerate(result)
|
||||
if line[2] >= self.options.confidence_threshold
|
||||
]
|
||||
all_ocr_cells.extend(cells)
|
||||
|
||||
|
||||
@@ -7,9 +7,8 @@ from typing import Iterable, List
|
||||
|
||||
from docling_core.types.doc import CoordOrigin, DocItemLabel
|
||||
from docling_ibm_models.layoutmodel.layout_predictor import LayoutPredictor
|
||||
from PIL import ImageDraw
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
|
||||
import docling.utils.layout_utils as lu
|
||||
from docling.datamodel.base_models import (
|
||||
BoundingBox,
|
||||
Cell,
|
||||
@@ -22,6 +21,7 @@ from docling.datamodel.pipeline_options import AcceleratorDevice, AcceleratorOpt
|
||||
from docling.datamodel.settings import settings
|
||||
from docling.models.base_model import BasePageModel
|
||||
from docling.utils.accelerator_utils import decide_device
|
||||
from docling.utils.layout_postprocessor import LayoutPostprocessor
|
||||
from docling.utils.profiling import TimeRecorder
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
@@ -44,9 +44,10 @@ class LayoutModel(BasePageModel):
|
||||
]
|
||||
PAGE_HEADER_LABELS = [DocItemLabel.PAGE_HEADER, DocItemLabel.PAGE_FOOTER]
|
||||
|
||||
TABLE_LABEL = DocItemLabel.TABLE
|
||||
TABLE_LABELS = [DocItemLabel.TABLE, DocItemLabel.DOCUMENT_INDEX]
|
||||
FIGURE_LABEL = DocItemLabel.PICTURE
|
||||
FORMULA_LABEL = DocItemLabel.FORMULA
|
||||
CONTAINER_LABELS = [DocItemLabel.FORM, DocItemLabel.KEY_VALUE_REGION]
|
||||
|
||||
def __init__(self, artifacts_path: Path, accelerator_options: AcceleratorOptions):
|
||||
device = decide_device(accelerator_options.device)
|
||||
@@ -55,234 +56,127 @@ class LayoutModel(BasePageModel):
|
||||
artifact_path=str(artifacts_path),
|
||||
device=device,
|
||||
num_threads=accelerator_options.num_threads,
|
||||
base_threshold=0.6,
|
||||
blacklist_classes={"Form", "Key-Value Region"},
|
||||
)
|
||||
|
||||
def postprocess(self, clusters_in: List[Cluster], cells: List[Cell], page_height):
|
||||
MIN_INTERSECTION = 0.2
|
||||
CLASS_THRESHOLDS = {
|
||||
DocItemLabel.CAPTION: 0.35,
|
||||
DocItemLabel.FOOTNOTE: 0.35,
|
||||
DocItemLabel.FORMULA: 0.35,
|
||||
DocItemLabel.LIST_ITEM: 0.35,
|
||||
DocItemLabel.PAGE_FOOTER: 0.35,
|
||||
DocItemLabel.PAGE_HEADER: 0.35,
|
||||
DocItemLabel.PICTURE: 0.2, # low threshold adjust to capture chemical structures for examples.
|
||||
DocItemLabel.SECTION_HEADER: 0.45,
|
||||
DocItemLabel.TABLE: 0.35,
|
||||
DocItemLabel.TEXT: 0.45,
|
||||
DocItemLabel.TITLE: 0.45,
|
||||
DocItemLabel.DOCUMENT_INDEX: 0.45,
|
||||
DocItemLabel.CODE: 0.45,
|
||||
DocItemLabel.CHECKBOX_SELECTED: 0.45,
|
||||
DocItemLabel.CHECKBOX_UNSELECTED: 0.45,
|
||||
DocItemLabel.FORM: 0.45,
|
||||
DocItemLabel.KEY_VALUE_REGION: 0.45,
|
||||
def draw_clusters_and_cells_side_by_side(
|
||||
self, conv_res, page, clusters, mode_prefix: str, show: bool = False
|
||||
):
|
||||
"""
|
||||
Draws a page image side by side with clusters filtered into two categories:
|
||||
- Left: Clusters excluding FORM, KEY_VALUE_REGION, and PICTURE.
|
||||
- Right: Clusters including FORM, KEY_VALUE_REGION, and PICTURE.
|
||||
Includes label names and confidence scores for each cluster.
|
||||
"""
|
||||
label_to_color = {
|
||||
DocItemLabel.TEXT: (255, 255, 153), # Light Yellow
|
||||
DocItemLabel.CAPTION: (255, 204, 153), # Light Orange
|
||||
DocItemLabel.LIST_ITEM: (153, 153, 255), # Light Purple
|
||||
DocItemLabel.FORMULA: (192, 192, 192), # Gray
|
||||
DocItemLabel.TABLE: (255, 204, 204), # Light Pink
|
||||
DocItemLabel.PICTURE: (255, 204, 164), # Light Beige
|
||||
DocItemLabel.SECTION_HEADER: (255, 153, 153), # Light Red
|
||||
DocItemLabel.PAGE_HEADER: (204, 255, 204), # Light Green
|
||||
DocItemLabel.PAGE_FOOTER: (
|
||||
204,
|
||||
255,
|
||||
204,
|
||||
), # Light Green (same as Page-Header)
|
||||
DocItemLabel.TITLE: (255, 153, 153), # Light Red (same as Section-Header)
|
||||
DocItemLabel.FOOTNOTE: (200, 200, 255), # Light Blue
|
||||
DocItemLabel.DOCUMENT_INDEX: (220, 220, 220), # Light Gray
|
||||
DocItemLabel.CODE: (125, 125, 125), # Gray
|
||||
DocItemLabel.CHECKBOX_SELECTED: (255, 182, 193), # Pale Green
|
||||
DocItemLabel.CHECKBOX_UNSELECTED: (255, 182, 193), # Light Pink
|
||||
DocItemLabel.FORM: (200, 255, 255), # Light Cyan
|
||||
DocItemLabel.KEY_VALUE_REGION: (183, 65, 14), # Rusty orange
|
||||
}
|
||||
|
||||
CLASS_REMAPPINGS = {
|
||||
DocItemLabel.DOCUMENT_INDEX: DocItemLabel.TABLE,
|
||||
DocItemLabel.TITLE: DocItemLabel.SECTION_HEADER,
|
||||
# Filter clusters for left and right images
|
||||
exclude_labels = {
|
||||
DocItemLabel.FORM,
|
||||
DocItemLabel.KEY_VALUE_REGION,
|
||||
DocItemLabel.PICTURE,
|
||||
}
|
||||
left_clusters = [c for c in clusters if c.label not in exclude_labels]
|
||||
right_clusters = [c for c in clusters if c.label in exclude_labels]
|
||||
# Create a deep copy of the original image for both sides
|
||||
left_image = copy.deepcopy(page.image)
|
||||
right_image = copy.deepcopy(page.image)
|
||||
|
||||
_log.debug("================= Start postprocess function ====================")
|
||||
start_time = time.time()
|
||||
# Apply Confidence Threshold to cluster predictions
|
||||
# confidence = self.conf_threshold
|
||||
clusters_mod = []
|
||||
# Function to draw clusters on an image
|
||||
def draw_clusters(image, clusters):
|
||||
draw = ImageDraw.Draw(image, "RGBA")
|
||||
# Create a smaller font for the labels
|
||||
try:
|
||||
font = ImageFont.truetype("arial.ttf", 12)
|
||||
except OSError:
|
||||
# Fallback to default font if arial is not available
|
||||
font = ImageFont.load_default()
|
||||
for c_tl in clusters:
|
||||
all_clusters = [c_tl, *c_tl.children]
|
||||
for c in all_clusters:
|
||||
# Draw cells first (underneath)
|
||||
cell_color = (0, 0, 0, 40) # Transparent black for cells
|
||||
for tc in c.cells:
|
||||
cx0, cy0, cx1, cy1 = tc.bbox.as_tuple()
|
||||
draw.rectangle(
|
||||
[(cx0, cy0), (cx1, cy1)],
|
||||
outline=None,
|
||||
fill=cell_color,
|
||||
)
|
||||
# Draw cluster rectangle
|
||||
x0, y0, x1, y1 = c.bbox.as_tuple()
|
||||
cluster_fill_color = (*list(label_to_color.get(c.label)), 70)
|
||||
cluster_outline_color = (*list(label_to_color.get(c.label)), 255)
|
||||
draw.rectangle(
|
||||
[(x0, y0), (x1, y1)],
|
||||
outline=cluster_outline_color,
|
||||
fill=cluster_fill_color,
|
||||
)
|
||||
# Add label name and confidence
|
||||
label_text = f"{c.label.name} ({c.confidence:.2f})"
|
||||
# Create semi-transparent background for text
|
||||
text_bbox = draw.textbbox((x0, y0), label_text, font=font)
|
||||
text_bg_padding = 2
|
||||
draw.rectangle(
|
||||
[
|
||||
(
|
||||
text_bbox[0] - text_bg_padding,
|
||||
text_bbox[1] - text_bg_padding,
|
||||
),
|
||||
(
|
||||
text_bbox[2] + text_bg_padding,
|
||||
text_bbox[3] + text_bg_padding,
|
||||
),
|
||||
],
|
||||
fill=(255, 255, 255, 180), # Semi-transparent white
|
||||
)
|
||||
# Draw text
|
||||
draw.text(
|
||||
(x0, y0),
|
||||
label_text,
|
||||
fill=(0, 0, 0, 255), # Solid black
|
||||
font=font,
|
||||
)
|
||||
|
||||
for cluster in clusters_in:
|
||||
confidence = CLASS_THRESHOLDS[cluster.label]
|
||||
if cluster.confidence >= confidence:
|
||||
# annotation["created_by"] = "high_conf_pred"
|
||||
|
||||
# Remap class labels where needed.
|
||||
if cluster.label in CLASS_REMAPPINGS.keys():
|
||||
cluster.label = CLASS_REMAPPINGS[cluster.label]
|
||||
clusters_mod.append(cluster)
|
||||
|
||||
# map to dictionary clusters and cells, with bottom left origin
|
||||
clusters_orig = [
|
||||
{
|
||||
"id": c.id,
|
||||
"bbox": list(
|
||||
c.bbox.to_bottom_left_origin(page_height).as_tuple()
|
||||
), # TODO
|
||||
"confidence": c.confidence,
|
||||
"cell_ids": [],
|
||||
"type": c.label,
|
||||
}
|
||||
for c in clusters_in
|
||||
]
|
||||
|
||||
clusters_out = [
|
||||
{
|
||||
"id": c.id,
|
||||
"bbox": list(
|
||||
c.bbox.to_bottom_left_origin(page_height).as_tuple()
|
||||
), # TODO
|
||||
"confidence": c.confidence,
|
||||
"created_by": "high_conf_pred",
|
||||
"cell_ids": [],
|
||||
"type": c.label,
|
||||
}
|
||||
for c in clusters_mod
|
||||
]
|
||||
|
||||
del clusters_mod
|
||||
|
||||
raw_cells = [
|
||||
{
|
||||
"id": c.id,
|
||||
"bbox": list(
|
||||
c.bbox.to_bottom_left_origin(page_height).as_tuple()
|
||||
), # TODO
|
||||
"text": c.text,
|
||||
}
|
||||
for c in cells
|
||||
]
|
||||
cell_count = len(raw_cells)
|
||||
|
||||
_log.debug("---- 0. Treat cluster overlaps ------")
|
||||
clusters_out = lu.remove_cluster_duplicates_by_conf(clusters_out, 0.8)
|
||||
|
||||
_log.debug(
|
||||
"---- 1. Initially assign cells to clusters based on minimum intersection ------"
|
||||
)
|
||||
## Check for cells included in or touched by clusters:
|
||||
clusters_out = lu.assigning_cell_ids_to_clusters(
|
||||
clusters_out, raw_cells, MIN_INTERSECTION
|
||||
)
|
||||
|
||||
_log.debug("---- 2. Assign Orphans with Low Confidence Detections")
|
||||
# Creates a map of cell_id->cluster_id
|
||||
(
|
||||
clusters_around_cells,
|
||||
orphan_cell_indices,
|
||||
ambiguous_cell_indices,
|
||||
) = lu.cell_id_state_map(clusters_out, cell_count)
|
||||
|
||||
# Assign orphan cells with lower confidence predictions
|
||||
clusters_out, orphan_cell_indices = lu.assign_orphans_with_low_conf_pred(
|
||||
clusters_out, clusters_orig, raw_cells, orphan_cell_indices
|
||||
)
|
||||
|
||||
# Refresh the cell_ids assignment, after creating new clusters using low conf predictions
|
||||
clusters_out = lu.assigning_cell_ids_to_clusters(
|
||||
clusters_out, raw_cells, MIN_INTERSECTION
|
||||
)
|
||||
|
||||
_log.debug("---- 3. Settle Ambigous Cells")
|
||||
# Creates an update map after assignment of cell_id->cluster_id
|
||||
(
|
||||
clusters_around_cells,
|
||||
orphan_cell_indices,
|
||||
ambiguous_cell_indices,
|
||||
) = lu.cell_id_state_map(clusters_out, cell_count)
|
||||
|
||||
# Settle pdf cells that belong to multiple clusters
|
||||
clusters_out, ambiguous_cell_indices = lu.remove_ambigous_pdf_cell_by_conf(
|
||||
clusters_out, raw_cells, ambiguous_cell_indices
|
||||
)
|
||||
|
||||
_log.debug("---- 4. Set Orphans as Text")
|
||||
(
|
||||
clusters_around_cells,
|
||||
orphan_cell_indices,
|
||||
ambiguous_cell_indices,
|
||||
) = lu.cell_id_state_map(clusters_out, cell_count)
|
||||
|
||||
clusters_out, orphan_cell_indices = lu.set_orphan_as_text(
|
||||
clusters_out, clusters_orig, raw_cells, orphan_cell_indices
|
||||
)
|
||||
|
||||
_log.debug("---- 5. Merge Cells & and adapt the bounding boxes")
|
||||
# Merge cells orphan cells
|
||||
clusters_out = lu.merge_cells(clusters_out)
|
||||
|
||||
# Clean up clusters that remain from merged and unreasonable clusters
|
||||
clusters_out = lu.clean_up_clusters(
|
||||
clusters_out,
|
||||
raw_cells,
|
||||
merge_cells=True,
|
||||
img_table=True,
|
||||
one_cell_table=True,
|
||||
)
|
||||
|
||||
new_clusters = lu.adapt_bboxes(raw_cells, clusters_out, orphan_cell_indices)
|
||||
clusters_out = new_clusters
|
||||
|
||||
## We first rebuild where every cell is now:
|
||||
## Now we write into a prediction cells list, not into the raw cells list.
|
||||
## As we don't need previous labels, we best overwrite any old list, because that might
|
||||
## have been sorted differently.
|
||||
(
|
||||
clusters_around_cells,
|
||||
orphan_cell_indices,
|
||||
ambiguous_cell_indices,
|
||||
) = lu.cell_id_state_map(clusters_out, cell_count)
|
||||
|
||||
target_cells = []
|
||||
for ix, cell in enumerate(raw_cells):
|
||||
new_cell = {
|
||||
"id": ix,
|
||||
"rawcell_id": ix,
|
||||
"label": "None",
|
||||
"bbox": cell["bbox"],
|
||||
"text": cell["text"],
|
||||
}
|
||||
for cluster_index in clusters_around_cells[
|
||||
ix
|
||||
]: # By previous analysis, this is always 1 cluster.
|
||||
new_cell["label"] = clusters_out[cluster_index]["type"]
|
||||
target_cells.append(new_cell)
|
||||
# _log.debug("New label of cell " + str(ix) + " is " + str(new_cell["label"]))
|
||||
cells_out = target_cells
|
||||
|
||||
## -------------------------------
|
||||
## Sort clusters into reasonable reading order, and sort the cells inside each cluster
|
||||
_log.debug("---- 5. Sort clusters in reading order ------")
|
||||
sorted_clusters = lu.produce_reading_order(
|
||||
clusters_out, "raw_cell_ids", "raw_cell_ids", True
|
||||
)
|
||||
clusters_out = sorted_clusters
|
||||
|
||||
# end_time = timer()
|
||||
_log.debug("---- End of postprocessing function ------")
|
||||
end_time = time.time() - start_time
|
||||
_log.debug(f"Finished post processing in seconds={end_time:.3f}")
|
||||
|
||||
cells_out_new = [
|
||||
Cell(
|
||||
id=c["id"], # type: ignore
|
||||
bbox=BoundingBox.from_tuple(
|
||||
coord=c["bbox"], origin=CoordOrigin.BOTTOMLEFT # type: ignore
|
||||
).to_top_left_origin(page_height),
|
||||
text=c["text"], # type: ignore
|
||||
# Draw clusters on both images
|
||||
draw_clusters(left_image, left_clusters)
|
||||
draw_clusters(right_image, right_clusters)
|
||||
# Combine the images side by side
|
||||
combined_width = left_image.width * 2
|
||||
combined_height = left_image.height
|
||||
combined_image = Image.new("RGB", (combined_width, combined_height))
|
||||
combined_image.paste(left_image, (0, 0))
|
||||
combined_image.paste(right_image, (left_image.width, 0))
|
||||
if show:
|
||||
combined_image.show()
|
||||
else:
|
||||
out_path: Path = (
|
||||
Path(settings.debug.debug_output_path)
|
||||
/ f"debug_{conv_res.input.file.stem}"
|
||||
)
|
||||
for c in cells_out
|
||||
]
|
||||
|
||||
del cells_out
|
||||
|
||||
clusters_out_new = []
|
||||
for c in clusters_out:
|
||||
cluster_cells = [
|
||||
ccell for ccell in cells_out_new if ccell.id in c["cell_ids"] # type: ignore
|
||||
]
|
||||
c_new = Cluster(
|
||||
id=c["id"], # type: ignore
|
||||
bbox=BoundingBox.from_tuple(
|
||||
coord=c["bbox"], origin=CoordOrigin.BOTTOMLEFT # type: ignore
|
||||
).to_top_left_origin(page_height),
|
||||
confidence=c["confidence"], # type: ignore
|
||||
label=DocItemLabel(c["type"]),
|
||||
cells=cluster_cells,
|
||||
)
|
||||
clusters_out_new.append(c_new)
|
||||
|
||||
return clusters_out_new, cells_out_new
|
||||
out_path.mkdir(parents=True, exist_ok=True)
|
||||
out_file = out_path / f"{mode_prefix}_layout_page_{page.page_no:05}.png"
|
||||
combined_image.save(str(out_file), format="png")
|
||||
|
||||
def __call__(
|
||||
self, conv_res: ConversionResult, page_batch: Iterable[Page]
|
||||
@@ -315,66 +209,26 @@ class LayoutModel(BasePageModel):
|
||||
)
|
||||
clusters.append(cluster)
|
||||
|
||||
# Map cells to clusters
|
||||
# TODO: Remove, postprocess should take care of it anyway.
|
||||
for cell in page.cells:
|
||||
for cluster in clusters:
|
||||
if not cell.bbox.area() > 0:
|
||||
overlap_frac = 0.0
|
||||
else:
|
||||
overlap_frac = (
|
||||
cell.bbox.intersection_area_with(cluster.bbox)
|
||||
/ cell.bbox.area()
|
||||
)
|
||||
if settings.debug.visualize_raw_layout:
|
||||
self.draw_clusters_and_cells_side_by_side(
|
||||
conv_res, page, clusters, mode_prefix="raw"
|
||||
)
|
||||
|
||||
if overlap_frac > 0.5:
|
||||
cluster.cells.append(cell)
|
||||
# Apply postprocessing
|
||||
|
||||
# Pre-sort clusters
|
||||
# clusters = self.sort_clusters_by_cell_order(clusters)
|
||||
processed_clusters, processed_cells = LayoutPostprocessor(
|
||||
page.cells, clusters, page.size
|
||||
).postprocess()
|
||||
# processed_clusters, processed_cells = clusters, page.cells
|
||||
|
||||
# DEBUG code:
|
||||
def draw_clusters_and_cells(show: bool = False):
|
||||
image = copy.deepcopy(page.image)
|
||||
if image is not None:
|
||||
draw = ImageDraw.Draw(image)
|
||||
for c in clusters:
|
||||
x0, y0, x1, y1 = c.bbox.as_tuple()
|
||||
draw.rectangle([(x0, y0), (x1, y1)], outline="green")
|
||||
|
||||
cell_color = (
|
||||
random.randint(30, 140),
|
||||
random.randint(30, 140),
|
||||
random.randint(30, 140),
|
||||
)
|
||||
for tc in c.cells: # [:1]:
|
||||
x0, y0, x1, y1 = tc.bbox.as_tuple()
|
||||
draw.rectangle(
|
||||
[(x0, y0), (x1, y1)], outline=cell_color
|
||||
)
|
||||
if show:
|
||||
image.show()
|
||||
else:
|
||||
out_path: Path = (
|
||||
Path(settings.debug.debug_output_path)
|
||||
/ f"debug_{conv_res.input.file.stem}"
|
||||
)
|
||||
out_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
out_file = (
|
||||
out_path / f"layout_page_{page.page_no:05}.png"
|
||||
)
|
||||
image.save(str(out_file), format="png")
|
||||
|
||||
# draw_clusters_and_cells()
|
||||
|
||||
clusters, page.cells = self.postprocess(
|
||||
clusters, page.cells, page.size.height
|
||||
page.cells = processed_cells
|
||||
page.predictions.layout = LayoutPrediction(
|
||||
clusters=processed_clusters
|
||||
)
|
||||
|
||||
page.predictions.layout = LayoutPrediction(clusters=clusters)
|
||||
|
||||
if settings.debug.visualize_layout:
|
||||
draw_clusters_and_cells()
|
||||
self.draw_clusters_and_cells_side_by_side(
|
||||
conv_res, page, processed_clusters, mode_prefix="postprocessed"
|
||||
)
|
||||
|
||||
yield page
|
||||
|
||||
@@ -6,6 +6,7 @@ from pydantic import BaseModel
|
||||
|
||||
from docling.datamodel.base_models import (
|
||||
AssembledUnit,
|
||||
ContainerElement,
|
||||
FigureElement,
|
||||
Page,
|
||||
PageElement,
|
||||
@@ -94,7 +95,7 @@ class PageAssembleModel(BasePageModel):
|
||||
headers.append(text_el)
|
||||
else:
|
||||
body.append(text_el)
|
||||
elif cluster.label == LayoutModel.TABLE_LABEL:
|
||||
elif cluster.label in LayoutModel.TABLE_LABELS:
|
||||
tbl = None
|
||||
if page.predictions.tablestructure:
|
||||
tbl = page.predictions.tablestructure.table_map.get(
|
||||
@@ -159,6 +160,15 @@ class PageAssembleModel(BasePageModel):
|
||||
)
|
||||
elements.append(equation)
|
||||
body.append(equation)
|
||||
elif cluster.label in LayoutModel.CONTAINER_LABELS:
|
||||
container_el = ContainerElement(
|
||||
label=cluster.label,
|
||||
id=cluster.id,
|
||||
page_no=page.page_no,
|
||||
cluster=cluster,
|
||||
)
|
||||
elements.append(container_el)
|
||||
body.append(container_el)
|
||||
|
||||
page.assembled = AssembledUnit(
|
||||
elements=elements, headers=headers, body=body
|
||||
|
||||
@@ -76,6 +76,10 @@ class TableStructureModel(BasePageModel):
|
||||
x0, y0, x1, y1 = table_element.cluster.bbox.as_tuple()
|
||||
draw.rectangle([(x0, y0), (x1, y1)], outline="red")
|
||||
|
||||
for cell in table_element.cluster.cells:
|
||||
x0, y0, x1, y1 = cell.bbox.as_tuple()
|
||||
draw.rectangle([(x0, y0), (x1, y1)], outline="green")
|
||||
|
||||
for tc in table_element.table_cells:
|
||||
if tc.bbox is not None:
|
||||
x0, y0, x1, y1 = tc.bbox.as_tuple()
|
||||
@@ -89,7 +93,6 @@ class TableStructureModel(BasePageModel):
|
||||
text=f"{tc.start_row_offset_idx}, {tc.start_col_offset_idx}",
|
||||
fill="black",
|
||||
)
|
||||
|
||||
if show:
|
||||
image.show()
|
||||
else:
|
||||
@@ -135,47 +138,40 @@ class TableStructureModel(BasePageModel):
|
||||
],
|
||||
)
|
||||
for cluster in page.predictions.layout.clusters
|
||||
if cluster.label == DocItemLabel.TABLE
|
||||
if cluster.label
|
||||
in [DocItemLabel.TABLE, DocItemLabel.DOCUMENT_INDEX]
|
||||
]
|
||||
if not len(in_tables):
|
||||
yield page
|
||||
continue
|
||||
|
||||
tokens = []
|
||||
for c in page.cells:
|
||||
for cluster, _ in in_tables:
|
||||
if c.bbox.area() > 0:
|
||||
if (
|
||||
c.bbox.intersection_area_with(cluster.bbox)
|
||||
/ c.bbox.area()
|
||||
> 0.2
|
||||
):
|
||||
# Only allow non empty stings (spaces) into the cells of a table
|
||||
if len(c.text.strip()) > 0:
|
||||
new_cell = copy.deepcopy(c)
|
||||
new_cell.bbox = new_cell.bbox.scaled(
|
||||
scale=self.scale
|
||||
)
|
||||
|
||||
tokens.append(new_cell.model_dump())
|
||||
|
||||
page_input = {
|
||||
"tokens": tokens,
|
||||
"width": page.size.width * self.scale,
|
||||
"height": page.size.height * self.scale,
|
||||
"image": numpy.asarray(page.get_image(scale=self.scale)),
|
||||
}
|
||||
page_input["image"] = numpy.asarray(
|
||||
page.get_image(scale=self.scale)
|
||||
)
|
||||
|
||||
table_clusters, table_bboxes = zip(*in_tables)
|
||||
|
||||
if len(table_bboxes):
|
||||
tf_output = self.tf_predictor.multi_table_predict(
|
||||
page_input, table_bboxes, do_matching=self.do_cell_matching
|
||||
)
|
||||
for table_cluster, tbl_box in in_tables:
|
||||
|
||||
for table_cluster, table_out in zip(table_clusters, tf_output):
|
||||
tokens = []
|
||||
for c in table_cluster.cells:
|
||||
# Only allow non empty stings (spaces) into the cells of a table
|
||||
if len(c.text.strip()) > 0:
|
||||
new_cell = copy.deepcopy(c)
|
||||
new_cell.bbox = new_cell.bbox.scaled(
|
||||
scale=self.scale
|
||||
)
|
||||
|
||||
tokens.append(new_cell.model_dump())
|
||||
page_input["tokens"] = tokens
|
||||
|
||||
tf_output = self.tf_predictor.multi_table_predict(
|
||||
page_input, [tbl_box], do_matching=self.do_cell_matching
|
||||
)
|
||||
table_out = tf_output[0]
|
||||
table_cells = []
|
||||
for element in table_out["tf_responses"]:
|
||||
|
||||
@@ -208,7 +204,7 @@ class TableStructureModel(BasePageModel):
|
||||
id=table_cluster.id,
|
||||
page_no=page.page_no,
|
||||
cluster=table_cluster,
|
||||
label=DocItemLabel.TABLE,
|
||||
label=table_cluster.label,
|
||||
)
|
||||
|
||||
page.predictions.tablestructure.table_map[
|
||||
|
||||
Reference in New Issue
Block a user