Spaces:
Running
Running
File size: 4,187 Bytes
ca5b08e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
import os
import json
import argparse
import nltk
from tqdm import tqdm
from eval.parallel import parallel_process
def evaluate(pred, gt):
pred = sorted(pred, key=lambda x: (x[0], x[1]))
gt = sorted(gt, key=lambda x: (x[0], x[1]))
if pred == gt:
return 1
else:
return 0
def main():
parser = argparse.ArgumentParser(description="Evaluate element_merge_detect task")
parser.add_argument(
"workspace",
help="The filesystem path where work will be stored, can be a local folder",
)
parser.add_argument(
"--gt_file",
help="Ground truth file",
)
parser.add_argument("--n_jobs", type=int, default=40, help="Number of jobs to run in parallel")
args = parser.parse_args()
pred_data = {}
root_dir = os.path.join(args.workspace, "results")
for jsonl_file in os.listdir(root_dir):
if jsonl_file.endswith(".jsonl"):
with open(os.path.join(root_dir, jsonl_file), "r") as f:
for line in f:
data = json.loads(line)
pred_data[os.path.basename(data['orig_path'])] = data['merge_pairs']
filename_list_en = []
filename_list_zh = []
gt_data = {}
with open(args.gt_file, "r") as f:
for line in f:
data = json.loads(line)
pdf_name_1 = data['pdf_name_1'].split(".")[0]
pdf_name_2 = data['pdf_name_2'].split(".")[0]
pdf_name,page_1 = pdf_name_1.split('_')
pdf_name,page_2 = pdf_name_2.split('_')
json_name = pdf_name + '_' + page_1 + '_' + page_2 + '.json'
gt_data[json_name] = data['merging_idx_pairs']
if data['language'] == 'en':
filename_list_en.append(json_name)
else:
filename_list_zh.append(json_name)
keys = list(gt_data.keys())
if args.n_jobs == 1:
scores = [evaluate(pred_data.get(filename, []), gt_data.get(filename, [])) for filename in tqdm(keys)]
else:
inputs = [{'pred': pred_data.get(filename, []), 'gt': gt_data.get(filename, [])} for filename in keys]
scores = parallel_process(inputs, evaluate, use_kwargs=True, n_jobs=args.n_jobs, front_num=1)
tp_en = 0
tn_en = 0
fp_en = 0
fn_en = 0
tp_zh = 0
tn_zh = 0
fp_zh = 0
fn_zh = 0
score_en = 0
score_zh = 0
num_en = 0
num_zh = 0
for filename, score in zip(keys, scores):
print(filename)
print(score)
print()
pred_label = pred_data[filename]
if filename in filename_list_en:
if pred_label == []:
if score == 1:
tn_en += 1
else:
fn_en += 1
else:
if score == 1:
tp_en += 1
else:
fp_en += 1
score_en += score
num_en += 1
elif filename in filename_list_zh:
if pred_label == []:
if score == 1:
tn_zh += 1
else:
fn_zh += 1
else:
if score == 1:
tp_zh += 1
else:
fp_zh += 1
score_zh += score
num_zh += 1
precision_en = tp_en / (tp_en + fp_en)
recall_en = tp_en / (tp_en + fn_en)
f1_en = 2*precision_en*recall_en / (precision_en+recall_en)
acc_en = score_en / num_en
precision_zh = tp_zh / (tp_zh + fp_zh)
recall_zh = tp_zh / (tp_zh + fn_zh)
f1_zh = 2*precision_zh*recall_zh / (precision_zh+recall_zh)
acc_zh = score_zh / num_zh
tp = tp_en + tp_zh
fp = fp_en + fp_zh
fn = fn_en + fn_zh
score = score_en + score_zh
num = num_en + num_zh
precision = tp / (tp + fp)
recall = tp / (tp + fn)
f1 = 2*precision*recall / (precision+recall)
acc = score / num
print(f"EN: {precision_en} / {recall_en} / {f1_en} / {acc_en}")
print(f"ZH: {precision_zh} / {recall_zh} / {f1_zh} / {acc_zh}")
print(f"ALL: {precision} / {recall} / {f1} / {acc}")
if __name__ == "__main__":
main() |