OCRFlux / eval /eval_element_merge_detect.py
mirnaresearch's picture
Initial commit for HF Space (no images)
ca5b08e
import os
import json
import argparse
import nltk
from tqdm import tqdm
from eval.parallel import parallel_process
def evaluate(pred, gt):
pred = sorted(pred, key=lambda x: (x[0], x[1]))
gt = sorted(gt, key=lambda x: (x[0], x[1]))
if pred == gt:
return 1
else:
return 0
def main():
parser = argparse.ArgumentParser(description="Evaluate element_merge_detect task")
parser.add_argument(
"workspace",
help="The filesystem path where work will be stored, can be a local folder",
)
parser.add_argument(
"--gt_file",
help="Ground truth file",
)
parser.add_argument("--n_jobs", type=int, default=40, help="Number of jobs to run in parallel")
args = parser.parse_args()
pred_data = {}
root_dir = os.path.join(args.workspace, "results")
for jsonl_file in os.listdir(root_dir):
if jsonl_file.endswith(".jsonl"):
with open(os.path.join(root_dir, jsonl_file), "r") as f:
for line in f:
data = json.loads(line)
pred_data[os.path.basename(data['orig_path'])] = data['merge_pairs']
filename_list_en = []
filename_list_zh = []
gt_data = {}
with open(args.gt_file, "r") as f:
for line in f:
data = json.loads(line)
pdf_name_1 = data['pdf_name_1'].split(".")[0]
pdf_name_2 = data['pdf_name_2'].split(".")[0]
pdf_name,page_1 = pdf_name_1.split('_')
pdf_name,page_2 = pdf_name_2.split('_')
json_name = pdf_name + '_' + page_1 + '_' + page_2 + '.json'
gt_data[json_name] = data['merging_idx_pairs']
if data['language'] == 'en':
filename_list_en.append(json_name)
else:
filename_list_zh.append(json_name)
keys = list(gt_data.keys())
if args.n_jobs == 1:
scores = [evaluate(pred_data.get(filename, []), gt_data.get(filename, [])) for filename in tqdm(keys)]
else:
inputs = [{'pred': pred_data.get(filename, []), 'gt': gt_data.get(filename, [])} for filename in keys]
scores = parallel_process(inputs, evaluate, use_kwargs=True, n_jobs=args.n_jobs, front_num=1)
tp_en = 0
tn_en = 0
fp_en = 0
fn_en = 0
tp_zh = 0
tn_zh = 0
fp_zh = 0
fn_zh = 0
score_en = 0
score_zh = 0
num_en = 0
num_zh = 0
for filename, score in zip(keys, scores):
print(filename)
print(score)
print()
pred_label = pred_data[filename]
if filename in filename_list_en:
if pred_label == []:
if score == 1:
tn_en += 1
else:
fn_en += 1
else:
if score == 1:
tp_en += 1
else:
fp_en += 1
score_en += score
num_en += 1
elif filename in filename_list_zh:
if pred_label == []:
if score == 1:
tn_zh += 1
else:
fn_zh += 1
else:
if score == 1:
tp_zh += 1
else:
fp_zh += 1
score_zh += score
num_zh += 1
precision_en = tp_en / (tp_en + fp_en)
recall_en = tp_en / (tp_en + fn_en)
f1_en = 2*precision_en*recall_en / (precision_en+recall_en)
acc_en = score_en / num_en
precision_zh = tp_zh / (tp_zh + fp_zh)
recall_zh = tp_zh / (tp_zh + fn_zh)
f1_zh = 2*precision_zh*recall_zh / (precision_zh+recall_zh)
acc_zh = score_zh / num_zh
tp = tp_en + tp_zh
fp = fp_en + fp_zh
fn = fn_en + fn_zh
score = score_en + score_zh
num = num_en + num_zh
precision = tp / (tp + fp)
recall = tp / (tp + fn)
f1 = 2*precision*recall / (precision+recall)
acc = score / num
print(f"EN: {precision_en} / {recall_en} / {f1_en} / {acc_en}")
print(f"ZH: {precision_zh} / {recall_zh} / {f1_zh} / {acc_zh}")
print(f"ALL: {precision} / {recall} / {f1} / {acc}")
if __name__ == "__main__":
main()