WNJXYK commited on
Commit
22c93a7
·
verified ·
1 Parent(s): fba7b7c

Upload 16 files

Browse files
__init__.py ADDED
File without changes
all_exps.sh ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ for model in InternLM2-Math-Plus-7B Deepseek-Math-RL-7B InternLM2-Math-Plus-1.8B; do
2
+ for method in PPL SC RPC; do
3
+ python main.py --dataset MATH --model $model --method $method --K 64
4
+ done
5
+ for dataset in MathOdyssey AIME OlympiadBench; do
6
+ for method in PPL SC RPC; do
7
+ python main.py --dataset $dataset --model $model --method $method --K 128
8
+ done
9
+ done
10
+ done
app.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import json, os
3
+ from huggingface_hub import hf_hub_download
4
+ from compute_perp import prep_evaluator, numberic_compare, check_equal
5
+ from compute_sc import sc_evaluator
6
+ from compute_rpc import wpc_evaluator
7
+ import numpy as np
8
+
9
+ def greet(name):
10
+ return "Hello " + name + "!!"
11
+
12
+ json_file = {"predict": [], "answer": [], "completion": [], "mean_logprob": [], "prompt": []}
13
+
14
+ demo = gr.Blocks()
15
+ with demo:
16
+ paper_title = gr.HTML("""<div align='center'><h1>[NeurIPS 2025] A Theoretical Study on Bridging Internal Probability and Self-Consistency for LLM Reasoning</h1></div>""")
17
+ paper_info = gr.HTML("""<div align="center"><h3><a href="https://arxiv.org/pdf/2502.00511">📄 [Paper]</a> • <a href="https://wnjxyk.github.io/RPC">🌐 [Project]</a> • <a href="#" onclick="document.getElementById('bibtex-popup').style.display='block';">📚 [BibTeX]</a><h3>
18
+ <div id="bibtex-popup" style="display:none; position:fixed; top:50%; left:50%; transform:translate(-50%, -50%); background:white; padding:20px; border:1px solid #ccc; box-shadow:0 0 10px rgba(0,0,0,0.2); z-index:1000; max-width:80%; overflow:auto;">
19
+ <pre style="white-space:pre-wrap; font-size:12px; text-align:left;">@inproceedings{zhou24theoretical,
20
+ author = {Zhou, Zhi and Tan, Yuhao and Li, Zenan and Yao, Yuan and Guo, Lan-Zhe and Li, Yu-Feng and Ma, Xiaoxing},
21
+ title = {A Theorecial Study on Bridging Internal Probability and Self-Consistency for LLM Reasoning},
22
+ booktitle = {Advances in Neural Information Processing Systems},
23
+ year = {2025},
24
+ }</pre>
25
+ <button onclick="document.getElementById('bibtex-popup').style.display='none';" style="margin-top:10px; padding:5px 10px;">Close</button>
26
+ </div></div>""")
27
+
28
+ with gr.Column():
29
+ gr.Markdown("## 1. Experimental Settings")
30
+ with gr.Row():
31
+ dataset = gr.Dropdown(
32
+ choices=["MATH", "MathOdyssey", "AIME", "OlympiadBench"],
33
+ value="MathOdyssey",
34
+ label="Dataset",
35
+ interactive=True
36
+ )
37
+ model = gr.Dropdown(
38
+ choices=["Deepseek-Math-RL-7B", "InternLM2-Math-Plus-1.8B", "InternLM2-Math-Plus-7B"],
39
+ value="InternLM2-Math-Plus-7B",
40
+ label="Model",
41
+ interactive=True
42
+ )
43
+ k_value = gr.Dropdown(
44
+ choices=[8, 16, 32, 64, 128],
45
+ value=128,
46
+ label="K (Number of Sampled Reasoning Paths)",
47
+ interactive=True
48
+ )
49
+ seed = gr.Number(
50
+ label="Random Seed",
51
+ value=998244353,
52
+ step=1,
53
+ interactive=True
54
+ )
55
+ def update_k_value(dataset_choice):
56
+ if dataset_choice == "MATH":
57
+ return gr.update(choices=[8, 16, 32, 64], value=min(64, k_value.value))
58
+ else:
59
+ return gr.update(choices=[8, 16, 32, 64, 128], value=k_value.value)
60
+ dataset.change(fn=update_k_value, inputs=dataset, outputs=k_value)
61
+ load_btn = gr.Button("Load All Problems")
62
+
63
+ with gr.Column(visible=False) as content_column:
64
+ gr.Markdown("## 2. Problem Selection")
65
+ with gr.Group():
66
+ data_info = gr.Textbox(label="Experiment Info", value="")
67
+ problem_id = gr.Dropdown(
68
+ choices=[1],
69
+ value=1,
70
+ label="Problem ID (We removed (1) problems that were unlikely to be answered correctly using any of the methods; (2) easy problems)",
71
+ interactive=True
72
+ )
73
+ with gr.Row():
74
+ problem_prompt = gr.Textbox(label="Problem Prompt", value="", scale=3)
75
+ problem_answer = gr.Textbox(label="Problem Answer", value="", scale=1)
76
+ def update_problem_info(problem_id):
77
+ return gr.update(value=json_file['prompt'][problem_id-1], label=f"Problem#{problem_id} Prompt"), gr.update(value=json_file['answer'][problem_id-1], label=f"Problem#{problem_id} Answer")
78
+ problem_id.change(fn=update_problem_info, inputs=problem_id, outputs=[problem_prompt, problem_answer])
79
+ run_btn = gr.Button("Run Evaluation")
80
+
81
+ with gr.Column(visible=False) as result_column:
82
+ gr.Markdown("## 3. Experiment Result")
83
+ with gr.Row():
84
+ with gr.Column():
85
+ gr.Markdown("### PPL (Internal Probability)")
86
+ ppl_result = gr.Markdown()
87
+ with gr.Column():
88
+ gr.Markdown("### SC (Self-Consistency)")
89
+ sc_result = gr.Markdown(value="")
90
+ with gr.Column():
91
+ gr.Markdown("### RPC (Ours)")
92
+ rpc_result = gr.Markdown(value="")
93
+
94
+ def get_available_problems():
95
+ global json_file
96
+ answer = np.array(json_file["accuracy"]).mean(axis=0)
97
+ # print(answer.shape)
98
+ # Select indices where the answer is greater than 0.3
99
+ available_indices = np.where((answer > 0.3) & (answer < 0.5))[0]
100
+ available_indices = available_indices + 1
101
+ # print(available_indices)
102
+ return available_indices.tolist()
103
+
104
+
105
+ def load(dataset, model, k_value, seed):
106
+ try:
107
+ repo_id = {
108
+ "MATH": "WNJXYK/MATH-Reasoning-Paths",
109
+ "MathOdyssey": "WNJXYK/MathOdyssey-Reasoning-Paths",
110
+ "AIME": "WNJXYK/AIME_1983_2024-Reasoning-Paths",
111
+ "OlympiadBench": "WNJXYK/OlympiadBench-Reasoning-Paths"
112
+ }[dataset]
113
+ filename = f"{model}.json"
114
+
115
+ yield f"Downloading sampled reasoning paths from Hugging Face {repo_id}...", gr.update(visible=False), gr.update(), gr.update()
116
+ file_path = hf_hub_download(repo_id=repo_id, filename=filename, repo_type="dataset")
117
+
118
+ global json_file
119
+ with open(file_path, 'r', encoding='utf-8') as f:
120
+ json_file = json.load(f)
121
+ clist = get_available_problems()
122
+ # yield "Removing downloaded file..."
123
+ # print(file_path)
124
+ os.remove(file_path)
125
+
126
+ yield "Loading complete! You can now select a problem ID.", gr.update(visible=True), gr.update(value=f"Dataset: {dataset}\tModel: {model}\tK: {k_value}\tSeed: {seed}"), gr.update(choices=clist, value=clist[0])
127
+
128
+ except Exception as e:
129
+ yield f"Error: {str(e)}"
130
+
131
+ def res_to_str(correct, answers, topk=10):
132
+ answers = sorted(answers, key=lambda x: x[1], reverse=True)
133
+ response = "| # | Answer | Probability | Correct |\n|---|--------|------------|--------|\n"
134
+ for i in range(min(len(answers), topk)):
135
+ correct_mark = "✅" if answers[i][2] else "❌"
136
+ wrapped_answer = answers[i][0] if len(answers[i][0]) <= 10 else answers[i][0][:10] + "..."
137
+ response += f"| Top-{i+1} | {wrapped_answer} | {answers[i][1]:.2f} | {correct_mark} |\n"
138
+ return response
139
+
140
+ def evaluate(problem_id):
141
+ ppl_correct, ppl_answers = prep_evaluator(
142
+ json_file["predict"][problem_id-1],
143
+ json_file["completion"][problem_id-1],
144
+ json_file["mean_logprob"][problem_id-1],
145
+ json_file["answer"][problem_id-1],
146
+ numberic_compare,
147
+ check_equal
148
+ )
149
+
150
+ sc_correct, sc_answers = sc_evaluator(
151
+ json_file["predict"][problem_id-1],
152
+ json_file["completion"][problem_id-1],
153
+ json_file["mean_logprob"][problem_id-1],
154
+ json_file["answer"][problem_id-1],
155
+ numberic_compare,
156
+ check_equal
157
+ )
158
+
159
+ rpc_correct, rpc_answers = wpc_evaluator(
160
+ json_file["predict"][problem_id-1],
161
+ json_file["completion"][problem_id-1],
162
+ json_file["mean_logprob"][problem_id-1],
163
+ json_file["answer"][problem_id-1],
164
+ numberic_compare,
165
+ check_equal
166
+ )
167
+
168
+ return gr.update(visible=True), gr.update(value=res_to_str(ppl_correct, ppl_answers)), gr.update(value=res_to_str(sc_correct, sc_answers)), gr.update(value=res_to_str(rpc_correct, rpc_answers))
169
+
170
+ load_btn.click(fn=load, inputs=[dataset, model, k_value, seed],outputs=[load_btn, content_column, data_info, problem_id], show_progress="inside")
171
+ run_btn.click(fn=evaluate, inputs=problem_id, outputs=[result_column, ppl_result, sc_result, rpc_result])
172
+
173
+ if __name__ == "__main__":
174
+ demo.launch()
build_cache.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from compute_perp import check_equal
2
+ import multiprocessing, json, os, time
3
+
4
+ def solve(predict, answer):
5
+ cache_dict = {}
6
+ m = len(predict)
7
+
8
+ for i in range(m):
9
+ key = str(predict[i]) + "<##>" + str(answer)
10
+ rev_key = str(answer) + "<##>" + str(predict[i])
11
+ if key in cache_dict or rev_key in cache_dict:
12
+ continue
13
+ val = check_equal(predict[i], answer)
14
+ cache_dict[key] = val
15
+ cache_dict[rev_key] = val
16
+
17
+ for i in range(m):
18
+ for j in range(m):
19
+ key = str(predict[i]) + "<##>" + str(predict[j])
20
+ rev_key = str(predict[j]) + "<##>" + str(predict[i])
21
+ if key in cache_dict or rev_key in cache_dict:
22
+ continue
23
+ val = check_equal(predict[i], predict[j])
24
+ cache_dict[key] = val
25
+ cache_dict[rev_key] = val
26
+
27
+ return cache_dict
28
+
29
+ def cache(data, cache_path):
30
+ if os.path.exists(cache_path):
31
+ print(f"Cache file {cache_path} exists, skip!")
32
+ return
33
+ start_time = time.time()
34
+ predicts = data["predict"]
35
+ answers = data["answer"]
36
+ n = len(predicts)
37
+ cache_dict = {}
38
+ with multiprocessing.Pool() as pool:
39
+ results = pool.starmap(
40
+ solve, [(predicts[i], answers[i]) for i in range(n)]
41
+ )
42
+ for result in results:
43
+ cache_dict.update(result)
44
+ with open(cache_path, "w") as fw:
45
+ json.dump(cache_dict, fw)
46
+ print(f"Cache file {cache_path} built in {time.time() - start_time:.2f}S")
compute_perp.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import metrics
3
+ import argparse
4
+ import numpy as np
5
+ import multiprocessing
6
+ from tqdm import trange
7
+ import signal, functools
8
+ import re, os, sys, random, time
9
+ from fraction import Fraction
10
+ from data_processing.answer_extraction import *
11
+ from functools import lru_cache
12
+ from eval.eval_script import *
13
+ MAX_INT = sys.maxsize
14
+ INVALID_ANS = "[Invalid]"
15
+ INF = 1e9
16
+
17
+ __all__ = [
18
+ "check_equal",
19
+ "check_equal_without_timeout",
20
+ "numberic_compare",
21
+ "Evaluator",
22
+ ]
23
+
24
+ @lru_cache(maxsize=1000000)
25
+ def check_equal_without_timeout(ans_1, ans_2):
26
+ return math_equal(ans_1, ans_2)
27
+
28
+ def check_equal(ans_1, ans_2, cache_dict=None):
29
+ try:
30
+ if cache_dict is not None:
31
+ key = str(ans_1) + "<##>" + str(ans_2)
32
+ if key in cache_dict: return cache_dict[key]
33
+ print("Miss")
34
+ return check_equal_without_timeout(ans_1, ans_2)
35
+ except TimeoutError as e:
36
+ return False
37
+
38
+ def numberic_compare(ai, aj, ci, cj, cache_dict=None):
39
+ return check_equal(ai, aj, cache_dict)
40
+
41
+ def prep_evaluator(
42
+ predicts, completions, perplexities, answer, equal_func, check_equal
43
+ ):
44
+ m = len(predicts)
45
+
46
+ # Compute maximum probability
47
+ max_perplexity = -INF
48
+ max_perplexity_count = 0.0
49
+ for i in range(m):
50
+ if perplexities[i] > max_perplexity:
51
+ max_perplexity = perplexities[i]
52
+ max_perplexity_count = 0.0
53
+ if perplexities[i] >= max_perplexity:
54
+ max_perplexity_count += 1.0
55
+
56
+ # Compute accuracy
57
+ correct, answers = 0, []
58
+ for i in range(m):
59
+ ans_i = predicts[i]
60
+ answers.append([ans_i, np.exp(perplexities[i]), check_equal(ans_i, answer)])
61
+ if perplexities[i] < max_perplexity: continue
62
+ if check_equal(ans_i, answer):
63
+ correct += 1.0 / max_perplexity_count
64
+
65
+ return correct, answers
66
+
67
+ class Evaluator:
68
+ def __init__(self):
69
+ self.name = "Perplexity"
70
+
71
+ def process(self, json_file, cache_file, equal_func, evaluator, K, seed=0):
72
+ # with open(file_path, 'r', encoding='utf-8') as f:
73
+ # results = json.load(f)
74
+ results = json_file
75
+ n = len(results["predict"])
76
+ m = len(results["predict"][0])
77
+ indices = list(range(m))
78
+ random.seed(seed)
79
+ random.shuffle(indices)
80
+ indices = indices[: K]
81
+
82
+ if cache_file is not None:
83
+ def cache_equal_func(ai, aj, ci, cj):
84
+ return equal_func(ai, aj, ci, cj, cache_file)
85
+ def cache_check_equal(ai, aj):
86
+ return check_equal(ai, aj, cache_file)
87
+ else:
88
+ cache_equal_func = equal_func
89
+ cache_check_equal = check_equal
90
+
91
+
92
+ predicts, completions, perplexities, answers = [], [], [], []
93
+ for i in range(0, n):
94
+ predicts.append([results["predict"][i][j] for j in indices])
95
+ completions.append([results["completion"][i][j] for j in indices])
96
+ perplexities.append([results["mean_logprob"][i][j] for j in indices])
97
+ answers.append(results["answer"][i])
98
+ n = len(predicts)
99
+
100
+ start_time = time.time()
101
+ outputs = []
102
+ for idx in trange(n):
103
+ res = evaluator(
104
+ predicts[idx],
105
+ completions[idx],
106
+ perplexities[idx],
107
+ answers[idx],
108
+ cache_equal_func,
109
+ cache_check_equal,
110
+ )
111
+ outputs.append(res)
112
+ print(f"Running Time with Single Process Mode with Seed #{seed}: {time.time() - start_time:.2f}S")
113
+
114
+ for i in trange(n):
115
+ m = len(outputs[i][1])
116
+ for j in range(m):
117
+ ans, prob, flag = outputs[i][1][j]
118
+ maximum, max_bins = metrics.compute_maximum_metrics([x[1] for x in outputs])
119
+ average, avg_bins = metrics.compute_average_metrics([x[1] for x in outputs])
120
+ accs = np.mean([x[0] for x in outputs])
121
+ return accs * 100.0, maximum, average, max_bins, avg_bins
122
+
123
+ def worker(self, args):
124
+ json_file, cache_file, K, seed = args
125
+ acc, maximum, average, max_bins, avg_bins = self.process(
126
+ json_file=json_file,
127
+ cache_file=cache_file,
128
+ equal_func=numberic_compare,
129
+ evaluator=prep_evaluator,
130
+ K=K,
131
+ seed=seed
132
+ )
133
+ return acc, maximum, average
134
+
135
+ def solve(self, json_file, cache_file=None, repeats=10, K=128):
136
+ accs, maxs, avgs = [], [], []
137
+ with multiprocessing.Pool() as pool:
138
+ results = pool.map(self.worker, [(json_file, cache_file, K, seed) for seed in range(repeats)])
139
+ accs, maxs, _ = zip(*results)
140
+ accs, maxs = np.array(accs), np.array(maxs)
141
+ return {
142
+ "Accuracy": f"{accs.mean():.2f} ± {accs.std():.2f}",
143
+ "ECE": f"{maxs[:, 0].mean() * 100.0:.2f} ± {maxs[:, 0].std() * 100.0:.2f}",
144
+ }
compute_rpc.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import metrics
3
+ import argparse
4
+ import numpy as np
5
+ import multiprocessing
6
+ from tqdm import trange
7
+ import signal, functools
8
+ from scipy.special import gamma
9
+ import re, os, sys, random, time
10
+ from scipy.stats import weibull_min
11
+ from scipy.optimize import minimize
12
+ from fraction import Fraction
13
+ from data_processing.answer_extraction import *
14
+ from eval.eval_script import *
15
+ from compute_perp import Evaluator, numberic_compare
16
+ from compute_sc import DSU
17
+ MAX_INT = sys.maxsize
18
+ INVALID_ANS = "[Invalid]"
19
+
20
+ #### Reasoning Pruning Module: Model probability with Weibull distribution ####
21
+
22
+ def weibull_pdf(x, k, lam):
23
+ return (k / lam) * (x / lam) ** (k - 1) * np.exp(-((x / lam) ** k))
24
+
25
+ def weibull_mean(k, lam):
26
+ return lam * gamma(1 + 1 / k)
27
+
28
+ def mixture_pdf(x, w1, k1, lam1, k2, lam2):
29
+ return w1 * weibull_pdf(x, k1, lam1) + (1 - w1) * weibull_pdf(x, k2, lam2)
30
+
31
+ def neg_log_likelihood(params, data):
32
+ w1, k1, lam1, k2, lam2 = params
33
+ pdf_vals = mixture_pdf(data, w1, k1, lam1, k2, lam2)
34
+ return -np.sum(np.log(pdf_vals))
35
+
36
+ def calculate_membership_probabilities(data, w1, k1, lam1, k2, lam2):
37
+ pdf1 = weibull_pdf(data, k1, lam1)
38
+ pdf2 = weibull_pdf(data, k2, lam2)
39
+ prob1 = w1 * pdf1 / (w1 * pdf1 + (1 - w1) * pdf2)
40
+ prob2 = 1 - prob1
41
+ return prob1, prob2
42
+
43
+ ### Perplexity Consistency Module: Bridging the probability with self-consistency ####
44
+
45
+ def wpc_evaluator(predicts, completions, perplexities, answer, equal_func, check_equal):
46
+ m = len(predicts)
47
+ dsu = DSU(m)
48
+ probas = [np.exp(perplexities[i]) for i in range(m)]
49
+ mean_proba = np.mean(probas)
50
+
51
+ # Model probability with Weibull distribution
52
+ initial_guess = [0.5, 1.0, 1.0, 1.5, 2.0]
53
+ result = minimize(
54
+ neg_log_likelihood,
55
+ initial_guess,
56
+ args=(probas,),
57
+ bounds=[(0.2, 0.8), (0.01, None), (0.01, None), (0.01, None), (0.01, None)],
58
+ )
59
+ w1, k1, lam1, k2, lam2 = result.x
60
+ if weibull_mean(k1, lam1) < weibull_mean(k2, lam2):
61
+ k1, lam1, k2, lam2 = k2, lam2, k1, lam1
62
+ w1 = 1 - w1
63
+
64
+ # Pruning reasoning paths with low probabilities
65
+ remove = 0
66
+ for i in range(m):
67
+ completion_i = completions[i]
68
+ logprob_i = perplexities[i]
69
+ proba_i = np.exp(logprob_i)
70
+ p1, p2 = calculate_membership_probabilities(proba_i, w1, k1, lam1, k2, lam2)
71
+ if p1 < p2 and proba_i < mean_proba:
72
+ proba_i = 0
73
+ remove += 1
74
+ else:
75
+ dsu.attr[i][completion_i] = set([proba_i])
76
+
77
+ # Combining internal probabilities and self-consistency
78
+ for i in range(m):
79
+ if dsu.get_father(i) != i:
80
+ continue
81
+ for j in range(i):
82
+ ans_i = predicts[i]
83
+ ans_j = predicts[j]
84
+ completion_i = completions[i]
85
+ completion_j = completions[j]
86
+ if equal_func(ans_i, ans_j, completion_i, completion_j):
87
+ dsu.merge(i, j)
88
+
89
+ # Compute majority votes with probabilities
90
+ max_prob, max_prob_count = 0, 0
91
+ for i in range(m):
92
+ if dsu.get_father(i) != i:
93
+ continue
94
+ prob_i = np.sum([np.sum(list(dsu.attr[i][k])) for k in dsu.attr[i].keys()])
95
+ if prob_i > max_prob:
96
+ max_prob = prob_i
97
+ max_prob_count = 0
98
+ if prob_i >= max_prob:
99
+ max_prob_count += 1
100
+
101
+ # Compute accuracy
102
+ correct, answers = 0, []
103
+ for i in range(m):
104
+ if dsu.get_father(i) != i:
105
+ continue
106
+ ans_i = predicts[i]
107
+ prob_i = np.sum([np.sum(list(dsu.attr[i][k])) for k in dsu.attr[i].keys()])
108
+ answers.append([ans_i, prob_i, check_equal(ans_i, answer)])
109
+ if prob_i < max_prob:
110
+ continue
111
+ if check_equal(ans_i, answer):
112
+ correct += 1.0 / max_prob_count
113
+
114
+ # Normalize probabilities
115
+ sum_proba = np.sum([x[1] for x in answers])
116
+ for i in range(len(answers)):
117
+ answers[i][1] /= sum_proba
118
+
119
+ return correct, answers
120
+
121
+ class RPCEvaluator(Evaluator):
122
+ def __init__(self,):
123
+ self.name = "RPC"
124
+
125
+ def worker(self, args):
126
+ json_file, cache_file, K, seed = args
127
+ acc, maximum, average, max_bins, avg_bins = self.process(
128
+ json_file=json_file,
129
+ cache_file=cache_file,
130
+ equal_func=numberic_compare,
131
+ evaluator=wpc_evaluator,
132
+ K=K,
133
+ seed=seed
134
+ )
135
+ return acc, maximum, average
136
+
compute_sc.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import metrics
3
+ import argparse
4
+ import numpy as np
5
+ import multiprocessing
6
+ from tqdm import trange
7
+ import signal, functools
8
+ import re, os, sys, random, time
9
+ from fraction import Fraction
10
+ from data_processing.answer_extraction import *
11
+ from eval.eval_script import *
12
+ from compute_perp import Evaluator, numberic_compare
13
+ MAX_INT = sys.maxsize
14
+ INVALID_ANS = "[Invalid]"
15
+
16
+ __all__ = ["DSU"]
17
+
18
+ class DSU:
19
+ def __init__(self, n):
20
+ self.n = n
21
+ self.father = [i for i in range(n)]
22
+ self.size = [1 for i in range(n)]
23
+ self.attr = [{} for i in range(n)]
24
+
25
+ def get_father(self, x):
26
+ if self.father[x] == x:
27
+ return x
28
+ self.father[x] = self.get_father(self.father[x])
29
+ return self.father[x]
30
+
31
+ def merge(self, x, y):
32
+ fx = self.get_father(x)
33
+ fy = self.get_father(y)
34
+ if fx == fy:
35
+ return
36
+ self.father[fy] = fx
37
+ self.size[fx] += self.size[fy]
38
+ self.size[fy] = 0
39
+ for key in self.attr[fy].keys():
40
+ if key not in self.attr[fx]:
41
+ self.attr[fx][key] = self.attr[fy][key]
42
+ else:
43
+ self.attr[fx][key] |= self.attr[fy][key]
44
+ self.attr[fy] = {}
45
+
46
+
47
+ def sc_evaluator(predicts, completions, perplexities, answer, equal_func, check_equal):
48
+ m = len(predicts)
49
+ dsu = DSU(m)
50
+
51
+ # Merge answer for self-consistency
52
+ for i in range(m):
53
+ if dsu.get_father(i) != i:
54
+ continue
55
+ for j in range(i):
56
+ ans_i = predicts[i]
57
+ ans_j = predicts[j]
58
+ completion_i = completions[i]
59
+ completion_j = completions[j]
60
+ if equal_func(ans_i, ans_j, completion_i, completion_j):
61
+ dsu.merge(i, j)
62
+
63
+ # Compute majority votes
64
+ max_size, max_size_count = 0, 0
65
+ for i in range(m):
66
+ if dsu.get_father(i) != i:
67
+ continue
68
+ if dsu.size[i] > max_size:
69
+ max_size = dsu.size[i]
70
+ max_size_count = 0
71
+ if dsu.size[i] == max_size:
72
+ max_size_count += 1
73
+
74
+ # Compute accuracy
75
+ correct, answers = 0, []
76
+ for i in range(m):
77
+ if dsu.get_father(i) != i:
78
+ continue
79
+ ans_i = predicts[i]
80
+ answers.append([ans_i, dsu.size[i] / m, check_equal(ans_i, answer)])
81
+ if dsu.size[i] < max_size:
82
+ continue
83
+ if check_equal(ans_i, answer):
84
+ correct += 1.0 / max_size_count
85
+
86
+ # Normalize probabilities
87
+ sum_proba = np.sum([x[1] for x in answers])
88
+ for i in range(len(answers)):
89
+ answers[i][1] /= sum_proba
90
+
91
+ return correct, answers
92
+
93
+
94
+ class SCEvaluator(Evaluator):
95
+ def __init__(self):
96
+ self.name = "Self-Consistency"
97
+
98
+ def worker(self, args):
99
+ json_file, cache_file, K, seed = args
100
+ acc, maximum, average, max_bins, avg_bins = self.process(
101
+ json_file=json_file,
102
+ cache_file=cache_file,
103
+ equal_func=numberic_compare,
104
+ evaluator=sc_evaluator,
105
+ K=K,
106
+ seed=seed
107
+ )
108
+ return acc, maximum, average
data_processing/answer_extraction.py ADDED
@@ -0,0 +1,362 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import regex
3
+
4
+
5
+ def _fix_fracs(string):
6
+ substrs = string.split("\\frac")
7
+ new_str = substrs[0]
8
+ if len(substrs) > 1:
9
+ substrs = substrs[1:]
10
+ for substr in substrs:
11
+ new_str += "\\frac"
12
+ if len(substr) > 0 and substr[0] == "{":
13
+ new_str += substr
14
+ else:
15
+ try:
16
+ assert len(substr) >= 2
17
+ except:
18
+ return string
19
+ a = substr[0]
20
+ b = substr[1]
21
+ if b != "{":
22
+ if len(substr) > 2:
23
+ post_substr = substr[2:]
24
+ new_str += "{" + a + "}{" + b + "}" + post_substr
25
+ else:
26
+ new_str += "{" + a + "}{" + b + "}"
27
+ else:
28
+ if len(substr) > 2:
29
+ post_substr = substr[2:]
30
+ new_str += "{" + a + "}" + b + post_substr
31
+ else:
32
+ new_str += "{" + a + "}" + b
33
+ string = new_str
34
+ return string
35
+
36
+
37
+ def _fix_a_slash_b(string):
38
+ if len(string.split("/")) != 2:
39
+ return string
40
+ a = string.split("/")[0]
41
+ b = string.split("/")[1]
42
+ try:
43
+ if "sqrt" not in a:
44
+ a = int(a)
45
+ if "sqrt" not in b:
46
+ b = int(b)
47
+ assert string == "{}/{}".format(a, b)
48
+ new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
49
+ return new_string
50
+ except:
51
+ return string
52
+
53
+
54
+ def _fix_sqrt(string):
55
+ _string = re.sub(r"\\sqrt(-?[0-9.a-zA-Z]+)", r"\\sqrt{\1}", string)
56
+ _string = re.sub(r"\\sqrt\s+(\w+)$", r"\\sqrt{\1}", _string)
57
+ return _string
58
+
59
+
60
+ def _fix_tan(string):
61
+ _string = re.sub(r"\\tan(-?[0-9.a-zA-Z]+)", r"\\tan{\1}", string)
62
+ _string = re.sub(r"\\tan\s+(\w+)$", r"\\tan{\1}", _string)
63
+ return _string
64
+
65
+
66
+ def strip_string(string):
67
+ string = str(string).strip()
68
+ # linebreaks
69
+ string = string.replace("\n", "")
70
+
71
+ # right "."
72
+ string = string.rstrip(".")
73
+
74
+ # remove inverse spaces
75
+ string = string.replace("\\!", "")
76
+ # string = string.replace("\\ ", "")
77
+
78
+ # replace \\ with \
79
+ # string = string.replace("\\\\", "\\")
80
+ # string = string.replace("\\\\", "\\")
81
+
82
+ if string.startswith("\\text{") and string.endswith("}"):
83
+ string = string.split("{", 1)[1][:-1]
84
+
85
+ # replace tfrac and dfrac with frac
86
+ string = string.replace("tfrac", "frac")
87
+ string = string.replace("dfrac", "frac")
88
+ string = string.replace("cfrac", "frac")
89
+
90
+ # remove \left and \right
91
+ string = string.replace("\\left", "")
92
+ string = string.replace("\\right", "")
93
+
94
+ # Remove unit: miles, dollars if after is not none
95
+ _string = re.sub(r"\\text{.*?}$", "", string).strip()
96
+ if _string != "" and _string != string:
97
+ # print("Warning: unit not removed: '{}' -> '{}'".format(string, _string))
98
+ string = _string
99
+
100
+ # Remove circ (degrees)
101
+ string = string.replace("^{\\circ}", "").strip()
102
+ string = string.replace("^\\circ", "").strip()
103
+
104
+ string = regex.sub(r"\{(c|m)?m\}(\^(2|3))?", "", string).strip()
105
+ string = regex.sub(r"p\.m\.$", "", string).strip()
106
+ string = regex.sub(r"(\d)\s*t$", r"\1", string).strip()
107
+
108
+ # remove dollar signs
109
+ string = string.replace("\\$", "")
110
+ string = string.replace("$", "")
111
+
112
+ # string = string.replace("\\text", "")
113
+ string = string.replace("x\\in", "")
114
+
115
+ # remove percentage
116
+ string = string.replace("\\%", "%")
117
+ string = string.replace("\%", "%")
118
+ # string = string.replace("%", "")
119
+
120
+ # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
121
+ string = string.replace(" .", " 0.")
122
+ string = string.replace("{.", "{0.")
123
+
124
+ # cdot
125
+ string = string.replace("\\cdot", "")
126
+
127
+ # inf
128
+ string = string.replace("infinity", "\\infty")
129
+ if "\\infty" not in string:
130
+ string = string.replace("inf", "\\infty")
131
+ string = string.replace("+\\inity", "\\infty")
132
+
133
+ # and
134
+ # string = string.replace("and", "")
135
+ string = string.replace("\\mathbf", "")
136
+ string = string.replace("\\mathrm", "")
137
+
138
+ # use regex to remove \mbox{...}
139
+ string = re.sub(r"\\mbox{.*?}", "", string)
140
+
141
+ # quote
142
+ string.replace("'", "")
143
+ string.replace('"', "")
144
+
145
+ # i, j
146
+ if "j" in string and "i" not in string:
147
+ string = string.replace("j", "i")
148
+
149
+ # replace a.000b where b is not number or b is end, with ab, use regex
150
+ string = re.sub(r"(\d+)\.0+([^\d])", r"\1\2", string)
151
+ string = re.sub(r"(\d+)\.0+$", r"\1", string)
152
+
153
+ # if empty, return empty string
154
+ if len(string) == 0:
155
+ return string
156
+ if string[0] == ".":
157
+ string = "0" + string
158
+
159
+ # to consider: get rid of e.g. "k = " or "q = " at beginning
160
+ # if len(string.split("=")) == 2:
161
+ # if len(string.split("=")[0]) <= 2:
162
+ # string = string.split("=")[1]
163
+
164
+ string = _fix_sqrt(string)
165
+ string = _fix_tan(string)
166
+ string = string.replace(" ", "")
167
+
168
+ # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
169
+ string = _fix_fracs(string)
170
+
171
+ # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
172
+ string = _fix_a_slash_b(string)
173
+
174
+ string = regex.sub(r"(\\|,|\.)+$", "", string)
175
+
176
+ return string
177
+
178
+
179
+ def extract_boxed_answers(text):
180
+ answers = []
181
+ for piece in text.split("boxed{")[1:]:
182
+ n = 0
183
+ for i in range(len(piece)):
184
+ if piece[i] == "{":
185
+ n += 1
186
+ elif piece[i] == "}":
187
+ n -= 1
188
+ if n < 0:
189
+ if i + 1 < len(piece) and piece[i + 1] == "%":
190
+ answers.append(piece[: i + 1])
191
+ else:
192
+ answers.append(piece[:i])
193
+ break
194
+ return answers
195
+
196
+
197
+ def extract_program_output(pred_str):
198
+ """
199
+ extract output between the last ```output\n...\n```
200
+ """
201
+ if "```output" not in pred_str:
202
+ return ""
203
+ if "```output" in pred_str:
204
+ pred_str = pred_str.split("```output")[-1]
205
+ if "```" in pred_str:
206
+ pred_str = pred_str.split("```")[0]
207
+ output = pred_str.strip()
208
+ return output
209
+
210
+
211
+ def extract_answer(pred_str, exhaust=False):
212
+ pred = []
213
+ if "final answer is $" in pred_str and "$. I hope" in pred_str:
214
+ tmp = pred_str.split("final answer is $", 1)[1]
215
+ pred = [tmp.split("$. I hope", 1)[0].strip()]
216
+ elif "boxed" in pred_str:
217
+ pred = extract_boxed_answers(pred_str)
218
+ elif "he answer is" in pred_str:
219
+ pred = [pred_str.split("he answer is")[-1].strip()]
220
+ else:
221
+ program_output = extract_program_output(pred_str)
222
+ if program_output != "":
223
+ # fall back to program
224
+ pred.append(program_output)
225
+ else: # use the last number
226
+ pattern = "-?\d*\.?\d+"
227
+ ans = re.findall(pattern, pred_str.replace(",", ""))
228
+ if len(ans) >= 1:
229
+ ans = ans[-1]
230
+ else:
231
+ ans = ""
232
+ if ans:
233
+ pred.append(ans)
234
+
235
+ # multiple line
236
+ _pred = []
237
+ for ans in pred:
238
+ ans = ans.strip().split("\n")[0]
239
+ ans = ans.lstrip(":")
240
+ ans = ans.rstrip(".")
241
+ ans = ans.rstrip("/")
242
+ ans = strip_string(ans)
243
+ _pred.append(ans)
244
+ if exhaust:
245
+ return _pred
246
+ else:
247
+ return _pred[-1] if _pred else ""
248
+
249
+
250
+ def extract_math_answer(question, reasoning, task):
251
+ answer = []
252
+ for ans in extract_answer(reasoning, exhaust=True):
253
+ if "separated by commas" in question and all(ch not in ans for ch in "()[]"):
254
+ answer.extend([a.strip() for a in ans.split(",")])
255
+ elif regex.search(r"\\text\{\s*and\s*\}", ans):
256
+ answer.extend(
257
+ [
258
+ a.strip()
259
+ for a in regex.sub(r"\\text\{\s*and\s*\}", "[SEP]", ans).split(
260
+ "[SEP]"
261
+ )
262
+ ]
263
+ )
264
+ else:
265
+ answer.append(ans.strip())
266
+ return answer
267
+
268
+
269
+ def extract_math_few_shot_cot_answer(question, reasoning, task):
270
+ if "Problem:" in reasoning:
271
+ reasoning = reasoning.split("Problem:", 1)[0]
272
+ return extract_math_answer(question, reasoning, task)
273
+
274
+
275
+ def extract_last_single_answer(question, reasoning, task):
276
+ return extract_answer(reasoning, exhaust=False)
277
+
278
+
279
+ def extract_gsm_few_shot_cot_answer(question, reasoning, task):
280
+ if "Q: " in reasoning:
281
+ reasoning = reasoning.split("Q: ", 1)[0]
282
+ pred = [s for s in regex.findall(r"-?\d+\.?\d*", reasoning)]
283
+ if pred:
284
+ return pred[-1]
285
+ else:
286
+ return "[invalid]"
287
+
288
+
289
+ def extract_agieval_gaokao_mathcloze_few_shot_cot_test(question, reasoning, task):
290
+ if "问题 " in reasoning:
291
+ reasoning = reasoning.split("问题 ", 1)[0]
292
+ if "答案是" in reasoning:
293
+ ans = reasoning.split("答案是", 1)[1].strip()
294
+ ans = ans.split("\n")[0].strip()
295
+ ans = [ans.strip("$")]
296
+ else:
297
+ ans = ["placeholder"]
298
+ return ans
299
+
300
+
301
+ def extract_agieval_gaokao_mathqa_few_shot_cot_test(question, reasoning, task):
302
+ if "问题 " in reasoning:
303
+ reasoning = reasoning.split("问题 ", 1)[0]
304
+ if "答案是" in reasoning:
305
+ ans = reasoning.split("答案是", 1)[1].strip()
306
+ ans = ans.split("\n")[0].strip()
307
+ else:
308
+ ans = "placeholder"
309
+ return ans
310
+
311
+
312
+ def extract_sat_few_shot_answer(question, reasoning, task):
313
+ if "Problem:" in reasoning:
314
+ reasoning = reasoning.split("Problem:", 1)[0]
315
+ patt = regex.search(r"the final answer is \(?(?P<ans>[abcd])\)?", reasoning.lower())
316
+ if patt is not None:
317
+ return patt.group("ans").upper()
318
+ return "placeholder"
319
+
320
+
321
+ def extract_ocwcourses_few_shot_answer(question, reasoning, task):
322
+ if "Problem:" in reasoning:
323
+ reasoning = reasoning.split("Problem:", 1)[0]
324
+ patt = regex.search(
325
+ r"final answer is (?P<ans>.*)\. I hope it is correct.", reasoning
326
+ )
327
+ if patt is None:
328
+ pred = "[invalid]"
329
+ print(f"DEBUG >>>\n{reasoning}", flush=True)
330
+ else:
331
+ pred = patt.group("ans")
332
+ return pred
333
+
334
+
335
+ def extract_mmlu_stem(question, reasoning, task):
336
+ if "Problem:" in reasoning:
337
+ reasoning = reasoning.split("Problem:", 1)[0]
338
+ return extract_sat_few_shot_answer(question, reasoning, task)
339
+
340
+
341
+ def extract_minif2f_isabelle(question, reasoning, task):
342
+ if "Informal:" in reasoning:
343
+ reasoning = reasoning.split("Informal:", 1)[0]
344
+ return reasoning.strip()
345
+
346
+
347
+ def extract_cmath_few_shot_test(question, reasoning, task):
348
+ if "问题:" in reasoning:
349
+ reasoning = reasoning.split("问题:", 1)[0]
350
+ if "答案是" in reasoning:
351
+ ans = reasoning.split("答案是", 1)[1].strip()
352
+ ans = ans.split("\n")[0]
353
+ ans = ans.strip(":")
354
+ ans = ans.strip("。")
355
+ try:
356
+ ans = [s for s in regex.findall(r"-?\d+\.?\d*", ans)][-1]
357
+ except:
358
+ print(f"DEBUG CMATH: {reasoning}", flush=True)
359
+ ans = "[invalid]"
360
+ else:
361
+ ans = extract_last_single_answer(question, reasoning, task)
362
+ return ans
data_processing/process_utils.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import regex
2
+
3
+ from data_processing.answer_extraction import extract_math_answer, strip_string
4
+
5
+
6
+ def process_gsm8k_test(item):
7
+ sample = {
8
+ "dataset": "gsm8k-cot",
9
+ "id": item["id"],
10
+ "messages": [
11
+ {"role": "user", "content": item["question"]},
12
+ {
13
+ "role": "assistant",
14
+ "content": regex.sub(r"<<[^<>]*>>", "", item["cot"])
15
+ + "\nSo the answer is $\\boxed{"
16
+ + item["answer"].strip()
17
+ + "}$.",
18
+ },
19
+ ],
20
+ "answer": item["answer"].replace(",", ""),
21
+ }
22
+ yield sample
23
+
24
+
25
+ def process_math_test(item):
26
+ question = item["problem"]
27
+ try:
28
+ answer = extract_math_answer(question, item["solution"], task="cot")
29
+ except:
30
+ return
31
+ sample = {
32
+ "dataset": "math-cot",
33
+ "id": item["id"],
34
+ "level": item["level"],
35
+ "type": item["type"],
36
+ "category": item["category"],
37
+ "messages": [
38
+ {"role": "user", "content": question},
39
+ {
40
+ "role": "assistant",
41
+ "content": "\n".join(
42
+ regex.split(r"(?<=\.) (?=[A-Z])", item["solution"])
43
+ ),
44
+ },
45
+ ],
46
+ "answer": answer,
47
+ }
48
+ yield sample
49
+
50
+
51
+ def process_math_sat(item):
52
+ options = item["options"].strip()
53
+ assert "A" == options[0]
54
+ options = "(" + options
55
+ for ch in "BCDEFG":
56
+ if f" {ch}) " in options:
57
+ options = regex.sub(f" {ch}\) ", f" ({ch}) ", options)
58
+ question = f"{item['question'].strip()}\nWhat of the following is the right choice? Explain your answer.\n{options.strip()}"
59
+ messages = [
60
+ {"role": "user", "content": question},
61
+ {"role": "assistant", "content": item["Answer"]},
62
+ ]
63
+ item = {
64
+ "dataset": "math_sat",
65
+ "id": item["id"],
66
+ "language": "en",
67
+ "messages": messages,
68
+ "answer": item["Answer"],
69
+ }
70
+ yield item
71
+
72
+
73
+ def process_ocwcourses(item):
74
+ messages = [
75
+ {"role": "user", "content": item["problem"].strip()},
76
+ {"role": "assistant", "content": item["solution"].strip()},
77
+ ]
78
+ item = {
79
+ "dataset": "OCWCourses",
80
+ "id": item["id"],
81
+ "language": "en",
82
+ "messages": messages,
83
+ "answer": item["answer"],
84
+ }
85
+ yield item
86
+
87
+
88
+ def process_mmlu_stem(item):
89
+ options = item["options"]
90
+ for i, (label, option) in enumerate(zip("ABCD", options)):
91
+ options[i] = f"({label}) {str(option).strip()}"
92
+ options = ", ".join(options)
93
+ question = f"{item['question'].strip()}\nWhat of the following is the right choice? Explain your answer.\n{options}"
94
+ messages = [
95
+ {"role": "user", "content": question},
96
+ {"role": "assistant", "content": item["answer"]},
97
+ ]
98
+ item = {
99
+ "dataset": "MMLU-STEM",
100
+ "id": item["id"],
101
+ "language": "en",
102
+ "messages": messages,
103
+ "answer": item["answer"],
104
+ }
105
+ yield item
106
+
107
+
108
+ def process_mgsm_zh(item):
109
+ item["answer"] = item["answer"].replace(",", "")
110
+ yield item
111
+
112
+
113
+ def process_cmath(item):
114
+ item = {
115
+ "dataset": "cmath",
116
+ "id": item["id"],
117
+ "grade": item["grade"],
118
+ "reasoning_step": item["reasoning_step"],
119
+ "messages": [
120
+ {"role": "user", "content": item["question"].strip()},
121
+ {"role": "assistant", "content": ""},
122
+ ],
123
+ "answer": item["golden"].strip().replace(",", ""),
124
+ }
125
+ yield item
126
+
127
+
128
+ def process_agieval_gaokao_math_cloze(item):
129
+ item = {
130
+ "dataset": "agieval-gaokao-math-cloze",
131
+ "id": item["id"],
132
+ "messages": [
133
+ {"role": "user", "content": item["question"].strip()},
134
+ {"role": "assistant", "content": ""},
135
+ ],
136
+ "answer": [strip_string(ans) for ans in item["answer"].strip().split(";")],
137
+ }
138
+ yield item
139
+
140
+
141
+ def process_agieval_gaokao_mathqa(item):
142
+ question = item["question"].strip()
143
+ options = []
144
+ for option in item["options"]:
145
+ option = option.strip()
146
+ assert option[0] == "("
147
+ assert option[2] == ")"
148
+ assert option[1] in "ABCD"
149
+ option = f"{option[1]}: {option[3:].strip()}"
150
+ options.append(option.strip())
151
+ question = f"{question}\n{options}"
152
+ item = {
153
+ "dataset": "agieval-gaokao-mathqa",
154
+ "id": item["id"],
155
+ "messages": [
156
+ {"role": "user", "content": question},
157
+ {"role": "assistant", "content": ""},
158
+ ],
159
+ "answer": item["label"],
160
+ }
161
+ yield item
162
+
163
+
164
+ def process_agieval_gaokao_mathqa_few_shot_cot_test(item):
165
+ question = item["question"].strip().rstrip("\\")
166
+ options = " ".join([opt.strip() for opt in item["options"]])
167
+ question = f"{question}\n从以下选项中选择: {options}"
168
+ item = {
169
+ "dataset": "agieval-gaokao-mathqa",
170
+ "id": item["id"],
171
+ "messages": [
172
+ {"role": "user", "content": question},
173
+ {"role": "assistant", "content": ""},
174
+ ],
175
+ "answer": item["label"],
176
+ }
177
+ yield item
178
+
179
+
180
+ def process_minif2f_isabelle(item):
181
+ question = f"(*### Problem\n\n{item['informal_statement'].strip()}\n\n### Solution\n\n{item['informal_proof'].strip()} *)\n\nFormal:\n{item['formal_statement'].strip()}"
182
+ item = {
183
+ "dataset": "minif2f-isabelle",
184
+ "id": item["id"],
185
+ "messages": [
186
+ {"role": "user", "content": question},
187
+ {"role": "assistant", "content": ""},
188
+ ],
189
+ "answer": "placeholder",
190
+ }
191
+ yield item
eval/eval_script.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import regex
2
+ from copy import deepcopy
3
+ from eval.eval_utils import math_equal
4
+ from eval.ocwcourses_eval_utils import (
5
+ normalize_numeric,
6
+ numeric_equality,
7
+ normalize_symbolic_equation,
8
+ SymbolicMathMixin,
9
+ )
10
+
11
+
12
+ def is_correct(item, pred_key="prediction", prec=1e-3):
13
+ pred = item[pred_key]
14
+ ans = item["answer"]
15
+ if isinstance(pred, list) and isinstance(ans, list):
16
+ pred_matched = set()
17
+ ans_matched = set()
18
+ for i in range(len(pred)):
19
+ for j in range(len(ans)):
20
+ item_cpy = deepcopy(item)
21
+ item_cpy.update({pred_key: pred[i], "answer": ans[j]})
22
+ if is_correct(item_cpy, pred_key=pred_key, prec=prec):
23
+ pred_matched.add(i)
24
+ ans_matched.add(j)
25
+ if item_cpy[pred_key] == "2,3,4":
26
+ print(item, flush=True)
27
+ print("wtf", flush=True)
28
+ return len(pred_matched) == len(pred) and len(ans_matched) == len(ans)
29
+ elif isinstance(pred, str) and isinstance(ans, str):
30
+ if "\\cup" in pred and "\\cup" in ans:
31
+ item = deepcopy(item)
32
+ item.update(
33
+ {
34
+ pred_key: pred.split("\\cup"),
35
+ "answer": ans.split("\\cup"),
36
+ }
37
+ )
38
+ return is_correct(item, pred_key=pred_key, prec=prec)
39
+ else:
40
+ label = False
41
+ try:
42
+ label = (
43
+ abs(
44
+ float(regex.sub(r",", "", str(pred)))
45
+ - float(regex.sub(r",", "", str(ans)))
46
+ )
47
+ < prec
48
+ )
49
+ except:
50
+ pass
51
+ label = label or (ans and pred == ans) or math_equal(pred, ans)
52
+ return label
53
+ else:
54
+ print(item, flush=True)
55
+ raise NotImplementedError()
56
+
57
+
58
+ def eval_math(item, pred_key="prediction", prec=1e-3):
59
+ pred = item[pred_key]
60
+ if pred_key == "program_output" and isinstance(pred, str):
61
+ pred = [pred]
62
+ ans = item["answer"]
63
+ if isinstance(pred, list) and isinstance(ans, list):
64
+ # for some questions in MATH, `reference` repeats answers
65
+ _ans = []
66
+ for a in ans:
67
+ if a not in _ans:
68
+ _ans.append(a)
69
+ ans = _ans
70
+ # some predictions for MATH questions also repeats answers
71
+ _pred = []
72
+ for a in pred:
73
+ if a not in _pred:
74
+ _pred.append(a)
75
+ # some predictions mistakenly box non-answer strings
76
+ pred = _pred[-len(ans) :]
77
+
78
+ item.update({pred_key: pred, "answer": ans})
79
+ return is_correct(item, pred_key=pred_key, prec=prec)
80
+
81
+
82
+ def eval_last_single_answer(item, pred_key="prediction", prec=1e-3):
83
+ for key in [pred_key, "answer"]:
84
+ assert isinstance(item[key], str), f"{key} = `{item[key]}` is not a str"
85
+ return is_correct(item, pred_key=pred_key, prec=prec)
86
+
87
+
88
+ def eval_agieval_gaokao_math_cloze(item, pred_key="prediction", prec=1e-3):
89
+ if pred_key == "program_output" and isinstance(item[pred_key], str):
90
+ item[pred_key] = [item[pred_key]]
91
+ for key in [pred_key, "answer"]:
92
+ assert isinstance(item[key], list), f"{key} = `{item[key]}` is not a list"
93
+ pred = item[pred_key]
94
+ ans = item["answer"]
95
+ _pred = []
96
+ for p in pred:
97
+ p = p + ";"
98
+ while p:
99
+ left_brackets = 0
100
+ for i in range(len(p)):
101
+ if p[i] == ";" or (p[i] == "," and left_brackets == 0):
102
+ _p, p = p[:i].strip(), p[i + 1 :].strip()
103
+ if _p not in _pred:
104
+ _pred.append(_p)
105
+ break
106
+ elif p[i] in "([{":
107
+ left_brackets += 1
108
+ elif p[i] in ")]}":
109
+ left_brackets -= 1
110
+ pred = _pred[-len(ans) :]
111
+ if len(pred) == len(ans):
112
+ for p, a in zip(pred, ans):
113
+ item.update(
114
+ {
115
+ pred_key: p,
116
+ "answer": a,
117
+ }
118
+ )
119
+ if not is_correct(item, pred_key=pred_key, prec=prec):
120
+ return False
121
+ return True
122
+ else:
123
+ return False
124
+
125
+
126
+ def eval_agieval_gaokao_mathqa(item, pred_key="prediction", prec=1e-3):
127
+ if pred_key == "program_output" and isinstance(item[pred_key], str):
128
+ item[pred_key] = [item[pred_key]]
129
+ pred_str = " ".join(item[pred_key])
130
+ ans = item["answer"]
131
+ tag = None
132
+ idx = -1
133
+ for t in "ABCD":
134
+ if t in pred_str and pred_str.index(t) > idx:
135
+ tag = t
136
+ idx = pred_str.index(t)
137
+ return tag == ans
138
+
139
+
140
+ def eval_math_sat(item, pred_key="prediction", prec=1e-3):
141
+ for key in [pred_key, "answer"]:
142
+ assert isinstance(item[key], str), f"{key} = `{item[key]}` is not a str"
143
+ return item[pred_key].lower() == item["answer"].lower()
144
+
145
+
146
+ def eval_mmlu_stem(item, pred_key="prediction", prec=1e-3):
147
+ return eval_math_sat(item, pred_key=pred_key, prec=prec)
148
+
149
+
150
+ def eval_ocwcourses(item, pred_key="prediction", prec=1e-3):
151
+ INVALID_ANSWER = "[invalidanswer]"
152
+ for key in [pred_key, "answer"]:
153
+ assert isinstance(item[key], str), f"{key} = `{item[key]}` is not a str"
154
+ pred = item[pred_key]
155
+ ans = item["answer"]
156
+
157
+ try:
158
+ float(ans)
159
+ normalize_fn = normalize_numeric
160
+ is_equiv = numeric_equality
161
+ answer_type = "numeric"
162
+ except ValueError:
163
+ if "=" in ans:
164
+ normalize_fn = normalize_symbolic_equation
165
+ is_equiv = lambda x, y: x == y
166
+ answer_type = "equation"
167
+ else:
168
+ normalize_fn = SymbolicMathMixin().normalize_tex
169
+ is_equiv = SymbolicMathMixin().is_tex_equiv
170
+ answer_type = "expression"
171
+
172
+ correct_answer = normalize_fn(ans)
173
+
174
+ unnormalized_answer = pred if pred else INVALID_ANSWER
175
+ model_answer = normalize_fn(unnormalized_answer)
176
+
177
+ if unnormalized_answer == INVALID_ANSWER:
178
+ acc = 0
179
+ elif model_answer == INVALID_ANSWER:
180
+ acc = 0
181
+ elif is_equiv(model_answer, correct_answer):
182
+ acc = 1
183
+ else:
184
+ acc = 0
185
+
186
+ return acc
187
+
188
+
189
+ def eval_minif2f_isabelle(item, pred_key="prediction", prec=1e-3):
190
+ return True
eval/eval_utils.py ADDED
@@ -0,0 +1,400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import multiprocessing
2
+ from math import isclose
3
+ import numpy as np
4
+ from typing import Union, Any, Dict
5
+
6
+ from sympy import simplify, N
7
+ from sympy.parsing.sympy_parser import parse_expr
8
+ from sympy.parsing.latex import parse_latex
9
+ import re
10
+ import regex
11
+
12
+ from data_processing.answer_extraction import (
13
+ extract_answer,
14
+ extract_program_output,
15
+ strip_string,
16
+ )
17
+
18
+
19
+ def extract_program(result: str, last_only=True):
20
+ """
21
+ extract the program after "```python", and before "```"
22
+ """
23
+ program = ""
24
+ start = False
25
+ for line in result.split("\n"):
26
+ if line.startswith("```python"):
27
+ if last_only:
28
+ program = "" # only extract the last program
29
+ else:
30
+ program += "\n# ========\n"
31
+ start = True
32
+ elif line.startswith("```"):
33
+ start = False
34
+ elif start:
35
+ program += line + "\n"
36
+ return program
37
+
38
+
39
+ def parse_ground_truth(example: Dict[str, Any], data_name):
40
+ if "gt_cot" in example:
41
+ return example["gt_cot"], strip_string(example["gt"])
42
+
43
+ # parse ground truth
44
+ if data_name in ["math", "ocw"]:
45
+ gt_cot = example["solution"]
46
+ gt_ans = extract_answer(gt_cot)
47
+ elif data_name == "gsm8k":
48
+ gt_cot, gt_ans = example["answer"].split("####")
49
+ elif data_name == "gsm-hard":
50
+ gt_cot, gt_ans = example["code"], example["target"]
51
+ elif data_name == "svamp":
52
+ gt_cot, gt_ans = example["Equation"], example["Answer"]
53
+ elif data_name == "asdiv":
54
+ gt_cot = example["formula"]
55
+ gt_ans = re.sub(r"\(.*?\)", "", example["answer"])
56
+ elif data_name == "mawps":
57
+ gt_cot, gt_ans = None, example["target"]
58
+ elif data_name == "tabmwp":
59
+ gt_cot = example["solution"]
60
+ gt_ans = example["answer"]
61
+ if example["ans_type"] in ["integer_number", "decimal_number"]:
62
+ if "/" in gt_ans:
63
+ gt_ans = int(gt_ans.split("/")[0]) / int(gt_ans.split("/")[1])
64
+ elif "," in gt_ans:
65
+ gt_ans = float(gt_ans.replace(",", ""))
66
+ elif "%" in gt_ans:
67
+ gt_ans = float(gt_ans.split("%")[0]) / 100
68
+ else:
69
+ gt_ans = float(gt_ans)
70
+ elif data_name == "bbh":
71
+ gt_cot, gt_ans = None, example["target"]
72
+ else:
73
+ raise NotImplementedError(data_name)
74
+ # post process
75
+ gt_cot = str(gt_cot).strip()
76
+ gt_ans = strip_string(gt_ans)
77
+ return gt_cot, gt_ans
78
+
79
+
80
+ def parse_question(example, data_name):
81
+ question = ""
82
+ if data_name == "asdiv":
83
+ question = f"{example['body'].strip()} {example['question'].strip()}"
84
+ elif data_name == "svamp":
85
+ body = example["Body"].strip()
86
+ if not body.endswith("."):
87
+ body = body + "."
88
+ question = f'{body} {example["Question"].strip()}'
89
+ elif data_name == "tabmwp":
90
+ title_str = (
91
+ f'regarding "{example["table_title"]}" ' if example["table_title"] else ""
92
+ )
93
+ question = f"Read the following table {title_str}and answer a question:\n"
94
+ question += f'{example["table"]}\n{example["question"]}'
95
+ if example["choices"]:
96
+ question += (
97
+ f' Please select from the following options: {example["choices"]}'
98
+ )
99
+ else:
100
+ for key in ["question", "problem", "Question", "input"]:
101
+ if key in example:
102
+ question = example[key]
103
+ break
104
+ assert question != ""
105
+ return question.strip()
106
+
107
+
108
+ def run_execute(executor, result, prompt_type, execute=False):
109
+ if not result or result == "error":
110
+ return None, None
111
+ report = None
112
+
113
+ if "program_only" in prompt_type:
114
+ prediction = extract_program_output(result)
115
+ elif prompt_type in ["pot", "pal"] and execute:
116
+ code = extract_program(result)
117
+ prediction, report = executor.apply(code)
118
+ else:
119
+ prediction = extract_answer(result)
120
+
121
+ prediction = strip_string(prediction)
122
+ return prediction, report
123
+
124
+
125
+ def parse_digits(num):
126
+ # format: 234.23 || 23%
127
+ num = regex.sub(",", "", str(num))
128
+ try:
129
+ return float(num)
130
+ except:
131
+ if num.endswith("%"):
132
+ num = num[:-1]
133
+ if num.endswith("\\"):
134
+ num = num[:-1]
135
+ try:
136
+ return float(num) / 100
137
+ except:
138
+ pass
139
+ return None
140
+
141
+
142
+ def is_digit(num):
143
+ # paired with parse_digits
144
+ return parse_digits(num) is not None
145
+
146
+
147
+ def normalize_prediction(prediction):
148
+ try: # 1. numerical equal
149
+ if is_digit(prediction):
150
+ prediction = np.round(float(str(prediction).replace(",", "")), 6)
151
+ return str(prediction)
152
+ except:
153
+ pass
154
+
155
+ # 2. symbolic equal
156
+ prediction = str(prediction).strip()
157
+
158
+ ## deal with [], (), {}
159
+ brackets = []
160
+ while (
161
+ prediction.startswith("[")
162
+ and prediction.endswith("]")
163
+ or (prediction.startswith("(") and prediction.endswith(")"))
164
+ ):
165
+ bracket = prediction[0]
166
+ prediction = prediction[1:-1]
167
+ if brackets and "," in prediction:
168
+ pred_parts = [normalize_prediction(part) for part in prediction.split(",")]
169
+ prediction = ",".join(pred_parts)
170
+
171
+ if brackets:
172
+ for b in reversed(brackets):
173
+ if b == "[":
174
+ prediction = "[" + prediction + "]"
175
+ else:
176
+ assert b == "("
177
+ prediction = "(" + prediction + ")"
178
+
179
+ def _parse(s):
180
+ for f in [parse_latex, parse_expr]:
181
+ try:
182
+ return f(s)
183
+ except:
184
+ pass
185
+ return s
186
+
187
+ prediction = _parse(prediction)
188
+
189
+ for s in ["{", "}", "(", ")"]:
190
+ prediction = prediction.replace(s, "")
191
+
192
+ return prediction
193
+
194
+
195
+ def math_equal(
196
+ prediction: Union[bool, float, str],
197
+ reference: Union[float, str],
198
+ include_percentage: bool = True,
199
+ is_close: bool = True,
200
+ timeout: bool = False,
201
+ ) -> bool:
202
+ """
203
+ Exact match of math if and only if:
204
+ 1. numerical equal: both can convert to float and are equal
205
+ 2. symbolic equal: both can convert to sympy expression and are equal
206
+ """
207
+ if str(prediction) == str(reference):
208
+ return True
209
+
210
+ try: # 1. numerical equal
211
+ if is_digit(prediction) and is_digit(reference):
212
+ prediction = parse_digits(prediction)
213
+ reference = parse_digits(reference)
214
+ # number questions
215
+ if include_percentage:
216
+ gt_result = [reference / 100, reference, reference * 100]
217
+ else:
218
+ gt_result = [reference]
219
+ for item in gt_result:
220
+ try:
221
+ if is_close:
222
+ if isclose(item, prediction, abs_tol=1e-3):
223
+ return True
224
+ else:
225
+ if item == prediction:
226
+ return True
227
+ except Exception:
228
+ continue
229
+ return False
230
+ except:
231
+ pass
232
+
233
+ if not prediction and prediction not in [0, False]:
234
+ return False
235
+
236
+ # 2. symbolic equal
237
+ reference = str(reference).strip()
238
+ prediction = str(prediction).strip()
239
+
240
+ if (
241
+ regex.match(r"(\(|\[).+(\)|\])", prediction) is not None
242
+ and regex.match(r"(\(|\[).+(\)|\])", reference) is not None
243
+ ):
244
+ pred_parts = prediction[1:-1].split(",")
245
+ ref_parts = reference[1:-1].split(",")
246
+ if len(pred_parts) == len(ref_parts):
247
+ if all(
248
+ [
249
+ math_equal(
250
+ pred_parts[i], ref_parts[i], include_percentage, is_close
251
+ )
252
+ for i in range(len(pred_parts))
253
+ ]
254
+ ):
255
+ return True
256
+
257
+ if (
258
+ (
259
+ prediction.startswith("\\begin{pmatrix}")
260
+ or prediction.startswith("\\begin{bmatrix}")
261
+ )
262
+ and (
263
+ prediction.endswith("\\end{pmatrix}")
264
+ or prediction.endswith("\\end{bmatrix}")
265
+ )
266
+ and (
267
+ reference.startswith("\\begin{pmatrix}")
268
+ or reference.startswith("\\begin{bmatrix}")
269
+ )
270
+ and (
271
+ reference.endswith("\\end{pmatrix}") or reference.endswith("\\end{bmatrix}")
272
+ )
273
+ ):
274
+ pred_lines = [
275
+ line.strip()
276
+ for line in prediction[
277
+ len("\\begin{pmatrix}") : -len("\\end{pmatrix}")
278
+ ].split("\\\\")
279
+ if line.strip()
280
+ ]
281
+ ref_lines = [
282
+ line.strip()
283
+ for line in reference[
284
+ len("\\begin{pmatrix}") : -len("\\end{pmatrix}")
285
+ ].split("\\\\")
286
+ if line.strip()
287
+ ]
288
+ matched = True
289
+ if len(pred_lines) == len(ref_lines):
290
+ for pred_line, ref_line in zip(pred_lines, ref_lines):
291
+ pred_parts = pred_line.split("&")
292
+ ref_parts = ref_line.split("&")
293
+ if len(pred_parts) == len(ref_parts):
294
+ if not all(
295
+ [
296
+ math_equal(
297
+ pred_parts[i],
298
+ ref_parts[i],
299
+ include_percentage,
300
+ is_close,
301
+ )
302
+ for i in range(len(pred_parts))
303
+ ]
304
+ ):
305
+ matched = False
306
+ break
307
+ else:
308
+ matched = False
309
+ if not matched:
310
+ break
311
+ else:
312
+ matched = False
313
+ if matched:
314
+ return True
315
+
316
+ if prediction.count("=") == 1 and reference.count("=") == 1:
317
+ pred = prediction.split("=")
318
+ pred = f"{pred[0].strip()} - ({pred[1].strip()})"
319
+ ref = reference.split("=")
320
+ ref = f"{ref[0].strip()} - ({ref[1].strip()})"
321
+ if symbolic_equal(pred, ref) or symbolic_equal(f"-({pred})", ref):
322
+ return True
323
+ elif (
324
+ prediction.count("=") == 1
325
+ and len(prediction.split("=")[0].strip()) <= 2
326
+ and "=" not in reference
327
+ ):
328
+ if math_equal(
329
+ prediction.split("=")[1], reference, include_percentage, is_close
330
+ ):
331
+ return True
332
+ elif (
333
+ reference.count("=") == 1
334
+ and len(reference.split("=")[0].strip()) <= 2
335
+ and "=" not in prediction
336
+ ):
337
+ if math_equal(
338
+ prediction, reference.split("=")[1], include_percentage, is_close
339
+ ):
340
+ return True
341
+
342
+ # symbolic equal with sympy
343
+ if timeout:
344
+ if call_with_timeout(symbolic_equal_process, prediction, reference):
345
+ return True
346
+ else:
347
+ if symbolic_equal(prediction, reference):
348
+ return True
349
+
350
+ return False
351
+
352
+
353
+ def math_equal_process(param):
354
+ return math_equal(param[-2], param[-1])
355
+
356
+
357
+ def symbolic_equal(a, b):
358
+ def _parse(s):
359
+ for f in [parse_latex, parse_expr]:
360
+ try:
361
+ return f(s)
362
+ except:
363
+ pass
364
+ return s
365
+
366
+ a = _parse(a)
367
+ b = _parse(b)
368
+
369
+ try:
370
+ if simplify(a - b) == 0:
371
+ return True
372
+ except:
373
+ pass
374
+
375
+ try:
376
+ if isclose(N(a), N(b), abs_tol=1e-3):
377
+ return True
378
+ except:
379
+ pass
380
+ return False
381
+
382
+
383
+ def symbolic_equal_process(a, b, output_queue):
384
+ result = symbolic_equal(a, b)
385
+ output_queue.put(result)
386
+
387
+
388
+ def call_with_timeout(func, *args, timeout=1, **kwargs):
389
+ output_queue = multiprocessing.Queue()
390
+ process_args = args + (output_queue,)
391
+ process = multiprocessing.Process(target=func, args=process_args, kwargs=kwargs)
392
+ process.start()
393
+ process.join(timeout)
394
+
395
+ if process.is_alive():
396
+ process.terminate()
397
+ process.join()
398
+ return False
399
+
400
+ return output_queue.get()
eval/ocwcourses_eval_utils.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import numpy as np
3
+ import sympy
4
+ from sympy.core.sympify import SympifyError
5
+ from sympy.parsing.latex import parse_latex
6
+
7
+ import signal
8
+
9
+ INVALID_ANSWER = "[invalidanswer]"
10
+
11
+
12
+ class timeout:
13
+ def __init__(self, seconds=1, error_message="Timeout"):
14
+ self.seconds = seconds
15
+ self.error_message = error_message
16
+
17
+ def handle_timeout(self, signum, frame):
18
+ raise TimeoutError(self.error_message)
19
+
20
+ def __enter__(self):
21
+ signal.signal(signal.SIGALRM, self.handle_timeout)
22
+ signal.alarm(self.seconds)
23
+
24
+ def __exit__(self, type, value, traceback):
25
+ signal.alarm(0)
26
+
27
+
28
+ def normalize_numeric(s):
29
+ if s is None:
30
+ return None
31
+ for unit in [
32
+ "eV",
33
+ " \\mathrm{~kg} \\cdot \\mathrm{m} / \\mathrm{s}",
34
+ " kg m/s",
35
+ "kg*m/s",
36
+ "kg",
37
+ "m/s",
38
+ "m / s",
39
+ "m s^{-1}",
40
+ "\\text{ m/s}",
41
+ " \\mathrm{m/s}",
42
+ " \\text{ m/s}",
43
+ "g/mole",
44
+ "g/mol",
45
+ "\\mathrm{~g}",
46
+ "\\mathrm{~g} / \\mathrm{mol}",
47
+ "W",
48
+ "erg/s",
49
+ "years",
50
+ "year",
51
+ "cm",
52
+ ]:
53
+ s = s.replace(unit, "")
54
+ s = s.strip()
55
+ for maybe_unit in ["m", "s", "cm"]:
56
+ s = s.replace("\\mathrm{" + maybe_unit + "}", "")
57
+ s = s.replace("\\mathrm{~" + maybe_unit + "}", "")
58
+ s = s.strip()
59
+ s = s.strip("$")
60
+ try:
61
+ return float(eval(s))
62
+ except:
63
+ try:
64
+ expr = parse_latex(s)
65
+ if expr.is_number:
66
+ return float(expr)
67
+ return INVALID_ANSWER
68
+ except:
69
+ return INVALID_ANSWER
70
+
71
+
72
+ def numeric_equality(n1, n2, threshold=0.01):
73
+ if n1 is None or n2 is None:
74
+ return False
75
+ if np.isclose(n1, 0) or np.isclose(n2, 0) or np.isclose(n1 - n2, 0):
76
+ return np.abs(n1 - n2) < threshold * (n1 + n2) / 2
77
+ else:
78
+ return np.isclose(n1, n2)
79
+
80
+
81
+ def normalize_symbolic_equation(s):
82
+ if not isinstance(s, str):
83
+ return INVALID_ANSWER
84
+ if s.startswith("\\["):
85
+ s = s[2:]
86
+ if s.endswith("\\]"):
87
+ s = s[:-2]
88
+ s = s.replace("\\left(", "(")
89
+ s = s.replace("\\right)", ")")
90
+ s = s.replace("\\\\", "\\")
91
+ if s.startswith("$") or s.endswith("$"):
92
+ s = s.strip("$")
93
+ try:
94
+ maybe_expression = parse_latex(s)
95
+ if not isinstance(maybe_expression, sympy.core.relational.Equality):
96
+ # we have equation, not expression
97
+ return INVALID_ANSWER
98
+ else:
99
+ return maybe_expression
100
+ except:
101
+ return INVALID_ANSWER
102
+
103
+
104
+ class SymbolicMathMixin:
105
+ """
106
+ Methods useful for parsing mathematical expressions from text and determining equivalence of expressions.
107
+ """
108
+
109
+ SUBSTITUTIONS = [ # used for text normalize
110
+ ("an ", ""),
111
+ ("a ", ""),
112
+ (".$", "$"),
113
+ ("\\$", ""),
114
+ (r"\ ", ""),
115
+ (" ", ""),
116
+ ("mbox", "text"),
117
+ (",\\text{and}", ","),
118
+ ("\\text{and}", ","),
119
+ ("\\text{m}", "\\text{}"),
120
+ ]
121
+ REMOVED_EXPRESSIONS = [ # used for text normalizer
122
+ "square",
123
+ "ways",
124
+ "integers",
125
+ "dollars",
126
+ "mph",
127
+ "inches",
128
+ "ft",
129
+ "hours",
130
+ "km",
131
+ "units",
132
+ "\\ldots",
133
+ "sue",
134
+ "points",
135
+ "feet",
136
+ "minutes",
137
+ "digits",
138
+ "cents",
139
+ "degrees",
140
+ "cm",
141
+ "gm",
142
+ "pounds",
143
+ "meters",
144
+ "meals",
145
+ "edges",
146
+ "students",
147
+ "childrentickets",
148
+ "multiples",
149
+ "\\text{s}",
150
+ "\\text{.}",
151
+ "\\text{\ns}",
152
+ "\\text{}^2",
153
+ "\\text{}^3",
154
+ "\\text{\n}",
155
+ "\\text{}",
156
+ r"\mathrm{th}",
157
+ r"^\circ",
158
+ r"^{\circ}",
159
+ r"\;",
160
+ r",\!",
161
+ "{,}",
162
+ '"',
163
+ "\\dots",
164
+ ]
165
+
166
+ def normalize_tex(self, final_answer: str) -> str:
167
+ """
168
+ Normalizes a string representing a mathematical expression.
169
+ Used as a preprocessing step before parsing methods.
170
+
171
+ Copied character for character from appendix D of Lewkowycz et al. (2022)
172
+ """
173
+ final_answer = final_answer.split("=")[-1]
174
+
175
+ for before, after in self.SUBSTITUTIONS:
176
+ final_answer = final_answer.replace(before, after)
177
+ for expr in self.REMOVED_EXPRESSIONS:
178
+ final_answer = final_answer.replace(expr, "")
179
+
180
+ # Extract answer that is in LaTeX math, is bold,
181
+ # is surrounded by a box, etc.
182
+ final_answer = re.sub(r"(.*?)(\$)(.*?)(\$)(.*)", "$\\3$", final_answer)
183
+ final_answer = re.sub(r"(\\text\{)(.*?)(\})", "\\2", final_answer)
184
+ final_answer = re.sub(r"(\\textbf\{)(.*?)(\})", "\\2", final_answer)
185
+ final_answer = re.sub(r"(\\overline\{)(.*?)(\})", "\\2", final_answer)
186
+ final_answer = re.sub(r"(\\boxed\{)(.*)(\})", "\\2", final_answer)
187
+
188
+ # Normalize shorthand TeX:
189
+ # \fracab -> \frac{a}{b}
190
+ # \frac{abc}{bef} -> \frac{abc}{bef}
191
+ # \fracabc -> \frac{a}{b}c
192
+ # \sqrta -> \sqrt{a}
193
+ # \sqrtab -> sqrt{a}b
194
+ final_answer = re.sub(r"(frac)([^{])(.)", "frac{\\2}{\\3}", final_answer)
195
+ final_answer = re.sub(r"(sqrt)([^{])", "sqrt{\\2}", final_answer)
196
+ final_answer = final_answer.replace("$", "")
197
+
198
+ # Normalize 100,000 -> 100000
199
+ if final_answer.replace(",", "").isdigit():
200
+ final_answer = final_answer.replace(",", "")
201
+
202
+ return final_answer
203
+
204
+ def parse_tex(self, text: str, time_limit: int = 5) -> sympy.Basic:
205
+ """
206
+ Wrapper around `sympy.parse_text` that outputs a SymPy expression.
207
+ Typically, you want to apply `normalize_text` as a preprocessing step.
208
+ """
209
+ try:
210
+ with timeout(seconds=time_limit):
211
+ parsed = parse_latex(text)
212
+ except (
213
+ # general error handling: there is a long tail of possible sympy/other
214
+ # errors we would like to catch
215
+ Exception
216
+ ) as e:
217
+ print(f"failed to parse {text} with exception {e}")
218
+ return None
219
+
220
+ return parsed
221
+
222
+ def is_exp_equiv(self, x1: sympy.Basic, x2: sympy.Basic, time_limit=5) -> bool:
223
+ """
224
+ Determines whether two sympy expressions are equal.
225
+ """
226
+ try:
227
+ with timeout(seconds=time_limit):
228
+ try:
229
+ diff = x1 - x2
230
+ except (SympifyError, ValueError, TypeError) as e:
231
+ print(f"Couldn't subtract {x1} and {x2} with exception {e}")
232
+ return False
233
+
234
+ try:
235
+ if sympy.simplify(diff) == 0:
236
+ return True
237
+ else:
238
+ return False
239
+ except (SympifyError, ValueError, TypeError) as e:
240
+ print(f"Failed to simplify {x1}-{x2} with {e}")
241
+ return False
242
+ except TimeoutError as e:
243
+ print(f"Timed out comparing {x1} and {x2}")
244
+ return False
245
+ except Exception as e:
246
+ print(f"failed on unrecognized exception {e}")
247
+ return False
248
+
249
+ def is_tex_equiv(self, x1: str, x2: str, time_limit=5) -> bool:
250
+ """
251
+ Determines whether two (ideally normalized using `normalize_text`) TeX expressions are equal.
252
+
253
+ Does so by first checking for string exact-match, then falls back on sympy-equivalence,
254
+ following the (Lewkowycz et al. 2022) methodology.
255
+ """
256
+ if x1 == x2:
257
+ # don't resort to sympy if we have full string match, post-normalization
258
+ return True
259
+ else:
260
+ return False
261
+ parsed_x2 = self.parse_tex(x2)
262
+ if not parsed_x2:
263
+ # if our reference fails to parse into a Sympy object,
264
+ # we forgo parsing + checking our generated answer.
265
+ return False
266
+ return self.is_exp_equiv(self.parse_tex(x1), parsed_x2, time_limit=time_limit)
eval/python_executor.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import io
3
+ from contextlib import redirect_stdout
4
+ import pickle
5
+ import regex
6
+ import copy
7
+ from typing import Any, Dict, Optional
8
+ import multiprocess
9
+ from pebble import ProcessPool
10
+ from concurrent.futures import TimeoutError
11
+ from functools import partial
12
+ import traceback
13
+ from timeout_decorator import timeout
14
+
15
+
16
+ class GenericRuntime:
17
+ GLOBAL_DICT = {}
18
+ LOCAL_DICT = None
19
+ HEADERS = []
20
+
21
+ def __init__(self):
22
+ self._global_vars = copy.copy(self.GLOBAL_DICT)
23
+ self._local_vars = copy.copy(self.LOCAL_DICT) if self.LOCAL_DICT else None
24
+
25
+ for c in self.HEADERS:
26
+ self.exec_code(c)
27
+
28
+ def exec_code(self, code_piece: str) -> None:
29
+ if regex.search(r"(\s|^)?input\(", code_piece) or regex.search(
30
+ r"(\s|^)?os.system\(", code_piece
31
+ ):
32
+ raise RuntimeError()
33
+ exec(code_piece, self._global_vars)
34
+
35
+ def eval_code(self, expr: str) -> Any:
36
+ return eval(expr, self._global_vars)
37
+
38
+ def inject(self, var_dict: Dict[str, Any]) -> None:
39
+ for k, v in var_dict.items():
40
+ self._global_vars[k] = v
41
+
42
+ @property
43
+ def answer(self):
44
+ return self._global_vars["answer"]
45
+
46
+
47
+ class PythonExecutor:
48
+ def __init__(
49
+ self,
50
+ runtime: Optional[Any] = None,
51
+ get_answer_symbol: Optional[str] = None,
52
+ get_answer_expr: Optional[str] = None,
53
+ get_answer_from_stdout: bool = False,
54
+ ) -> None:
55
+ self.runtime = runtime if runtime else GenericRuntime()
56
+ self.answer_symbol = get_answer_symbol
57
+ self.answer_expr = get_answer_expr
58
+ self.get_answer_from_stdout = get_answer_from_stdout
59
+
60
+ def process_generation_to_code(self, gens: str):
61
+ batch_code = []
62
+ for g in gens:
63
+ multiline_comments = False
64
+ code = []
65
+ for line in g.split("\n"):
66
+ strip_line = line.strip()
67
+ if strip_line.startswith("#"):
68
+ line = line.split("#", 1)[0] + "# comments"
69
+ elif (
70
+ not multiline_comments
71
+ and strip_line.startswith('"""')
72
+ and strip_line.endswith('"""')
73
+ and len(strip_line) >= 6
74
+ ):
75
+ line = line.split('"""', 1)[0] + '"""comments"""'
76
+ elif not multiline_comments and strip_line.startswith('"""'):
77
+ multiline_comments = True
78
+ elif multiline_comments and strip_line.endswith('"""'):
79
+ multiline_comments = False
80
+ line = ""
81
+ if not multiline_comments:
82
+ code.append(line)
83
+ batch_code.append(code)
84
+ return batch_code
85
+
86
+ @staticmethod
87
+ def execute(
88
+ code,
89
+ get_answer_from_stdout=None,
90
+ runtime=None,
91
+ answer_symbol=None,
92
+ answer_expr=None,
93
+ timeout_length=10,
94
+ ):
95
+ try:
96
+ if get_answer_from_stdout:
97
+ program_io = io.StringIO()
98
+ with redirect_stdout(program_io):
99
+ timeout(timeout_length)(runtime.exec_code)("\n".join(code))
100
+ program_io.seek(0)
101
+ result = "".join(program_io.readlines()) # [-1]
102
+ elif answer_symbol:
103
+ timeout(timeout_length)(runtime.exec_code)("\n".join(code))
104
+ result = runtime._global_vars[answer_symbol]
105
+ elif answer_expr:
106
+ timeout(timeout_length)(runtime.exec_code)("\n".join(code))
107
+ result = timeout(timeout_length)(runtime.eval_code)(answer_expr)
108
+ else:
109
+ timeout(timeout_length)(runtime.exec_code)("\n".join(code[:-1]))
110
+ result = timeout(timeout_length)(runtime.eval_code)(code[-1])
111
+ concise_exec_info = ""
112
+ exec_info = ""
113
+ str(result)
114
+ pickle.dumps(result) # serialization check
115
+ except:
116
+ # traceback.print_exc()
117
+ result = ""
118
+ concise_exec_info = traceback.format_exc().split("\n")[-2]
119
+ exec_info = traceback.format_exc()
120
+ if (
121
+ get_answer_from_stdout
122
+ and "exec(code_piece, self._global_vars)" in exec_info
123
+ ):
124
+ exec_info = exec_info.split("exec(code_piece, self._global_vars)")[
125
+ -1
126
+ ].strip()
127
+ msg = []
128
+ for line in exec_info.split("\n"):
129
+ patt = regex.search(
130
+ r'(?P<start>.*)File "(?P<file>.*)", line (?P<lno>\d+), (?P<end>.*)',
131
+ line,
132
+ )
133
+ if patt is not None:
134
+ if "<module>" in patt.group("end"):
135
+ continue
136
+ fname = patt.group("file")
137
+ if "site-packages" in fname:
138
+ fname = f"site-packages{fname.split('site-packages', 1)[1]}"
139
+ line = f'{patt.group("start")}File "{fname}", {patt.group("end")}'
140
+ else:
141
+ line = f'{patt.group("start")}{patt.group("end")}'
142
+ else:
143
+ patt = regex.search(
144
+ r"(?P<start>.*)(?P<file>/.*site-packages/.*\.py)(?P<end>.*)",
145
+ line,
146
+ )
147
+ if patt is not None:
148
+ line = f'{patt.group("start")}site-packages{patt.group("file").split("site-packages", 1)[1]}{patt.group("end")}'
149
+ msg.append(line)
150
+ exec_info = "\n".join(msg)
151
+ return result, concise_exec_info, exec_info
152
+
153
+ def apply(self, code):
154
+ return self.batch_apply([code])[0]
155
+
156
+ def batch_apply(self, batch_code):
157
+ all_code_snippets = self.process_generation_to_code(batch_code)
158
+ all_exec_results = []
159
+ executor = partial(
160
+ self.execute,
161
+ get_answer_from_stdout=self.get_answer_from_stdout,
162
+ runtime=self.runtime,
163
+ answer_symbol=self.answer_symbol,
164
+ answer_expr=self.answer_expr,
165
+ timeout_length=10,
166
+ )
167
+ with ProcessPool(max_workers=multiprocess.cpu_count()) as pool:
168
+ iterator = pool.map(executor, all_code_snippets, timeout=10).result()
169
+
170
+ while True:
171
+ try:
172
+ result = next(iterator)
173
+ all_exec_results.append(result)
174
+ except StopIteration:
175
+ break
176
+ except TimeoutError as error:
177
+ all_exec_results.append(("", "Timeout Error", "Timeout Error"))
178
+ except Exception as error:
179
+ print(error)
180
+ exit()
181
+
182
+ batch_results = []
183
+ for code, (result, concise_exec_info, exec_info) in zip(
184
+ all_code_snippets, all_exec_results
185
+ ):
186
+ metadata = {
187
+ "code": code,
188
+ "exec_result": result,
189
+ "concise_exec_info": concise_exec_info,
190
+ "exec_info": exec_info,
191
+ }
192
+ batch_results.append((result, metadata))
193
+ return batch_results
main.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ from huggingface_hub import hf_hub_download
4
+ import json
5
+ from build_cache import cache
6
+ from compute_perp import Evaluator as PPLEvaluator
7
+ from compute_sc import SCEvaluator
8
+ from compute_rpc import RPCEvaluator
9
+
10
+ REPOID = {
11
+ "MATH": "WNJXYK/MATH-Reasoning-Paths",
12
+ "MathOdyssey": "WNJXYK/MathOdyssey-Reasoning-Paths",
13
+ "AIME": "WNJXYK/AIME_1983_2024-Reasoning-Paths",
14
+ "OlympiadBench": "WNJXYK/OlympiadBench-Reasoning-Paths"
15
+ }
16
+
17
+ EVALUATOR_MAP = {
18
+ "PPL": PPLEvaluator,
19
+ "SC": SCEvaluator,
20
+ "RPC": RPCEvaluator
21
+ }
22
+
23
+ args = argparse.ArgumentParser()
24
+ args.add_argument("--dataset", type=str, choices=["MATH", "MathOdyssey", "AIME", "OlympiadBench"], default="MathOdyssey")
25
+ args.add_argument("--model", type=str, choices=["Deepseek-Math-RL-7B", "InternLM2-Math-Plus-1.8B", "InternLM2-Math-Plus-7B"], default="InternLM2-Math-Plus-7B")
26
+ args.add_argument("--K", type=int, default=128)
27
+ args.add_argument("--method", type=str, default="PPL", choices=["PPL", "SC", "RPC"])
28
+ args = args.parse_args()
29
+
30
+ repo_id = REPOID[args.dataset]
31
+ filename = args.model + ".json"
32
+
33
+ # Download sampled reasoning paths from Hugging Face
34
+ try:
35
+ file_path = hf_hub_download(repo_id=repo_id, filename=filename, repo_type="dataset")
36
+ with open(file_path, 'r', encoding='utf-8') as f:
37
+ json_file = json.load(f)
38
+ print(f"Load sampled reasoning paths {filename} from {repo_id} successfully!")
39
+ except Exception as e:
40
+ print(f"Failed to load sampled reasoning paths {filename} from {repo_id}: {e}")
41
+
42
+ # Build cache for checking equality
43
+ cache_path = file_path.replace(".json", ".cache.json")
44
+ cache(json_file, cache_path)
45
+ with open(cache_path, 'r', encoding='utf-8') as f:
46
+ cache_file = json.load(f)
47
+
48
+ # Run!
49
+ results = EVALUATOR_MAP[args.method]().solve(json_file=json_file, cache_file=cache_file, K=args.K)
50
+
51
+ # Report results
52
+ result_str = f"{args.method} {args.dataset} {args.model} {args.K} {results}"
53
+ with open("results.txt", "a") as f:
54
+ f.write(result_str + "\n")
55
+ print(result_str)
metrics.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ EPS = 1e-10
4
+ NLLEPS = 1e-6
5
+
6
+ def compute_maximum_metrics(predicts, n_bins=10):
7
+ n = len(predicts)
8
+ acc, cnf, siz = np.zeros(n_bins), np.zeros(n_bins), np.zeros(n_bins)
9
+ brier_score = []
10
+ negative_ll = []
11
+
12
+ for idx in range(n):
13
+ m = len(predicts[idx])
14
+
15
+ # Compute maximum probabilities and corresponding counts within each problem
16
+ max_prob, max_prob_counts = -1e6, 0
17
+ for i in range(m):
18
+ ans, prob, flag = predicts[idx][i]
19
+ if prob > max_prob:
20
+ max_prob, max_prob_counts = prob, 0
21
+ if prob >= max_prob - EPS:
22
+ max_prob_counts += 1
23
+ # print(max_prob, max_prob_counts)
24
+ # Compute the maximum accuracy for each problem as well as the ECE metric
25
+ vote_acc = 0
26
+ for i in range(m):
27
+ ans, prob, flag = predicts[idx][i]
28
+ if prob < max_prob:
29
+ continue
30
+ if np.isnan(prob):
31
+ continue
32
+ if flag:
33
+ vote_acc += 1.0 / max_prob_counts
34
+ # Compute Expected Calibration Error
35
+ for cur in range(n_bins):
36
+ lower, upper = cur / n_bins, (cur + 1) / n_bins
37
+ if lower < max_prob <= upper:
38
+ if flag:
39
+ acc[cur] += 1.0 / max_prob_counts
40
+ cnf[cur] += prob / max_prob_counts
41
+ siz[cur] += 1.0 / max_prob_counts
42
+
43
+ # Compute Brier Score
44
+ brier_score.append((vote_acc - max_prob) ** 2)
45
+
46
+ # Compute Negative Likelihhod
47
+ cliped_max_prob = max(min(max_prob, 1 - NLLEPS), NLLEPS)
48
+ cliped_vote_acc = max(min(vote_acc, 1 - NLLEPS), NLLEPS)
49
+ negative_ll.append(
50
+ -np.log(cliped_max_prob) * cliped_vote_acc
51
+ - np.log(1 - cliped_max_prob) * (1 - cliped_vote_acc)
52
+ )
53
+
54
+ # Turn each metrics into values
55
+ ece = 0
56
+ for cur in range(n_bins):
57
+ if siz[cur] > 0:
58
+ acc[cur] = acc[cur] / siz[cur]
59
+ cnf[cur] = cnf[cur] / siz[cur]
60
+ ece += siz[cur] * np.abs(acc[cur] - cnf[cur])
61
+ # print(siz[cur], acc[cur], cnf[cur])
62
+ ece = ece / sum(siz)
63
+ bs = np.mean(brier_score)
64
+ nll = np.mean(negative_ll)
65
+
66
+ return (ece, bs, nll), (acc, cnf, siz)
67
+
68
+
69
+ def compute_average_metrics(predicts, n_bins=10):
70
+ n = len(predicts)
71
+ acc, cnf, siz = np.zeros(n_bins), np.zeros(n_bins), np.zeros(n_bins)
72
+ brier_score = []
73
+ negative_ll = []
74
+
75
+ for idx in range(n):
76
+ m = len(predicts[idx])
77
+
78
+ problem_brier_score = []
79
+ problem_negative_ll = []
80
+ for i in range(m):
81
+ ans, prob, flag = predicts[idx][i]
82
+ # Compute Expected Calibration Error
83
+ for cur in range(n_bins):
84
+ lower, upper = cur / n_bins, (cur + 1) / n_bins
85
+ if lower < prob <= upper:
86
+ if flag:
87
+ acc[cur] += 1.0 / m
88
+ cnf[cur] += prob / m
89
+ siz[cur] += 1.0 / m
90
+
91
+ # Compute Brier Score
92
+ problem_brier_score.append(((1 if flag else 0) - prob) ** 2)
93
+
94
+ # Compute Negative Likelyhood
95
+ cliped_max_prob = max(min(prob, 1 - NLLEPS), NLLEPS)
96
+ cliped_vote_acc = max(min(1 if flag else 0, 1 - NLLEPS), NLLEPS)
97
+ problem_negative_ll.append(
98
+ -np.log(cliped_max_prob) * cliped_vote_acc
99
+ - np.log(1 - cliped_max_prob) * (1 - cliped_vote_acc)
100
+ )
101
+
102
+ brier_score.append(np.mean(problem_brier_score))
103
+ negative_ll.append(np.mean(problem_negative_ll))
104
+
105
+ ece = 0
106
+ for cur in range(n_bins):
107
+ if siz[cur] > 0:
108
+ acc[cur] = acc[cur] / siz[cur]
109
+ cnf[cur] = cnf[cur] / siz[cur]
110
+ ece += siz[cur] * np.abs(acc[cur] - cnf[cur])
111
+ ece = ece / sum(siz)
112
+ bs = np.mean(brier_score)
113
+ nll = np.mean(negative_ll)
114
+
115
+ return (ece, bs, nll), (acc, cnf, siz)
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ antlr4-python3-runtime==4.13.2
2
+ datasets==4.1.1
3
+ Fraction==2.2.0
4
+ huggingface-hub==0.35.3
5
+ multiprocess==0.70.16
6
+ numpy==2.0.2
7
+ regex==2025.9.18
8
+ scipy==1.13.1
9
+ sympy==1.14.0
10
+ tqdm==4.67.1