added verification of input cells

Signed-off-by: Peter Staar <taa@zurich.ibm.com>
This commit is contained in:
Peter Staar 2024-08-28 09:09:06 +02:00
parent 3dbd6781df
commit 0d4fd90036

View File

@ -19,13 +19,31 @@ def get_pdf_paths():
return pdf_files
def verify_json(doc_pred_json, doc_true_json):
def verify_cells(doc_pred_json, doc_true_json):
if doc_pred_json.keys() != doc_true_json.keys():
return False
assert len(doc_pred_json["input"]["pages"])==len(doc_true_json["input"]["pages"]), \
"pred- and true-doc do not have the same number of pages"
if doc_pred_json["output"].keys() != doc_true_json["output"].keys():
return False
for pid, page_true_item in enumerate(doc_true_json["input"]["pages"]):
for cid, cell_true_item in enumerate(page_true_json["cells"]):
cell_pred_item = doc_pred_json["input"]["pages"][pid]["cells"][cid]
true_text = cell_true_item["text"]
pred_text = cell_pred_item["text"]
assert true_text==pred_text, f"{true_text}!={pred_text}"
for _ in ["t", "b", "l", "r"]:
true_val = np.round(cell_true_item["bbox"][_])
pred_val = np.round(cell_pred_item["bbox"][_])
assert pred_val==true_val, f"bbox for {_} is not the same: {true_val} != {pred_val}"
return True
def verify_maintext(doc_pred_json, doc_true_json):
for l, true_item in enumerate(doc_true_json["output"]["main_text"]):
if "text" in true_item:
@ -35,6 +53,8 @@ def verify_json(doc_pred_json, doc_true_json):
assert "text" in pred_item, f"`text` is in {pred_item}"
assert true_item["text"] == pred_item["text"]
def verify_tables(doc_pred_json, doc_true_json):
for l, true_item in enumerate(doc_true_json["output"]["tables"]):
if "data" in true_item:
@ -63,6 +83,20 @@ def verify_json(doc_pred_json, doc_true_json):
return True
def verify_json(doc_pred_json, doc_true_json):
if doc_pred_json.keys() != doc_true_json.keys():
return False
if doc_pred_json["output"].keys() != doc_true_json["output"].keys():
return False
assert verify_maintext(doc_pred_json, doc_true_json), "verify_maintext(doc_pred_json, doc_true_json)"
assert verify_tables(doc_pred_json, doc_true_json), "verify_tables(doc_pred_json, doc_true_json)"
return True
def verify_md(doc_pred_md, doc_true_md):
return doc_pred_md == doc_true_md