diff --git a/docling/datamodel/base_models.py b/docling/datamodel/base_models.py index 281d6735..2b09aafd 100644 --- a/docling/datamodel/base_models.py +++ b/docling/datamodel/base_models.py @@ -132,6 +132,12 @@ class LayoutPrediction(BaseModel): clusters: List[Cluster] = [] +class ContainerElement( + BasePageElement +): # Used for Form and Key-Value-Regions, only for typing. + pass + + class Table(BasePageElement): otsl_seq: List[str] num_rows: int = 0 @@ -171,7 +177,7 @@ class PagePredictions(BaseModel): equations_prediction: Optional[EquationPrediction] = None -PageElement = Union[TextElement, Table, FigureElement] +PageElement = Union[TextElement, Table, FigureElement, ContainerElement] class AssembledUnit(BaseModel): diff --git a/docling/datamodel/document.py b/docling/datamodel/document.py index be4e9a12..3be66ade 100644 --- a/docling/datamodel/document.py +++ b/docling/datamodel/document.py @@ -77,6 +77,8 @@ layout_label_to_ds_type = { DocItemLabel.PICTURE: "figure", DocItemLabel.TEXT: "paragraph", DocItemLabel.PARAGRAPH: "paragraph", + DocItemLabel.FORM: DocItemLabel.FORM.value, + DocItemLabel.KEY_VALUE_REGION: DocItemLabel.KEY_VALUE_REGION.value, } _EMPTY_DOCLING_DOC = DoclingDocument(name="dummy") diff --git a/docling/models/ds_glm_model.py b/docling/models/ds_glm_model.py index 0a066bfa..c17f0b17 100644 --- a/docling/models/ds_glm_model.py +++ b/docling/models/ds_glm_model.py @@ -24,9 +24,15 @@ from docling_core.types.legacy_doc.document import ( from docling_core.types.legacy_doc.document import CCSFileInfoObject as DsFileInfoObject from docling_core.types.legacy_doc.document import ExportedCCSDocument as DsDocument from PIL import ImageDraw -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel, ConfigDict, TypeAdapter -from docling.datamodel.base_models import Cluster, FigureElement, Table, TextElement +from docling.datamodel.base_models import ( + Cluster, + ContainerElement, + FigureElement, + Table, + TextElement, +) from docling.datamodel.document import ConversionResult, layout_label_to_ds_type from docling.datamodel.settings import settings from docling.utils.profiling import ProfilingScope, TimeRecorder @@ -45,7 +51,9 @@ class GlmModel: if self.options.model_names != "": load_pretrained_nlp_models() - self.model = init_nlp_model(model_names=self.options.model_names) + self.model = init_nlp_model( + model_names=self.options.model_names, loglevel="ERROR" + ) def _to_legacy_document(self, conv_res) -> DsDocument: title = "" @@ -207,7 +215,26 @@ class GlmModel: ) ], obj_type=layout_label_to_ds_type.get(element.label), - # data=[[]], + payload=TypeAdapter(List[Cluster]).dump_python( + element.cluster.children + ), # hack to channel child clusters through GLM + ) + ) + elif isinstance(element, ContainerElement): + main_text.append( + BaseText( + payload=TypeAdapter(List[Cluster]).dump_python( + element.cluster.children + ), # hack to channel child clusters through GLM + obj_type=layout_label_to_ds_type.get(element.label), + name=element.label, + prov=[ + Prov( + bbox=target_bbox, + page=element.page_no + 1, + span=[0, 0], + ) + ], ) ) diff --git a/docling/models/layout_model.py b/docling/models/layout_model.py index b0536b6d..02e2b85f 100644 --- a/docling/models/layout_model.py +++ b/docling/models/layout_model.py @@ -46,6 +46,8 @@ class LayoutModel(BasePageModel): FIGURE_LABEL = DocItemLabel.PICTURE FORMULA_LABEL = DocItemLabel.FORMULA + CONTAINER_LABELS = [DocItemLabel.FORM, DocItemLabel.KEY_VALUE_REGION] + def __init__(self, artifacts_path: Path): self.layout_predictor = LayoutPredictor(artifacts_path) # TODO temporary @@ -177,24 +179,6 @@ class LayoutModel(BasePageModel): cells=[], ) clusters.append(cluster) - # - # # Map cells to clusters - # # TODO: Remove, postprocess should take care of it anyway. - # for cell in page.cells: - # for cluster in clusters: - # if not cell.bbox.area() > 0: - # overlap_frac = 0.0 - # else: - # overlap_frac = ( - # cell.bbox.intersection_area_with(cluster.bbox) - # / cell.bbox.area() - # ) - # - # if overlap_frac > 0.5: - # cluster.cells.append(cell) - - # Pre-sort clusters - # clusters = self.sort_clusters_by_cell_order(clusters) # DEBUG code: def draw_clusters_and_cells( @@ -299,18 +283,6 @@ class LayoutModel(BasePageModel): clusters=processed_clusters ) - # def check_for_overlaps(clusters): - # for i, cluster in enumerate(clusters): - # for j, other_cluster in enumerate(clusters): - # if i >= j or cluster.label in [DocItemLabel.FORM, DocItemLabel.KEY_VALUE_REGION, DocItemLabel.PICTURE]: - # continue - # - # overlap_area = cluster.bbox.intersection_area_with(other_cluster.bbox) - # if overlap_area > 0: - # print(f"Overlap detected between cluster {i} and cluster {j}") - # print(f"Cluster {i} bbox: {cluster.bbox}, Cluster {j} bbox: {other_cluster.bbox}") - # check_for_overlaps(processed_clusters) - if settings.debug.visualize_layout: self.draw_clusters_and_cells_side_by_side( conv_res, page, processed_clusters, mode_prefix="postprocessed" diff --git a/docling/models/page_assemble_model.py b/docling/models/page_assemble_model.py index 9b064ead..4c27400f 100644 --- a/docling/models/page_assemble_model.py +++ b/docling/models/page_assemble_model.py @@ -6,6 +6,7 @@ from pydantic import BaseModel from docling.datamodel.base_models import ( AssembledUnit, + ContainerElement, FigureElement, Page, PageElement, @@ -159,6 +160,15 @@ class PageAssembleModel(BasePageModel): ) elements.append(equation) body.append(equation) + elif cluster.label in LayoutModel.CONTAINER_LABELS: + container_el = ContainerElement( + label=cluster.label, + id=cluster.id, + page_no=page.page_no, + cluster=cluster, + ) + elements.append(container_el) + body.append(container_el) page.assembled = AssembledUnit( elements=elements, headers=headers, body=body diff --git a/docling/utils/layout_postprocessor.py b/docling/utils/layout_postprocessor.py index daa47469..802e7f82 100644 --- a/docling/utils/layout_postprocessor.py +++ b/docling/utils/layout_postprocessor.py @@ -1,5 +1,6 @@ import bisect import logging +import sys from collections import defaultdict from typing import Dict, List, Set, Tuple @@ -279,7 +280,16 @@ class LayoutPostprocessor: contained.append(cluster) if contained: + # Sort contained clusters by minimum cell ID + contained.sort( + key=lambda cluster: ( + min(cell.id for cell in cluster.cells) + if cluster.cells + else sys.maxsize + ) + ) special.children = contained + # Adjust bbox only for wrapper types if special.label in self.WRAPPER_TYPES: special.bbox = BoundingBox(