mirror of
https://github.com/DS4SD/docling.git
synced 2025-07-27 12:34:22 +00:00
Pass nested cluster processing through full pipeline
Signed-off-by: Christoph Auer <cau@zurich.ibm.com> Signed-off-by: Christoph Auer <cau@zurich.ibm.com>
This commit is contained in:
parent
7245cc6080
commit
4dcc738b6d
@ -132,6 +132,12 @@ class LayoutPrediction(BaseModel):
|
||||
clusters: List[Cluster] = []
|
||||
|
||||
|
||||
class ContainerElement(
|
||||
BasePageElement
|
||||
): # Used for Form and Key-Value-Regions, only for typing.
|
||||
pass
|
||||
|
||||
|
||||
class Table(BasePageElement):
|
||||
otsl_seq: List[str]
|
||||
num_rows: int = 0
|
||||
@ -171,7 +177,7 @@ class PagePredictions(BaseModel):
|
||||
equations_prediction: Optional[EquationPrediction] = None
|
||||
|
||||
|
||||
PageElement = Union[TextElement, Table, FigureElement]
|
||||
PageElement = Union[TextElement, Table, FigureElement, ContainerElement]
|
||||
|
||||
|
||||
class AssembledUnit(BaseModel):
|
||||
|
@ -77,6 +77,8 @@ layout_label_to_ds_type = {
|
||||
DocItemLabel.PICTURE: "figure",
|
||||
DocItemLabel.TEXT: "paragraph",
|
||||
DocItemLabel.PARAGRAPH: "paragraph",
|
||||
DocItemLabel.FORM: DocItemLabel.FORM.value,
|
||||
DocItemLabel.KEY_VALUE_REGION: DocItemLabel.KEY_VALUE_REGION.value,
|
||||
}
|
||||
|
||||
_EMPTY_DOCLING_DOC = DoclingDocument(name="dummy")
|
||||
|
@ -24,9 +24,15 @@ from docling_core.types.legacy_doc.document import (
|
||||
from docling_core.types.legacy_doc.document import CCSFileInfoObject as DsFileInfoObject
|
||||
from docling_core.types.legacy_doc.document import ExportedCCSDocument as DsDocument
|
||||
from PIL import ImageDraw
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from pydantic import BaseModel, ConfigDict, TypeAdapter
|
||||
|
||||
from docling.datamodel.base_models import Cluster, FigureElement, Table, TextElement
|
||||
from docling.datamodel.base_models import (
|
||||
Cluster,
|
||||
ContainerElement,
|
||||
FigureElement,
|
||||
Table,
|
||||
TextElement,
|
||||
)
|
||||
from docling.datamodel.document import ConversionResult, layout_label_to_ds_type
|
||||
from docling.datamodel.settings import settings
|
||||
from docling.utils.profiling import ProfilingScope, TimeRecorder
|
||||
@ -45,7 +51,9 @@ class GlmModel:
|
||||
|
||||
if self.options.model_names != "":
|
||||
load_pretrained_nlp_models()
|
||||
self.model = init_nlp_model(model_names=self.options.model_names)
|
||||
self.model = init_nlp_model(
|
||||
model_names=self.options.model_names, loglevel="ERROR"
|
||||
)
|
||||
|
||||
def _to_legacy_document(self, conv_res) -> DsDocument:
|
||||
title = ""
|
||||
@ -207,7 +215,26 @@ class GlmModel:
|
||||
)
|
||||
],
|
||||
obj_type=layout_label_to_ds_type.get(element.label),
|
||||
# data=[[]],
|
||||
payload=TypeAdapter(List[Cluster]).dump_python(
|
||||
element.cluster.children
|
||||
), # hack to channel child clusters through GLM
|
||||
)
|
||||
)
|
||||
elif isinstance(element, ContainerElement):
|
||||
main_text.append(
|
||||
BaseText(
|
||||
payload=TypeAdapter(List[Cluster]).dump_python(
|
||||
element.cluster.children
|
||||
), # hack to channel child clusters through GLM
|
||||
obj_type=layout_label_to_ds_type.get(element.label),
|
||||
name=element.label,
|
||||
prov=[
|
||||
Prov(
|
||||
bbox=target_bbox,
|
||||
page=element.page_no + 1,
|
||||
span=[0, 0],
|
||||
)
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -46,6 +46,8 @@ class LayoutModel(BasePageModel):
|
||||
FIGURE_LABEL = DocItemLabel.PICTURE
|
||||
FORMULA_LABEL = DocItemLabel.FORMULA
|
||||
|
||||
CONTAINER_LABELS = [DocItemLabel.FORM, DocItemLabel.KEY_VALUE_REGION]
|
||||
|
||||
def __init__(self, artifacts_path: Path):
|
||||
self.layout_predictor = LayoutPredictor(artifacts_path) # TODO temporary
|
||||
|
||||
@ -177,24 +179,6 @@ class LayoutModel(BasePageModel):
|
||||
cells=[],
|
||||
)
|
||||
clusters.append(cluster)
|
||||
#
|
||||
# # Map cells to clusters
|
||||
# # TODO: Remove, postprocess should take care of it anyway.
|
||||
# for cell in page.cells:
|
||||
# for cluster in clusters:
|
||||
# if not cell.bbox.area() > 0:
|
||||
# overlap_frac = 0.0
|
||||
# else:
|
||||
# overlap_frac = (
|
||||
# cell.bbox.intersection_area_with(cluster.bbox)
|
||||
# / cell.bbox.area()
|
||||
# )
|
||||
#
|
||||
# if overlap_frac > 0.5:
|
||||
# cluster.cells.append(cell)
|
||||
|
||||
# Pre-sort clusters
|
||||
# clusters = self.sort_clusters_by_cell_order(clusters)
|
||||
|
||||
# DEBUG code:
|
||||
def draw_clusters_and_cells(
|
||||
@ -299,18 +283,6 @@ class LayoutModel(BasePageModel):
|
||||
clusters=processed_clusters
|
||||
)
|
||||
|
||||
# def check_for_overlaps(clusters):
|
||||
# for i, cluster in enumerate(clusters):
|
||||
# for j, other_cluster in enumerate(clusters):
|
||||
# if i >= j or cluster.label in [DocItemLabel.FORM, DocItemLabel.KEY_VALUE_REGION, DocItemLabel.PICTURE]:
|
||||
# continue
|
||||
#
|
||||
# overlap_area = cluster.bbox.intersection_area_with(other_cluster.bbox)
|
||||
# if overlap_area > 0:
|
||||
# print(f"Overlap detected between cluster {i} and cluster {j}")
|
||||
# print(f"Cluster {i} bbox: {cluster.bbox}, Cluster {j} bbox: {other_cluster.bbox}")
|
||||
# check_for_overlaps(processed_clusters)
|
||||
|
||||
if settings.debug.visualize_layout:
|
||||
self.draw_clusters_and_cells_side_by_side(
|
||||
conv_res, page, processed_clusters, mode_prefix="postprocessed"
|
||||
|
@ -6,6 +6,7 @@ from pydantic import BaseModel
|
||||
|
||||
from docling.datamodel.base_models import (
|
||||
AssembledUnit,
|
||||
ContainerElement,
|
||||
FigureElement,
|
||||
Page,
|
||||
PageElement,
|
||||
@ -159,6 +160,15 @@ class PageAssembleModel(BasePageModel):
|
||||
)
|
||||
elements.append(equation)
|
||||
body.append(equation)
|
||||
elif cluster.label in LayoutModel.CONTAINER_LABELS:
|
||||
container_el = ContainerElement(
|
||||
label=cluster.label,
|
||||
id=cluster.id,
|
||||
page_no=page.page_no,
|
||||
cluster=cluster,
|
||||
)
|
||||
elements.append(container_el)
|
||||
body.append(container_el)
|
||||
|
||||
page.assembled = AssembledUnit(
|
||||
elements=elements, headers=headers, body=body
|
||||
|
@ -1,5 +1,6 @@
|
||||
import bisect
|
||||
import logging
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List, Set, Tuple
|
||||
|
||||
@ -279,7 +280,16 @@ class LayoutPostprocessor:
|
||||
contained.append(cluster)
|
||||
|
||||
if contained:
|
||||
# Sort contained clusters by minimum cell ID
|
||||
contained.sort(
|
||||
key=lambda cluster: (
|
||||
min(cell.id for cell in cluster.cells)
|
||||
if cluster.cells
|
||||
else sys.maxsize
|
||||
)
|
||||
)
|
||||
special.children = contained
|
||||
|
||||
# Adjust bbox only for wrapper types
|
||||
if special.label in self.WRAPPER_TYPES:
|
||||
special.bbox = BoundingBox(
|
||||
|
Loading…
Reference in New Issue
Block a user