mirror of
https://github.com/DS4SD/docling.git
synced 2025-07-27 12:34:22 +00:00
Pass nested cluster processing through full pipeline
Signed-off-by: Christoph Auer <cau@zurich.ibm.com> Signed-off-by: Christoph Auer <cau@zurich.ibm.com>
This commit is contained in:
parent
7245cc6080
commit
4dcc738b6d
@ -132,6 +132,12 @@ class LayoutPrediction(BaseModel):
|
|||||||
clusters: List[Cluster] = []
|
clusters: List[Cluster] = []
|
||||||
|
|
||||||
|
|
||||||
|
class ContainerElement(
|
||||||
|
BasePageElement
|
||||||
|
): # Used for Form and Key-Value-Regions, only for typing.
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class Table(BasePageElement):
|
class Table(BasePageElement):
|
||||||
otsl_seq: List[str]
|
otsl_seq: List[str]
|
||||||
num_rows: int = 0
|
num_rows: int = 0
|
||||||
@ -171,7 +177,7 @@ class PagePredictions(BaseModel):
|
|||||||
equations_prediction: Optional[EquationPrediction] = None
|
equations_prediction: Optional[EquationPrediction] = None
|
||||||
|
|
||||||
|
|
||||||
PageElement = Union[TextElement, Table, FigureElement]
|
PageElement = Union[TextElement, Table, FigureElement, ContainerElement]
|
||||||
|
|
||||||
|
|
||||||
class AssembledUnit(BaseModel):
|
class AssembledUnit(BaseModel):
|
||||||
|
@ -77,6 +77,8 @@ layout_label_to_ds_type = {
|
|||||||
DocItemLabel.PICTURE: "figure",
|
DocItemLabel.PICTURE: "figure",
|
||||||
DocItemLabel.TEXT: "paragraph",
|
DocItemLabel.TEXT: "paragraph",
|
||||||
DocItemLabel.PARAGRAPH: "paragraph",
|
DocItemLabel.PARAGRAPH: "paragraph",
|
||||||
|
DocItemLabel.FORM: DocItemLabel.FORM.value,
|
||||||
|
DocItemLabel.KEY_VALUE_REGION: DocItemLabel.KEY_VALUE_REGION.value,
|
||||||
}
|
}
|
||||||
|
|
||||||
_EMPTY_DOCLING_DOC = DoclingDocument(name="dummy")
|
_EMPTY_DOCLING_DOC = DoclingDocument(name="dummy")
|
||||||
|
@ -24,9 +24,15 @@ from docling_core.types.legacy_doc.document import (
|
|||||||
from docling_core.types.legacy_doc.document import CCSFileInfoObject as DsFileInfoObject
|
from docling_core.types.legacy_doc.document import CCSFileInfoObject as DsFileInfoObject
|
||||||
from docling_core.types.legacy_doc.document import ExportedCCSDocument as DsDocument
|
from docling_core.types.legacy_doc.document import ExportedCCSDocument as DsDocument
|
||||||
from PIL import ImageDraw
|
from PIL import ImageDraw
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict, TypeAdapter
|
||||||
|
|
||||||
from docling.datamodel.base_models import Cluster, FigureElement, Table, TextElement
|
from docling.datamodel.base_models import (
|
||||||
|
Cluster,
|
||||||
|
ContainerElement,
|
||||||
|
FigureElement,
|
||||||
|
Table,
|
||||||
|
TextElement,
|
||||||
|
)
|
||||||
from docling.datamodel.document import ConversionResult, layout_label_to_ds_type
|
from docling.datamodel.document import ConversionResult, layout_label_to_ds_type
|
||||||
from docling.datamodel.settings import settings
|
from docling.datamodel.settings import settings
|
||||||
from docling.utils.profiling import ProfilingScope, TimeRecorder
|
from docling.utils.profiling import ProfilingScope, TimeRecorder
|
||||||
@ -45,7 +51,9 @@ class GlmModel:
|
|||||||
|
|
||||||
if self.options.model_names != "":
|
if self.options.model_names != "":
|
||||||
load_pretrained_nlp_models()
|
load_pretrained_nlp_models()
|
||||||
self.model = init_nlp_model(model_names=self.options.model_names)
|
self.model = init_nlp_model(
|
||||||
|
model_names=self.options.model_names, loglevel="ERROR"
|
||||||
|
)
|
||||||
|
|
||||||
def _to_legacy_document(self, conv_res) -> DsDocument:
|
def _to_legacy_document(self, conv_res) -> DsDocument:
|
||||||
title = ""
|
title = ""
|
||||||
@ -207,7 +215,26 @@ class GlmModel:
|
|||||||
)
|
)
|
||||||
],
|
],
|
||||||
obj_type=layout_label_to_ds_type.get(element.label),
|
obj_type=layout_label_to_ds_type.get(element.label),
|
||||||
# data=[[]],
|
payload=TypeAdapter(List[Cluster]).dump_python(
|
||||||
|
element.cluster.children
|
||||||
|
), # hack to channel child clusters through GLM
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif isinstance(element, ContainerElement):
|
||||||
|
main_text.append(
|
||||||
|
BaseText(
|
||||||
|
payload=TypeAdapter(List[Cluster]).dump_python(
|
||||||
|
element.cluster.children
|
||||||
|
), # hack to channel child clusters through GLM
|
||||||
|
obj_type=layout_label_to_ds_type.get(element.label),
|
||||||
|
name=element.label,
|
||||||
|
prov=[
|
||||||
|
Prov(
|
||||||
|
bbox=target_bbox,
|
||||||
|
page=element.page_no + 1,
|
||||||
|
span=[0, 0],
|
||||||
|
)
|
||||||
|
],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -46,6 +46,8 @@ class LayoutModel(BasePageModel):
|
|||||||
FIGURE_LABEL = DocItemLabel.PICTURE
|
FIGURE_LABEL = DocItemLabel.PICTURE
|
||||||
FORMULA_LABEL = DocItemLabel.FORMULA
|
FORMULA_LABEL = DocItemLabel.FORMULA
|
||||||
|
|
||||||
|
CONTAINER_LABELS = [DocItemLabel.FORM, DocItemLabel.KEY_VALUE_REGION]
|
||||||
|
|
||||||
def __init__(self, artifacts_path: Path):
|
def __init__(self, artifacts_path: Path):
|
||||||
self.layout_predictor = LayoutPredictor(artifacts_path) # TODO temporary
|
self.layout_predictor = LayoutPredictor(artifacts_path) # TODO temporary
|
||||||
|
|
||||||
@ -177,24 +179,6 @@ class LayoutModel(BasePageModel):
|
|||||||
cells=[],
|
cells=[],
|
||||||
)
|
)
|
||||||
clusters.append(cluster)
|
clusters.append(cluster)
|
||||||
#
|
|
||||||
# # Map cells to clusters
|
|
||||||
# # TODO: Remove, postprocess should take care of it anyway.
|
|
||||||
# for cell in page.cells:
|
|
||||||
# for cluster in clusters:
|
|
||||||
# if not cell.bbox.area() > 0:
|
|
||||||
# overlap_frac = 0.0
|
|
||||||
# else:
|
|
||||||
# overlap_frac = (
|
|
||||||
# cell.bbox.intersection_area_with(cluster.bbox)
|
|
||||||
# / cell.bbox.area()
|
|
||||||
# )
|
|
||||||
#
|
|
||||||
# if overlap_frac > 0.5:
|
|
||||||
# cluster.cells.append(cell)
|
|
||||||
|
|
||||||
# Pre-sort clusters
|
|
||||||
# clusters = self.sort_clusters_by_cell_order(clusters)
|
|
||||||
|
|
||||||
# DEBUG code:
|
# DEBUG code:
|
||||||
def draw_clusters_and_cells(
|
def draw_clusters_and_cells(
|
||||||
@ -299,18 +283,6 @@ class LayoutModel(BasePageModel):
|
|||||||
clusters=processed_clusters
|
clusters=processed_clusters
|
||||||
)
|
)
|
||||||
|
|
||||||
# def check_for_overlaps(clusters):
|
|
||||||
# for i, cluster in enumerate(clusters):
|
|
||||||
# for j, other_cluster in enumerate(clusters):
|
|
||||||
# if i >= j or cluster.label in [DocItemLabel.FORM, DocItemLabel.KEY_VALUE_REGION, DocItemLabel.PICTURE]:
|
|
||||||
# continue
|
|
||||||
#
|
|
||||||
# overlap_area = cluster.bbox.intersection_area_with(other_cluster.bbox)
|
|
||||||
# if overlap_area > 0:
|
|
||||||
# print(f"Overlap detected between cluster {i} and cluster {j}")
|
|
||||||
# print(f"Cluster {i} bbox: {cluster.bbox}, Cluster {j} bbox: {other_cluster.bbox}")
|
|
||||||
# check_for_overlaps(processed_clusters)
|
|
||||||
|
|
||||||
if settings.debug.visualize_layout:
|
if settings.debug.visualize_layout:
|
||||||
self.draw_clusters_and_cells_side_by_side(
|
self.draw_clusters_and_cells_side_by_side(
|
||||||
conv_res, page, processed_clusters, mode_prefix="postprocessed"
|
conv_res, page, processed_clusters, mode_prefix="postprocessed"
|
||||||
|
@ -6,6 +6,7 @@ from pydantic import BaseModel
|
|||||||
|
|
||||||
from docling.datamodel.base_models import (
|
from docling.datamodel.base_models import (
|
||||||
AssembledUnit,
|
AssembledUnit,
|
||||||
|
ContainerElement,
|
||||||
FigureElement,
|
FigureElement,
|
||||||
Page,
|
Page,
|
||||||
PageElement,
|
PageElement,
|
||||||
@ -159,6 +160,15 @@ class PageAssembleModel(BasePageModel):
|
|||||||
)
|
)
|
||||||
elements.append(equation)
|
elements.append(equation)
|
||||||
body.append(equation)
|
body.append(equation)
|
||||||
|
elif cluster.label in LayoutModel.CONTAINER_LABELS:
|
||||||
|
container_el = ContainerElement(
|
||||||
|
label=cluster.label,
|
||||||
|
id=cluster.id,
|
||||||
|
page_no=page.page_no,
|
||||||
|
cluster=cluster,
|
||||||
|
)
|
||||||
|
elements.append(container_el)
|
||||||
|
body.append(container_el)
|
||||||
|
|
||||||
page.assembled = AssembledUnit(
|
page.assembled = AssembledUnit(
|
||||||
elements=elements, headers=headers, body=body
|
elements=elements, headers=headers, body=body
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import bisect
|
import bisect
|
||||||
import logging
|
import logging
|
||||||
|
import sys
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Dict, List, Set, Tuple
|
from typing import Dict, List, Set, Tuple
|
||||||
|
|
||||||
@ -279,7 +280,16 @@ class LayoutPostprocessor:
|
|||||||
contained.append(cluster)
|
contained.append(cluster)
|
||||||
|
|
||||||
if contained:
|
if contained:
|
||||||
|
# Sort contained clusters by minimum cell ID
|
||||||
|
contained.sort(
|
||||||
|
key=lambda cluster: (
|
||||||
|
min(cell.id for cell in cluster.cells)
|
||||||
|
if cluster.cells
|
||||||
|
else sys.maxsize
|
||||||
|
)
|
||||||
|
)
|
||||||
special.children = contained
|
special.children = contained
|
||||||
|
|
||||||
# Adjust bbox only for wrapper types
|
# Adjust bbox only for wrapper types
|
||||||
if special.label in self.WRAPPER_TYPES:
|
if special.label in self.WRAPPER_TYPES:
|
||||||
special.bbox = BoundingBox(
|
special.bbox = BoundingBox(
|
||||||
|
Loading…
Reference in New Issue
Block a user