mirror of
https://github.com/DS4SD/docling.git
synced 2025-12-08 12:48:28 +00:00
Add monkey patch to fix KeyError in reading order model
Co-authored-by: cau-git <60343111+cau-git@users.noreply.github.com>
This commit is contained in:
112
docling/models/_reading_order_patch.py
Normal file
112
docling/models/_reading_order_patch.py
Normal file
@@ -0,0 +1,112 @@
|
|||||||
|
"""
|
||||||
|
Monkey patch for docling_ibm_models.reading_order.reading_order_rb module.
|
||||||
|
|
||||||
|
This module patches the _init_ud_maps method to add defensive checks
|
||||||
|
and prevent KeyError when accessing dn_map and up_map dictionaries.
|
||||||
|
|
||||||
|
The issue occurs when following the l2r_map chain results in an index
|
||||||
|
that doesn't exist in the dn_map dictionary. This can happen when
|
||||||
|
maps are reinitialized with different element lists.
|
||||||
|
|
||||||
|
This patch will be applied when the readingorder_model is imported.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, List
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from docling_ibm_models.reading_order.reading_order_rb import (
|
||||||
|
PageElement,
|
||||||
|
ReadingOrderPredictor,
|
||||||
|
_ReadingOrderPredictorState,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _patched_init_ud_maps(
|
||||||
|
self: "ReadingOrderPredictor",
|
||||||
|
page_elems: List["PageElement"],
|
||||||
|
state: "_ReadingOrderPredictorState",
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Patched version of _init_ud_maps with defensive checks.
|
||||||
|
|
||||||
|
Initialize up/down maps for reading order prediction using R-tree spatial indexing.
|
||||||
|
|
||||||
|
Uses R-tree for spatial queries.
|
||||||
|
Determines linear reading sequence by finding preceding/following elements.
|
||||||
|
"""
|
||||||
|
from rtree import index as rtree_index
|
||||||
|
|
||||||
|
state.up_map = {}
|
||||||
|
state.dn_map = {}
|
||||||
|
|
||||||
|
for i, pelem_i in enumerate(page_elems):
|
||||||
|
state.up_map[i] = []
|
||||||
|
state.dn_map[i] = []
|
||||||
|
|
||||||
|
# Build R-tree spatial index
|
||||||
|
spatial_idx = rtree_index.Index()
|
||||||
|
for i, pelem in enumerate(page_elems):
|
||||||
|
spatial_idx.insert(i, (pelem.l, pelem.b, pelem.r, pelem.t))
|
||||||
|
|
||||||
|
for j, pelem_j in enumerate(page_elems):
|
||||||
|
if j in state.r2l_map:
|
||||||
|
i = state.r2l_map[j]
|
||||||
|
# Defensive check: ensure i exists in dn_map
|
||||||
|
if i in state.dn_map and j in state.up_map:
|
||||||
|
state.dn_map[i] = [j]
|
||||||
|
state.up_map[j] = [i]
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Find elements above current that might precede it in reading order
|
||||||
|
query_bbox = (pelem_j.l - 0.1, pelem_j.t, pelem_j.r + 0.1, float("inf"))
|
||||||
|
candidates = list(spatial_idx.intersection(query_bbox))
|
||||||
|
|
||||||
|
for i in candidates:
|
||||||
|
if i == j:
|
||||||
|
continue
|
||||||
|
|
||||||
|
pelem_i = page_elems[i]
|
||||||
|
|
||||||
|
# Check spatial relationship
|
||||||
|
if not (
|
||||||
|
pelem_i.is_strictly_above(pelem_j)
|
||||||
|
and pelem_i.overlaps_horizontally(pelem_j)
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check for interrupting elements
|
||||||
|
if not self._has_sequence_interruption(
|
||||||
|
spatial_idx, page_elems, i, j, pelem_i, pelem_j
|
||||||
|
):
|
||||||
|
# Follow left-to-right mapping
|
||||||
|
original_i = i
|
||||||
|
while i in state.l2r_map:
|
||||||
|
i = state.l2r_map[i]
|
||||||
|
# Defensive check: prevent infinite loops
|
||||||
|
if i == original_i:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Defensive check: ensure i and j exist in the maps before accessing
|
||||||
|
if i in state.dn_map and j in state.up_map:
|
||||||
|
state.dn_map[i].append(j)
|
||||||
|
state.up_map[j].append(i)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_patch() -> None:
|
||||||
|
"""Apply the monkey patch to ReadingOrderPredictor._init_ud_maps."""
|
||||||
|
try:
|
||||||
|
from docling_ibm_models.reading_order.reading_order_rb import (
|
||||||
|
ReadingOrderPredictor,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store original method for reference
|
||||||
|
if not hasattr(ReadingOrderPredictor, "_original_init_ud_maps"):
|
||||||
|
ReadingOrderPredictor._original_init_ud_maps = ( # type: ignore
|
||||||
|
ReadingOrderPredictor._init_ud_maps
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply the patch
|
||||||
|
ReadingOrderPredictor._init_ud_maps = _patched_init_ud_maps # type: ignore
|
||||||
|
except ImportError:
|
||||||
|
# If docling_ibm_models is not installed, silently skip the patch
|
||||||
|
pass
|
||||||
@@ -30,8 +30,12 @@ from docling.datamodel.base_models import (
|
|||||||
TextElement,
|
TextElement,
|
||||||
)
|
)
|
||||||
from docling.datamodel.document import ConversionResult
|
from docling.datamodel.document import ConversionResult
|
||||||
|
from docling.models import _reading_order_patch
|
||||||
from docling.utils.profiling import ProfilingScope, TimeRecorder
|
from docling.utils.profiling import ProfilingScope, TimeRecorder
|
||||||
|
|
||||||
|
# Apply patch to fix KeyError in reading order prediction
|
||||||
|
_reading_order_patch.apply_patch()
|
||||||
|
|
||||||
|
|
||||||
class ReadingOrderOptions(BaseModel):
|
class ReadingOrderOptions(BaseModel):
|
||||||
model_config = ConfigDict(protected_namespaces=())
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
|||||||
236
tests/test_reading_order_patch.py
Normal file
236
tests/test_reading_order_patch.py
Normal file
@@ -0,0 +1,236 @@
|
|||||||
|
"""
|
||||||
|
Unit tests for the reading order patch.
|
||||||
|
|
||||||
|
Tests that the monkey patch for docling_ibm_models.reading_order.reading_order_rb
|
||||||
|
is correctly applied and handles edge cases that could cause KeyError.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from docling_core.types.doc import DocItemLabel
|
||||||
|
from docling_ibm_models.reading_order.reading_order_rb import ReadingOrderPredictor
|
||||||
|
|
||||||
|
|
||||||
|
def test_reading_order_patch_applied():
|
||||||
|
"""Test that the monkey patch was successfully applied."""
|
||||||
|
# Import the readingorder_model to trigger the patch
|
||||||
|
from docling.models.readingorder_model import ReadingOrderModel # noqa: F401
|
||||||
|
|
||||||
|
# Verify the patch was applied
|
||||||
|
assert hasattr(
|
||||||
|
ReadingOrderPredictor, "_original_init_ud_maps"
|
||||||
|
), "Patch was not applied"
|
||||||
|
assert (
|
||||||
|
ReadingOrderPredictor._init_ud_maps.__name__ == "_patched_init_ud_maps"
|
||||||
|
), "Patched method name doesn't match"
|
||||||
|
|
||||||
|
|
||||||
|
def test_reading_order_model_init():
|
||||||
|
"""Test that ReadingOrderModel can be initialized with the patch."""
|
||||||
|
from docling.models.readingorder_model import ReadingOrderModel, ReadingOrderOptions
|
||||||
|
|
||||||
|
options = ReadingOrderOptions()
|
||||||
|
model = ReadingOrderModel(options)
|
||||||
|
assert model is not None
|
||||||
|
assert model.ro_model is not None
|
||||||
|
assert isinstance(model.ro_model, ReadingOrderPredictor)
|
||||||
|
|
||||||
|
|
||||||
|
def test_patched_method_defensive_checks():
|
||||||
|
"""Test that the patched method handles edge cases gracefully."""
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
from docling.models.readingorder_model import ReadingOrderModel, ReadingOrderOptions
|
||||||
|
from docling_core.types.doc import Size
|
||||||
|
from docling_core.types.doc.base import CoordOrigin
|
||||||
|
from docling_ibm_models.reading_order.reading_order_rb import PageElement
|
||||||
|
|
||||||
|
options = ReadingOrderOptions()
|
||||||
|
model = ReadingOrderModel(options)
|
||||||
|
|
||||||
|
# Create a simple test case with page elements
|
||||||
|
@dataclass
|
||||||
|
class MockState:
|
||||||
|
l2r_map: Dict[int, int] = field(default_factory=dict)
|
||||||
|
r2l_map: Dict[int, int] = field(default_factory=dict)
|
||||||
|
up_map: Dict[int, List[int]] = field(default_factory=dict)
|
||||||
|
dn_map: Dict[int, List[int]] = field(default_factory=dict)
|
||||||
|
|
||||||
|
state = MockState()
|
||||||
|
|
||||||
|
# Create some page elements
|
||||||
|
page_elements = [
|
||||||
|
PageElement(
|
||||||
|
cid=0,
|
||||||
|
text="Element 0",
|
||||||
|
page_no=0,
|
||||||
|
page_size=Size(width=612, height=792),
|
||||||
|
l=100,
|
||||||
|
r=200,
|
||||||
|
b=600,
|
||||||
|
t=700,
|
||||||
|
coord_origin=CoordOrigin.BOTTOMLEFT,
|
||||||
|
label=DocItemLabel.TEXT,
|
||||||
|
),
|
||||||
|
PageElement(
|
||||||
|
cid=1,
|
||||||
|
text="Element 1",
|
||||||
|
page_no=0,
|
||||||
|
page_size=Size(width=612, height=792),
|
||||||
|
l=100,
|
||||||
|
r=200,
|
||||||
|
b=500,
|
||||||
|
t=600,
|
||||||
|
coord_origin=CoordOrigin.BOTTOMLEFT,
|
||||||
|
label=DocItemLabel.TEXT,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Test that the patched method can handle the page elements
|
||||||
|
# without raising KeyError
|
||||||
|
try:
|
||||||
|
model.ro_model._init_ud_maps(page_elements, state)
|
||||||
|
except KeyError as e:
|
||||||
|
pytest.fail(f"Patched method raised KeyError: {e}")
|
||||||
|
|
||||||
|
# Verify that the maps were initialized
|
||||||
|
assert len(state.up_map) == 2, "up_map should have 2 entries"
|
||||||
|
assert len(state.dn_map) == 2, "dn_map should have 2 entries"
|
||||||
|
assert 0 in state.up_map, "Element 0 should be in up_map"
|
||||||
|
assert 1 in state.up_map, "Element 1 should be in up_map"
|
||||||
|
assert 0 in state.dn_map, "Element 0 should be in dn_map"
|
||||||
|
assert 1 in state.dn_map, "Element 1 should be in dn_map"
|
||||||
|
|
||||||
|
|
||||||
|
def test_patched_method_with_l2r_map():
|
||||||
|
"""Test that the patched method handles l2r_map chains correctly."""
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
from docling.models.readingorder_model import ReadingOrderModel, ReadingOrderOptions
|
||||||
|
from docling_core.types.doc import Size
|
||||||
|
from docling_core.types.doc.base import CoordOrigin
|
||||||
|
from docling_ibm_models.reading_order.reading_order_rb import PageElement
|
||||||
|
|
||||||
|
options = ReadingOrderOptions()
|
||||||
|
model = ReadingOrderModel(options)
|
||||||
|
|
||||||
|
# Create a simple test case with page elements
|
||||||
|
@dataclass
|
||||||
|
class MockState:
|
||||||
|
l2r_map: Dict[int, int] = field(default_factory=dict)
|
||||||
|
r2l_map: Dict[int, int] = field(default_factory=dict)
|
||||||
|
up_map: Dict[int, List[int]] = field(default_factory=dict)
|
||||||
|
dn_map: Dict[int, List[int]] = field(default_factory=dict)
|
||||||
|
|
||||||
|
state = MockState()
|
||||||
|
|
||||||
|
# Create page elements
|
||||||
|
page_elements = [
|
||||||
|
PageElement(
|
||||||
|
cid=0,
|
||||||
|
text="Element 0",
|
||||||
|
page_no=0,
|
||||||
|
page_size=Size(width=612, height=792),
|
||||||
|
l=100,
|
||||||
|
r=200,
|
||||||
|
b=600,
|
||||||
|
t=700,
|
||||||
|
coord_origin=CoordOrigin.BOTTOMLEFT,
|
||||||
|
label=DocItemLabel.TEXT,
|
||||||
|
),
|
||||||
|
PageElement(
|
||||||
|
cid=1,
|
||||||
|
text="Element 1",
|
||||||
|
page_no=0,
|
||||||
|
page_size=Size(width=612, height=792),
|
||||||
|
l=250,
|
||||||
|
r=350,
|
||||||
|
b=600,
|
||||||
|
t=700,
|
||||||
|
coord_origin=CoordOrigin.BOTTOMLEFT,
|
||||||
|
label=DocItemLabel.TEXT,
|
||||||
|
),
|
||||||
|
PageElement(
|
||||||
|
cid=2,
|
||||||
|
text="Element 2",
|
||||||
|
page_no=0,
|
||||||
|
page_size=Size(width=612, height=792),
|
||||||
|
l=150,
|
||||||
|
r=250,
|
||||||
|
b=500,
|
||||||
|
t=600,
|
||||||
|
coord_origin=CoordOrigin.BOTTOMLEFT,
|
||||||
|
label=DocItemLabel.TEXT,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Set up l2r_map with a chain: 0 -> 1
|
||||||
|
# Note: This is normally set by _init_l2r_map but we set it manually for testing
|
||||||
|
state.l2r_map = {0: 1}
|
||||||
|
|
||||||
|
# Test that the patched method can handle l2r_map chains
|
||||||
|
# without raising KeyError even if the chain points to indices
|
||||||
|
try:
|
||||||
|
model.ro_model._init_ud_maps(page_elements, state)
|
||||||
|
except KeyError as e:
|
||||||
|
pytest.fail(f"Patched method raised KeyError with l2r_map: {e}")
|
||||||
|
|
||||||
|
# Verify maps were initialized
|
||||||
|
assert len(state.up_map) == 3, "up_map should have 3 entries"
|
||||||
|
assert len(state.dn_map) == 3, "dn_map should have 3 entries"
|
||||||
|
|
||||||
|
|
||||||
|
def test_patched_method_with_invalid_r2l_map():
|
||||||
|
"""Test that the patched method handles invalid r2l_map gracefully."""
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
from docling.models.readingorder_model import ReadingOrderModel, ReadingOrderOptions
|
||||||
|
from docling_core.types.doc import Size
|
||||||
|
from docling_core.types.doc.base import CoordOrigin
|
||||||
|
from docling_ibm_models.reading_order.reading_order_rb import PageElement
|
||||||
|
|
||||||
|
options = ReadingOrderOptions()
|
||||||
|
model = ReadingOrderModel(options)
|
||||||
|
|
||||||
|
# Create a simple test case with page elements
|
||||||
|
@dataclass
|
||||||
|
class MockState:
|
||||||
|
l2r_map: Dict[int, int] = field(default_factory=dict)
|
||||||
|
r2l_map: Dict[int, int] = field(default_factory=dict)
|
||||||
|
up_map: Dict[int, List[int]] = field(default_factory=dict)
|
||||||
|
dn_map: Dict[int, List[int]] = field(default_factory=dict)
|
||||||
|
|
||||||
|
state = MockState()
|
||||||
|
|
||||||
|
# Create page elements
|
||||||
|
page_elements = [
|
||||||
|
PageElement(
|
||||||
|
cid=0,
|
||||||
|
text="Element 0",
|
||||||
|
page_no=0,
|
||||||
|
page_size=Size(width=612, height=792),
|
||||||
|
l=100,
|
||||||
|
r=200,
|
||||||
|
b=600,
|
||||||
|
t=700,
|
||||||
|
coord_origin=CoordOrigin.BOTTOMLEFT,
|
||||||
|
label=DocItemLabel.TEXT,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Set up r2l_map with an invalid reference (index 99 doesn't exist)
|
||||||
|
# This simulates the edge case that could cause KeyError
|
||||||
|
state.r2l_map = {0: 99}
|
||||||
|
|
||||||
|
# Test that the patched method handles invalid r2l_map gracefully
|
||||||
|
try:
|
||||||
|
model.ro_model._init_ud_maps(page_elements, state)
|
||||||
|
except KeyError as e:
|
||||||
|
pytest.fail(f"Patched method raised KeyError with invalid r2l_map: {e}")
|
||||||
|
|
||||||
|
# The patched method should skip the invalid mapping
|
||||||
|
# and still initialize the maps correctly
|
||||||
|
assert len(state.up_map) == 1, "up_map should have 1 entry"
|
||||||
|
assert len(state.dn_map) == 1, "dn_map should have 1 entry"
|
||||||
Reference in New Issue
Block a user