mirror of
https://github.com/DS4SD/docling.git
synced 2025-07-27 04:24:45 +00:00
feat(tests): Introduce fuzzy text comparison for OCR tests based on Levenshtein edit distance
Signed-off-by: Nikos Livathinos <nli@zurich.ibm.com>
This commit is contained in:
parent
544f298fb4
commit
49652eec54
@ -97,8 +97,5 @@ def test_e2e_conversions():
|
|||||||
doc_result=doc_result,
|
doc_result=doc_result,
|
||||||
generate=GENERATE,
|
generate=GENERATE,
|
||||||
ocr_engine=ocr_options.kind,
|
ocr_engine=ocr_options.kind,
|
||||||
|
fuzzy=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# if __name__ == "__main__":
|
|
||||||
# test_e2e_conversions()
|
|
||||||
|
@ -11,7 +11,43 @@ from docling.datamodel.base_models import ConversionStatus, Page
|
|||||||
from docling.datamodel.document import ConversionResult
|
from docling.datamodel.document import ConversionResult
|
||||||
|
|
||||||
|
|
||||||
def verify_cells(doc_pred_pages: List[Page], doc_true_pages: List[Page]):
|
def levenshtein(str1: str, str2: str) -> int:
|
||||||
|
|
||||||
|
# Ensure str1 is the shorter string to optimize memory usage
|
||||||
|
if len(str1) > len(str2):
|
||||||
|
str1, str2 = str2, str1
|
||||||
|
|
||||||
|
# Previous and current row buffers
|
||||||
|
previous_row = list(range(len(str2) + 1))
|
||||||
|
current_row = [0] * (len(str2) + 1)
|
||||||
|
|
||||||
|
# Compute the Levenshtein distance row by row
|
||||||
|
for i, c1 in enumerate(str1, start=1):
|
||||||
|
current_row[0] = i
|
||||||
|
for j, c2 in enumerate(str2, start=1):
|
||||||
|
insertions = previous_row[j] + 1
|
||||||
|
deletions = current_row[j - 1] + 1
|
||||||
|
substitutions = previous_row[j - 1] + (c1 != c2)
|
||||||
|
current_row[j] = min(insertions, deletions, substitutions)
|
||||||
|
# Swap rows for the next iteration
|
||||||
|
previous_row, current_row = current_row, previous_row
|
||||||
|
|
||||||
|
# The result is in the last element of the previous row
|
||||||
|
return previous_row[-1]
|
||||||
|
|
||||||
|
|
||||||
|
def verify_text(gt: str, pred: str, fuzzy: bool, fuzzy_threshold: float = 0.4):
|
||||||
|
|
||||||
|
if len(gt) == 0 or not fuzzy:
|
||||||
|
assert gt == pred, f"{gt}!={pred}"
|
||||||
|
else:
|
||||||
|
dist = levenshtein(gt, pred)
|
||||||
|
diff = dist / len(gt)
|
||||||
|
assert diff < fuzzy_threshold, f"{gt}!~{pred}"
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def verify_cells(doc_pred_pages: List[Page], doc_true_pages: List[Page], fuzzy: bool):
|
||||||
|
|
||||||
assert len(doc_pred_pages) == len(
|
assert len(doc_pred_pages) == len(
|
||||||
doc_true_pages
|
doc_true_pages
|
||||||
@ -32,8 +68,7 @@ def verify_cells(doc_pred_pages: List[Page], doc_true_pages: List[Page]):
|
|||||||
|
|
||||||
true_text = cell_true_item.text
|
true_text = cell_true_item.text
|
||||||
pred_text = cell_pred_item.text
|
pred_text = cell_pred_item.text
|
||||||
|
verify_text(true_text, pred_text, fuzzy)
|
||||||
assert true_text == pred_text, f"{true_text}!={pred_text}"
|
|
||||||
|
|
||||||
true_bbox = cell_true_item.bbox.as_tuple()
|
true_bbox = cell_true_item.bbox.as_tuple()
|
||||||
pred_bbox = cell_pred_item.bbox.as_tuple()
|
pred_bbox = cell_pred_item.bbox.as_tuple()
|
||||||
@ -69,7 +104,7 @@ def verify_maintext(doc_pred: DsDocument, doc_true: DsDocument):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def verify_tables(doc_pred: DsDocument, doc_true: DsDocument):
|
def verify_tables(doc_pred: DsDocument, doc_true: DsDocument, fuzzy: bool):
|
||||||
if doc_true.tables is None:
|
if doc_true.tables is None:
|
||||||
# No tables to check
|
# No tables to check
|
||||||
assert doc_pred.tables is None, "not expecting any table on this document"
|
assert doc_pred.tables is None, "not expecting any table on this document"
|
||||||
@ -102,9 +137,7 @@ def verify_tables(doc_pred: DsDocument, doc_true: DsDocument):
|
|||||||
# print("pred: ", pred_item.data[i][j].text)
|
# print("pred: ", pred_item.data[i][j].text)
|
||||||
# print("")
|
# print("")
|
||||||
|
|
||||||
assert (
|
verify_text(true_item.data[i][j].text, pred_item.data[i][j].text, fuzzy)
|
||||||
true_item.data[i][j].text == pred_item.data[i][j].text
|
|
||||||
), "table-cell does not have the same text"
|
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
true_item.data[i][j].obj_type == pred_item.data[i][j].obj_type
|
true_item.data[i][j].obj_type == pred_item.data[i][j].obj_type
|
||||||
@ -121,12 +154,12 @@ def verify_output(doc_pred: DsDocument, doc_true: DsDocument):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def verify_md(doc_pred_md, doc_true_md):
|
def verify_md(doc_pred_md: str, doc_true_md: str, fuzzy: bool):
|
||||||
return doc_pred_md == doc_true_md
|
return verify_text(doc_true_md, doc_pred_md, fuzzy)
|
||||||
|
|
||||||
|
|
||||||
def verify_dt(doc_pred_dt, doc_true_dt):
|
def verify_dt(doc_pred_dt: str, doc_true_dt: str, fuzzy: bool):
|
||||||
return doc_pred_dt == doc_true_dt
|
return verify_text(doc_true_dt, doc_pred_dt, fuzzy)
|
||||||
|
|
||||||
|
|
||||||
def verify_conversion_result(
|
def verify_conversion_result(
|
||||||
@ -134,6 +167,7 @@ def verify_conversion_result(
|
|||||||
doc_result: ConversionResult,
|
doc_result: ConversionResult,
|
||||||
generate=False,
|
generate=False,
|
||||||
ocr_engine=None,
|
ocr_engine=None,
|
||||||
|
fuzzy: bool = False,
|
||||||
):
|
):
|
||||||
PageList = TypeAdapter(List[Page])
|
PageList = TypeAdapter(List[Page])
|
||||||
|
|
||||||
@ -146,11 +180,6 @@ def verify_conversion_result(
|
|||||||
doc_pred_md = doc_result.render_as_markdown()
|
doc_pred_md = doc_result.render_as_markdown()
|
||||||
doc_pred_dt = doc_result.render_as_doctags()
|
doc_pred_dt = doc_result.render_as_doctags()
|
||||||
|
|
||||||
# pages_path = input_path.with_suffix(".pages.json")
|
|
||||||
# json_path = input_path.with_suffix(".json")
|
|
||||||
# md_path = input_path.with_suffix(".md")
|
|
||||||
# dt_path = input_path.with_suffix(".doctags.txt")
|
|
||||||
|
|
||||||
engine_suffix = "" if ocr_engine is None else f".{ocr_engine}"
|
engine_suffix = "" if ocr_engine is None else f".{ocr_engine}"
|
||||||
pages_path = input_path.with_suffix(f"{engine_suffix}.pages.json")
|
pages_path = input_path.with_suffix(f"{engine_suffix}.pages.json")
|
||||||
json_path = input_path.with_suffix(f"{engine_suffix}.json")
|
json_path = input_path.with_suffix(f"{engine_suffix}.json")
|
||||||
@ -183,7 +212,7 @@ def verify_conversion_result(
|
|||||||
doc_true_dt = fr.read()
|
doc_true_dt = fr.read()
|
||||||
|
|
||||||
assert verify_cells(
|
assert verify_cells(
|
||||||
doc_pred_pages, doc_true_pages
|
doc_pred_pages, doc_true_pages, fuzzy
|
||||||
), f"Mismatch in PDF cell prediction for {input_path}"
|
), f"Mismatch in PDF cell prediction for {input_path}"
|
||||||
|
|
||||||
# assert verify_output(
|
# assert verify_output(
|
||||||
@ -191,13 +220,13 @@ def verify_conversion_result(
|
|||||||
# ), f"Mismatch in JSON prediction for {input_path}"
|
# ), f"Mismatch in JSON prediction for {input_path}"
|
||||||
|
|
||||||
assert verify_tables(
|
assert verify_tables(
|
||||||
doc_pred, doc_true
|
doc_pred, doc_true, fuzzy
|
||||||
), f"verify_tables(doc_pred, doc_true) mismatch for {input_path}"
|
), f"verify_tables(doc_pred, doc_true) mismatch for {input_path}"
|
||||||
|
|
||||||
assert verify_md(
|
assert verify_md(
|
||||||
doc_pred_md, doc_true_md
|
doc_pred_md, doc_true_md, fuzzy
|
||||||
), f"Mismatch in Markdown prediction for {input_path}"
|
), f"Mismatch in Markdown prediction for {input_path}"
|
||||||
|
|
||||||
assert verify_dt(
|
assert verify_dt(
|
||||||
doc_pred_dt, doc_true_dt
|
doc_pred_dt, doc_true_dt, fuzzy
|
||||||
), f"Mismatch in DocTags prediction for {input_path}"
|
), f"Mismatch in DocTags prediction for {input_path}"
|
||||||
|
Loading…
Reference in New Issue
Block a user