File size: 4,187 Bytes
ca5b08e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import os
import json
import argparse
import nltk
from tqdm import tqdm
from eval.parallel import parallel_process


def evaluate(pred, gt):
    pred = sorted(pred, key=lambda x: (x[0], x[1]))
    gt = sorted(gt, key=lambda x: (x[0], x[1]))
    if pred == gt:
        return 1
    else:
        return 0

def main():
    parser = argparse.ArgumentParser(description="Evaluate element_merge_detect task")
    parser.add_argument(
        "workspace",
        help="The filesystem path where work will be stored, can be a local folder",
    )
    parser.add_argument(
        "--gt_file",
        help="Ground truth file",
    )
    parser.add_argument("--n_jobs", type=int, default=40, help="Number of jobs to run in parallel")
    args = parser.parse_args()
    
    pred_data = {}
    root_dir = os.path.join(args.workspace, "results")
    for jsonl_file in os.listdir(root_dir):
        if jsonl_file.endswith(".jsonl"):
            with open(os.path.join(root_dir, jsonl_file), "r") as f:
                for line in f:
                    data = json.loads(line)
                    pred_data[os.path.basename(data['orig_path'])] = data['merge_pairs']

    filename_list_en = []
    filename_list_zh = []
    gt_data = {}
    with open(args.gt_file, "r") as f:
        for line in f:
            data = json.loads(line)
            pdf_name_1 = data['pdf_name_1'].split(".")[0]
            pdf_name_2 = data['pdf_name_2'].split(".")[0]

            pdf_name,page_1 = pdf_name_1.split('_')
            pdf_name,page_2 = pdf_name_2.split('_')

            json_name = pdf_name + '_' + page_1 + '_' + page_2 + '.json'
            gt_data[json_name] = data['merging_idx_pairs']
            
            if data['language'] == 'en':
                filename_list_en.append(json_name)
            else:
                filename_list_zh.append(json_name)

    keys = list(gt_data.keys())
    if args.n_jobs == 1:
        scores = [evaluate(pred_data.get(filename, []), gt_data.get(filename, [])) for filename in tqdm(keys)]
    else:
        inputs = [{'pred': pred_data.get(filename, []), 'gt': gt_data.get(filename, [])} for filename in keys]
        scores = parallel_process(inputs, evaluate, use_kwargs=True, n_jobs=args.n_jobs, front_num=1)

    tp_en = 0
    tn_en = 0
    fp_en = 0
    fn_en = 0
    tp_zh = 0
    tn_zh = 0
    fp_zh = 0
    fn_zh = 0
    score_en = 0
    score_zh = 0
    num_en = 0
    num_zh = 0
    for filename, score in zip(keys, scores):
        print(filename)
        print(score)
        print()
        pred_label = pred_data[filename]
        if filename in filename_list_en:
            if pred_label == []:
                if score == 1:
                    tn_en += 1
                else:
                    fn_en += 1
            else:
                if score == 1:
                    tp_en += 1
                else:
                    fp_en += 1 
            score_en += score
            num_en += 1
        
        elif filename in filename_list_zh:
            if pred_label == []:
                if score == 1:
                    tn_zh += 1
                else:
                    fn_zh += 1
            else:
                if score == 1:
                    tp_zh += 1
                else:
                    fp_zh += 1
            score_zh += score
            num_zh += 1

    precision_en = tp_en / (tp_en + fp_en)
    recall_en = tp_en / (tp_en + fn_en)
    f1_en = 2*precision_en*recall_en / (precision_en+recall_en)
    acc_en = score_en / num_en

    precision_zh = tp_zh / (tp_zh + fp_zh)
    recall_zh = tp_zh / (tp_zh + fn_zh)
    f1_zh = 2*precision_zh*recall_zh / (precision_zh+recall_zh)
    acc_zh = score_zh / num_zh

    tp = tp_en + tp_zh
    fp = fp_en + fp_zh
    fn = fn_en + fn_zh
    score = score_en + score_zh
    num = num_en + num_zh
    
    precision = tp / (tp + fp)
    recall =  tp / (tp + fn)
    f1 = 2*precision*recall / (precision+recall)
    acc = score / num

    print(f"EN: {precision_en} / {recall_en} / {f1_en} / {acc_en}")
    print(f"ZH: {precision_zh} / {recall_zh} / {f1_zh} / {acc_zh}")
    print(f"ALL: {precision} / {recall} / {f1} / {acc}")

if __name__ == "__main__":
    main()