diff --git a/tests/test_toplevel_functions.py b/tests/test_toplevel_functions.py index 0196227b..8254914a 100644 --- a/tests/test_toplevel_functions.py +++ b/tests/test_toplevel_functions.py @@ -19,13 +19,31 @@ def get_pdf_paths(): return pdf_files -def verify_json(doc_pred_json, doc_true_json): +def verify_cells(doc_pred_json, doc_true_json): - if doc_pred_json.keys() != doc_true_json.keys(): - return False + assert len(doc_pred_json["input"]["pages"])==len(doc_true_json["input"]["pages"]), \ + "pred- and true-doc do not have the same number of pages" + + for pid, page_true_item in enumerate(doc_true_json["input"]["pages"]): + + for cid, cell_true_item in enumerate(page_true_json["cells"]): - if doc_pred_json["output"].keys() != doc_true_json["output"].keys(): - return False + cell_pred_item = doc_pred_json["input"]["pages"][pid]["cells"][cid] + + true_text = cell_true_item["text"] + pred_text = cell_pred_item["text"] + + assert true_text==pred_text, f"{true_text}!={pred_text}" + + for _ in ["t", "b", "l", "r"]: + true_val = np.round(cell_true_item["bbox"][_]) + pred_val = np.round(cell_pred_item["bbox"][_]) + + assert pred_val==true_val, f"bbox for {_} is not the same: {true_val} != {pred_val}" + + return True + +def verify_maintext(doc_pred_json, doc_true_json): for l, true_item in enumerate(doc_true_json["output"]["main_text"]): if "text" in true_item: @@ -35,6 +53,8 @@ def verify_json(doc_pred_json, doc_true_json): assert "text" in pred_item, f"`text` is in {pred_item}" assert true_item["text"] == pred_item["text"] +def verify_tables(doc_pred_json, doc_true_json): + for l, true_item in enumerate(doc_true_json["output"]["tables"]): if "data" in true_item: @@ -62,6 +82,20 @@ def verify_json(doc_pred_json, doc_true_json): return True + +def verify_json(doc_pred_json, doc_true_json): + + if doc_pred_json.keys() != doc_true_json.keys(): + return False + + if doc_pred_json["output"].keys() != doc_true_json["output"].keys(): + return False + + assert verify_maintext(doc_pred_json, doc_true_json), "verify_maintext(doc_pred_json, doc_true_json)" + + assert verify_tables(doc_pred_json, doc_true_json), "verify_tables(doc_pred_json, doc_true_json)" + + return True def verify_md(doc_pred_md, doc_true_md): return doc_pred_md == doc_true_md