Spaces:
Running
Running
import os | |
import json | |
import argparse | |
import nltk | |
from tqdm import tqdm | |
from eval.parallel import parallel_process | |
def evaluate(pred, gt): | |
pred = sorted(pred, key=lambda x: (x[0], x[1])) | |
gt = sorted(gt, key=lambda x: (x[0], x[1])) | |
if pred == gt: | |
return 1 | |
else: | |
return 0 | |
def main(): | |
parser = argparse.ArgumentParser(description="Evaluate element_merge_detect task") | |
parser.add_argument( | |
"workspace", | |
help="The filesystem path where work will be stored, can be a local folder", | |
) | |
parser.add_argument( | |
"--gt_file", | |
help="Ground truth file", | |
) | |
parser.add_argument("--n_jobs", type=int, default=40, help="Number of jobs to run in parallel") | |
args = parser.parse_args() | |
pred_data = {} | |
root_dir = os.path.join(args.workspace, "results") | |
for jsonl_file in os.listdir(root_dir): | |
if jsonl_file.endswith(".jsonl"): | |
with open(os.path.join(root_dir, jsonl_file), "r") as f: | |
for line in f: | |
data = json.loads(line) | |
pred_data[os.path.basename(data['orig_path'])] = data['merge_pairs'] | |
filename_list_en = [] | |
filename_list_zh = [] | |
gt_data = {} | |
with open(args.gt_file, "r") as f: | |
for line in f: | |
data = json.loads(line) | |
pdf_name_1 = data['pdf_name_1'].split(".")[0] | |
pdf_name_2 = data['pdf_name_2'].split(".")[0] | |
pdf_name,page_1 = pdf_name_1.split('_') | |
pdf_name,page_2 = pdf_name_2.split('_') | |
json_name = pdf_name + '_' + page_1 + '_' + page_2 + '.json' | |
gt_data[json_name] = data['merging_idx_pairs'] | |
if data['language'] == 'en': | |
filename_list_en.append(json_name) | |
else: | |
filename_list_zh.append(json_name) | |
keys = list(gt_data.keys()) | |
if args.n_jobs == 1: | |
scores = [evaluate(pred_data.get(filename, []), gt_data.get(filename, [])) for filename in tqdm(keys)] | |
else: | |
inputs = [{'pred': pred_data.get(filename, []), 'gt': gt_data.get(filename, [])} for filename in keys] | |
scores = parallel_process(inputs, evaluate, use_kwargs=True, n_jobs=args.n_jobs, front_num=1) | |
tp_en = 0 | |
tn_en = 0 | |
fp_en = 0 | |
fn_en = 0 | |
tp_zh = 0 | |
tn_zh = 0 | |
fp_zh = 0 | |
fn_zh = 0 | |
score_en = 0 | |
score_zh = 0 | |
num_en = 0 | |
num_zh = 0 | |
for filename, score in zip(keys, scores): | |
print(filename) | |
print(score) | |
print() | |
pred_label = pred_data[filename] | |
if filename in filename_list_en: | |
if pred_label == []: | |
if score == 1: | |
tn_en += 1 | |
else: | |
fn_en += 1 | |
else: | |
if score == 1: | |
tp_en += 1 | |
else: | |
fp_en += 1 | |
score_en += score | |
num_en += 1 | |
elif filename in filename_list_zh: | |
if pred_label == []: | |
if score == 1: | |
tn_zh += 1 | |
else: | |
fn_zh += 1 | |
else: | |
if score == 1: | |
tp_zh += 1 | |
else: | |
fp_zh += 1 | |
score_zh += score | |
num_zh += 1 | |
precision_en = tp_en / (tp_en + fp_en) | |
recall_en = tp_en / (tp_en + fn_en) | |
f1_en = 2*precision_en*recall_en / (precision_en+recall_en) | |
acc_en = score_en / num_en | |
precision_zh = tp_zh / (tp_zh + fp_zh) | |
recall_zh = tp_zh / (tp_zh + fn_zh) | |
f1_zh = 2*precision_zh*recall_zh / (precision_zh+recall_zh) | |
acc_zh = score_zh / num_zh | |
tp = tp_en + tp_zh | |
fp = fp_en + fp_zh | |
fn = fn_en + fn_zh | |
score = score_en + score_zh | |
num = num_en + num_zh | |
precision = tp / (tp + fp) | |
recall = tp / (tp + fn) | |
f1 = 2*precision*recall / (precision+recall) | |
acc = score / num | |
print(f"EN: {precision_en} / {recall_en} / {f1_en} / {acc_en}") | |
print(f"ZH: {precision_zh} / {recall_zh} / {f1_zh} / {acc_zh}") | |
print(f"ALL: {precision} / {recall} / {f1} / {acc}") | |
if __name__ == "__main__": | |
main() |