Upload 16 files
Browse files- __init__.py +0 -0
- all_exps.sh +10 -0
- app.py +174 -0
- build_cache.py +46 -0
- compute_perp.py +144 -0
- compute_rpc.py +136 -0
- compute_sc.py +108 -0
- data_processing/answer_extraction.py +362 -0
- data_processing/process_utils.py +191 -0
- eval/eval_script.py +190 -0
- eval/eval_utils.py +400 -0
- eval/ocwcourses_eval_utils.py +266 -0
- eval/python_executor.py +193 -0
- main.py +55 -0
- metrics.py +115 -0
- requirements.txt +10 -0
__init__.py
ADDED
File without changes
|
all_exps.sh
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
for model in InternLM2-Math-Plus-7B Deepseek-Math-RL-7B InternLM2-Math-Plus-1.8B; do
|
2 |
+
for method in PPL SC RPC; do
|
3 |
+
python main.py --dataset MATH --model $model --method $method --K 64
|
4 |
+
done
|
5 |
+
for dataset in MathOdyssey AIME OlympiadBench; do
|
6 |
+
for method in PPL SC RPC; do
|
7 |
+
python main.py --dataset $dataset --model $model --method $method --K 128
|
8 |
+
done
|
9 |
+
done
|
10 |
+
done
|
app.py
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import json, os
|
3 |
+
from huggingface_hub import hf_hub_download
|
4 |
+
from compute_perp import prep_evaluator, numberic_compare, check_equal
|
5 |
+
from compute_sc import sc_evaluator
|
6 |
+
from compute_rpc import wpc_evaluator
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
def greet(name):
|
10 |
+
return "Hello " + name + "!!"
|
11 |
+
|
12 |
+
json_file = {"predict": [], "answer": [], "completion": [], "mean_logprob": [], "prompt": []}
|
13 |
+
|
14 |
+
demo = gr.Blocks()
|
15 |
+
with demo:
|
16 |
+
paper_title = gr.HTML("""<div align='center'><h1>[NeurIPS 2025] A Theoretical Study on Bridging Internal Probability and Self-Consistency for LLM Reasoning</h1></div>""")
|
17 |
+
paper_info = gr.HTML("""<div align="center"><h3><a href="https://arxiv.org/pdf/2502.00511">📄 [Paper]</a> • <a href="https://wnjxyk.github.io/RPC">🌐 [Project]</a> • <a href="#" onclick="document.getElementById('bibtex-popup').style.display='block';">📚 [BibTeX]</a><h3>
|
18 |
+
<div id="bibtex-popup" style="display:none; position:fixed; top:50%; left:50%; transform:translate(-50%, -50%); background:white; padding:20px; border:1px solid #ccc; box-shadow:0 0 10px rgba(0,0,0,0.2); z-index:1000; max-width:80%; overflow:auto;">
|
19 |
+
<pre style="white-space:pre-wrap; font-size:12px; text-align:left;">@inproceedings{zhou24theoretical,
|
20 |
+
author = {Zhou, Zhi and Tan, Yuhao and Li, Zenan and Yao, Yuan and Guo, Lan-Zhe and Li, Yu-Feng and Ma, Xiaoxing},
|
21 |
+
title = {A Theorecial Study on Bridging Internal Probability and Self-Consistency for LLM Reasoning},
|
22 |
+
booktitle = {Advances in Neural Information Processing Systems},
|
23 |
+
year = {2025},
|
24 |
+
}</pre>
|
25 |
+
<button onclick="document.getElementById('bibtex-popup').style.display='none';" style="margin-top:10px; padding:5px 10px;">Close</button>
|
26 |
+
</div></div>""")
|
27 |
+
|
28 |
+
with gr.Column():
|
29 |
+
gr.Markdown("## 1. Experimental Settings")
|
30 |
+
with gr.Row():
|
31 |
+
dataset = gr.Dropdown(
|
32 |
+
choices=["MATH", "MathOdyssey", "AIME", "OlympiadBench"],
|
33 |
+
value="MathOdyssey",
|
34 |
+
label="Dataset",
|
35 |
+
interactive=True
|
36 |
+
)
|
37 |
+
model = gr.Dropdown(
|
38 |
+
choices=["Deepseek-Math-RL-7B", "InternLM2-Math-Plus-1.8B", "InternLM2-Math-Plus-7B"],
|
39 |
+
value="InternLM2-Math-Plus-7B",
|
40 |
+
label="Model",
|
41 |
+
interactive=True
|
42 |
+
)
|
43 |
+
k_value = gr.Dropdown(
|
44 |
+
choices=[8, 16, 32, 64, 128],
|
45 |
+
value=128,
|
46 |
+
label="K (Number of Sampled Reasoning Paths)",
|
47 |
+
interactive=True
|
48 |
+
)
|
49 |
+
seed = gr.Number(
|
50 |
+
label="Random Seed",
|
51 |
+
value=998244353,
|
52 |
+
step=1,
|
53 |
+
interactive=True
|
54 |
+
)
|
55 |
+
def update_k_value(dataset_choice):
|
56 |
+
if dataset_choice == "MATH":
|
57 |
+
return gr.update(choices=[8, 16, 32, 64], value=min(64, k_value.value))
|
58 |
+
else:
|
59 |
+
return gr.update(choices=[8, 16, 32, 64, 128], value=k_value.value)
|
60 |
+
dataset.change(fn=update_k_value, inputs=dataset, outputs=k_value)
|
61 |
+
load_btn = gr.Button("Load All Problems")
|
62 |
+
|
63 |
+
with gr.Column(visible=False) as content_column:
|
64 |
+
gr.Markdown("## 2. Problem Selection")
|
65 |
+
with gr.Group():
|
66 |
+
data_info = gr.Textbox(label="Experiment Info", value="")
|
67 |
+
problem_id = gr.Dropdown(
|
68 |
+
choices=[1],
|
69 |
+
value=1,
|
70 |
+
label="Problem ID (We removed (1) problems that were unlikely to be answered correctly using any of the methods; (2) easy problems)",
|
71 |
+
interactive=True
|
72 |
+
)
|
73 |
+
with gr.Row():
|
74 |
+
problem_prompt = gr.Textbox(label="Problem Prompt", value="", scale=3)
|
75 |
+
problem_answer = gr.Textbox(label="Problem Answer", value="", scale=1)
|
76 |
+
def update_problem_info(problem_id):
|
77 |
+
return gr.update(value=json_file['prompt'][problem_id-1], label=f"Problem#{problem_id} Prompt"), gr.update(value=json_file['answer'][problem_id-1], label=f"Problem#{problem_id} Answer")
|
78 |
+
problem_id.change(fn=update_problem_info, inputs=problem_id, outputs=[problem_prompt, problem_answer])
|
79 |
+
run_btn = gr.Button("Run Evaluation")
|
80 |
+
|
81 |
+
with gr.Column(visible=False) as result_column:
|
82 |
+
gr.Markdown("## 3. Experiment Result")
|
83 |
+
with gr.Row():
|
84 |
+
with gr.Column():
|
85 |
+
gr.Markdown("### PPL (Internal Probability)")
|
86 |
+
ppl_result = gr.Markdown()
|
87 |
+
with gr.Column():
|
88 |
+
gr.Markdown("### SC (Self-Consistency)")
|
89 |
+
sc_result = gr.Markdown(value="")
|
90 |
+
with gr.Column():
|
91 |
+
gr.Markdown("### RPC (Ours)")
|
92 |
+
rpc_result = gr.Markdown(value="")
|
93 |
+
|
94 |
+
def get_available_problems():
|
95 |
+
global json_file
|
96 |
+
answer = np.array(json_file["accuracy"]).mean(axis=0)
|
97 |
+
# print(answer.shape)
|
98 |
+
# Select indices where the answer is greater than 0.3
|
99 |
+
available_indices = np.where((answer > 0.3) & (answer < 0.5))[0]
|
100 |
+
available_indices = available_indices + 1
|
101 |
+
# print(available_indices)
|
102 |
+
return available_indices.tolist()
|
103 |
+
|
104 |
+
|
105 |
+
def load(dataset, model, k_value, seed):
|
106 |
+
try:
|
107 |
+
repo_id = {
|
108 |
+
"MATH": "WNJXYK/MATH-Reasoning-Paths",
|
109 |
+
"MathOdyssey": "WNJXYK/MathOdyssey-Reasoning-Paths",
|
110 |
+
"AIME": "WNJXYK/AIME_1983_2024-Reasoning-Paths",
|
111 |
+
"OlympiadBench": "WNJXYK/OlympiadBench-Reasoning-Paths"
|
112 |
+
}[dataset]
|
113 |
+
filename = f"{model}.json"
|
114 |
+
|
115 |
+
yield f"Downloading sampled reasoning paths from Hugging Face {repo_id}...", gr.update(visible=False), gr.update(), gr.update()
|
116 |
+
file_path = hf_hub_download(repo_id=repo_id, filename=filename, repo_type="dataset")
|
117 |
+
|
118 |
+
global json_file
|
119 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
120 |
+
json_file = json.load(f)
|
121 |
+
clist = get_available_problems()
|
122 |
+
# yield "Removing downloaded file..."
|
123 |
+
# print(file_path)
|
124 |
+
os.remove(file_path)
|
125 |
+
|
126 |
+
yield "Loading complete! You can now select a problem ID.", gr.update(visible=True), gr.update(value=f"Dataset: {dataset}\tModel: {model}\tK: {k_value}\tSeed: {seed}"), gr.update(choices=clist, value=clist[0])
|
127 |
+
|
128 |
+
except Exception as e:
|
129 |
+
yield f"Error: {str(e)}"
|
130 |
+
|
131 |
+
def res_to_str(correct, answers, topk=10):
|
132 |
+
answers = sorted(answers, key=lambda x: x[1], reverse=True)
|
133 |
+
response = "| # | Answer | Probability | Correct |\n|---|--------|------------|--------|\n"
|
134 |
+
for i in range(min(len(answers), topk)):
|
135 |
+
correct_mark = "✅" if answers[i][2] else "❌"
|
136 |
+
wrapped_answer = answers[i][0] if len(answers[i][0]) <= 10 else answers[i][0][:10] + "..."
|
137 |
+
response += f"| Top-{i+1} | {wrapped_answer} | {answers[i][1]:.2f} | {correct_mark} |\n"
|
138 |
+
return response
|
139 |
+
|
140 |
+
def evaluate(problem_id):
|
141 |
+
ppl_correct, ppl_answers = prep_evaluator(
|
142 |
+
json_file["predict"][problem_id-1],
|
143 |
+
json_file["completion"][problem_id-1],
|
144 |
+
json_file["mean_logprob"][problem_id-1],
|
145 |
+
json_file["answer"][problem_id-1],
|
146 |
+
numberic_compare,
|
147 |
+
check_equal
|
148 |
+
)
|
149 |
+
|
150 |
+
sc_correct, sc_answers = sc_evaluator(
|
151 |
+
json_file["predict"][problem_id-1],
|
152 |
+
json_file["completion"][problem_id-1],
|
153 |
+
json_file["mean_logprob"][problem_id-1],
|
154 |
+
json_file["answer"][problem_id-1],
|
155 |
+
numberic_compare,
|
156 |
+
check_equal
|
157 |
+
)
|
158 |
+
|
159 |
+
rpc_correct, rpc_answers = wpc_evaluator(
|
160 |
+
json_file["predict"][problem_id-1],
|
161 |
+
json_file["completion"][problem_id-1],
|
162 |
+
json_file["mean_logprob"][problem_id-1],
|
163 |
+
json_file["answer"][problem_id-1],
|
164 |
+
numberic_compare,
|
165 |
+
check_equal
|
166 |
+
)
|
167 |
+
|
168 |
+
return gr.update(visible=True), gr.update(value=res_to_str(ppl_correct, ppl_answers)), gr.update(value=res_to_str(sc_correct, sc_answers)), gr.update(value=res_to_str(rpc_correct, rpc_answers))
|
169 |
+
|
170 |
+
load_btn.click(fn=load, inputs=[dataset, model, k_value, seed],outputs=[load_btn, content_column, data_info, problem_id], show_progress="inside")
|
171 |
+
run_btn.click(fn=evaluate, inputs=problem_id, outputs=[result_column, ppl_result, sc_result, rpc_result])
|
172 |
+
|
173 |
+
if __name__ == "__main__":
|
174 |
+
demo.launch()
|
build_cache.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from compute_perp import check_equal
|
2 |
+
import multiprocessing, json, os, time
|
3 |
+
|
4 |
+
def solve(predict, answer):
|
5 |
+
cache_dict = {}
|
6 |
+
m = len(predict)
|
7 |
+
|
8 |
+
for i in range(m):
|
9 |
+
key = str(predict[i]) + "<##>" + str(answer)
|
10 |
+
rev_key = str(answer) + "<##>" + str(predict[i])
|
11 |
+
if key in cache_dict or rev_key in cache_dict:
|
12 |
+
continue
|
13 |
+
val = check_equal(predict[i], answer)
|
14 |
+
cache_dict[key] = val
|
15 |
+
cache_dict[rev_key] = val
|
16 |
+
|
17 |
+
for i in range(m):
|
18 |
+
for j in range(m):
|
19 |
+
key = str(predict[i]) + "<##>" + str(predict[j])
|
20 |
+
rev_key = str(predict[j]) + "<##>" + str(predict[i])
|
21 |
+
if key in cache_dict or rev_key in cache_dict:
|
22 |
+
continue
|
23 |
+
val = check_equal(predict[i], predict[j])
|
24 |
+
cache_dict[key] = val
|
25 |
+
cache_dict[rev_key] = val
|
26 |
+
|
27 |
+
return cache_dict
|
28 |
+
|
29 |
+
def cache(data, cache_path):
|
30 |
+
if os.path.exists(cache_path):
|
31 |
+
print(f"Cache file {cache_path} exists, skip!")
|
32 |
+
return
|
33 |
+
start_time = time.time()
|
34 |
+
predicts = data["predict"]
|
35 |
+
answers = data["answer"]
|
36 |
+
n = len(predicts)
|
37 |
+
cache_dict = {}
|
38 |
+
with multiprocessing.Pool() as pool:
|
39 |
+
results = pool.starmap(
|
40 |
+
solve, [(predicts[i], answers[i]) for i in range(n)]
|
41 |
+
)
|
42 |
+
for result in results:
|
43 |
+
cache_dict.update(result)
|
44 |
+
with open(cache_path, "w") as fw:
|
45 |
+
json.dump(cache_dict, fw)
|
46 |
+
print(f"Cache file {cache_path} built in {time.time() - start_time:.2f}S")
|
compute_perp.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import metrics
|
3 |
+
import argparse
|
4 |
+
import numpy as np
|
5 |
+
import multiprocessing
|
6 |
+
from tqdm import trange
|
7 |
+
import signal, functools
|
8 |
+
import re, os, sys, random, time
|
9 |
+
from fraction import Fraction
|
10 |
+
from data_processing.answer_extraction import *
|
11 |
+
from functools import lru_cache
|
12 |
+
from eval.eval_script import *
|
13 |
+
MAX_INT = sys.maxsize
|
14 |
+
INVALID_ANS = "[Invalid]"
|
15 |
+
INF = 1e9
|
16 |
+
|
17 |
+
__all__ = [
|
18 |
+
"check_equal",
|
19 |
+
"check_equal_without_timeout",
|
20 |
+
"numberic_compare",
|
21 |
+
"Evaluator",
|
22 |
+
]
|
23 |
+
|
24 |
+
@lru_cache(maxsize=1000000)
|
25 |
+
def check_equal_without_timeout(ans_1, ans_2):
|
26 |
+
return math_equal(ans_1, ans_2)
|
27 |
+
|
28 |
+
def check_equal(ans_1, ans_2, cache_dict=None):
|
29 |
+
try:
|
30 |
+
if cache_dict is not None:
|
31 |
+
key = str(ans_1) + "<##>" + str(ans_2)
|
32 |
+
if key in cache_dict: return cache_dict[key]
|
33 |
+
print("Miss")
|
34 |
+
return check_equal_without_timeout(ans_1, ans_2)
|
35 |
+
except TimeoutError as e:
|
36 |
+
return False
|
37 |
+
|
38 |
+
def numberic_compare(ai, aj, ci, cj, cache_dict=None):
|
39 |
+
return check_equal(ai, aj, cache_dict)
|
40 |
+
|
41 |
+
def prep_evaluator(
|
42 |
+
predicts, completions, perplexities, answer, equal_func, check_equal
|
43 |
+
):
|
44 |
+
m = len(predicts)
|
45 |
+
|
46 |
+
# Compute maximum probability
|
47 |
+
max_perplexity = -INF
|
48 |
+
max_perplexity_count = 0.0
|
49 |
+
for i in range(m):
|
50 |
+
if perplexities[i] > max_perplexity:
|
51 |
+
max_perplexity = perplexities[i]
|
52 |
+
max_perplexity_count = 0.0
|
53 |
+
if perplexities[i] >= max_perplexity:
|
54 |
+
max_perplexity_count += 1.0
|
55 |
+
|
56 |
+
# Compute accuracy
|
57 |
+
correct, answers = 0, []
|
58 |
+
for i in range(m):
|
59 |
+
ans_i = predicts[i]
|
60 |
+
answers.append([ans_i, np.exp(perplexities[i]), check_equal(ans_i, answer)])
|
61 |
+
if perplexities[i] < max_perplexity: continue
|
62 |
+
if check_equal(ans_i, answer):
|
63 |
+
correct += 1.0 / max_perplexity_count
|
64 |
+
|
65 |
+
return correct, answers
|
66 |
+
|
67 |
+
class Evaluator:
|
68 |
+
def __init__(self):
|
69 |
+
self.name = "Perplexity"
|
70 |
+
|
71 |
+
def process(self, json_file, cache_file, equal_func, evaluator, K, seed=0):
|
72 |
+
# with open(file_path, 'r', encoding='utf-8') as f:
|
73 |
+
# results = json.load(f)
|
74 |
+
results = json_file
|
75 |
+
n = len(results["predict"])
|
76 |
+
m = len(results["predict"][0])
|
77 |
+
indices = list(range(m))
|
78 |
+
random.seed(seed)
|
79 |
+
random.shuffle(indices)
|
80 |
+
indices = indices[: K]
|
81 |
+
|
82 |
+
if cache_file is not None:
|
83 |
+
def cache_equal_func(ai, aj, ci, cj):
|
84 |
+
return equal_func(ai, aj, ci, cj, cache_file)
|
85 |
+
def cache_check_equal(ai, aj):
|
86 |
+
return check_equal(ai, aj, cache_file)
|
87 |
+
else:
|
88 |
+
cache_equal_func = equal_func
|
89 |
+
cache_check_equal = check_equal
|
90 |
+
|
91 |
+
|
92 |
+
predicts, completions, perplexities, answers = [], [], [], []
|
93 |
+
for i in range(0, n):
|
94 |
+
predicts.append([results["predict"][i][j] for j in indices])
|
95 |
+
completions.append([results["completion"][i][j] for j in indices])
|
96 |
+
perplexities.append([results["mean_logprob"][i][j] for j in indices])
|
97 |
+
answers.append(results["answer"][i])
|
98 |
+
n = len(predicts)
|
99 |
+
|
100 |
+
start_time = time.time()
|
101 |
+
outputs = []
|
102 |
+
for idx in trange(n):
|
103 |
+
res = evaluator(
|
104 |
+
predicts[idx],
|
105 |
+
completions[idx],
|
106 |
+
perplexities[idx],
|
107 |
+
answers[idx],
|
108 |
+
cache_equal_func,
|
109 |
+
cache_check_equal,
|
110 |
+
)
|
111 |
+
outputs.append(res)
|
112 |
+
print(f"Running Time with Single Process Mode with Seed #{seed}: {time.time() - start_time:.2f}S")
|
113 |
+
|
114 |
+
for i in trange(n):
|
115 |
+
m = len(outputs[i][1])
|
116 |
+
for j in range(m):
|
117 |
+
ans, prob, flag = outputs[i][1][j]
|
118 |
+
maximum, max_bins = metrics.compute_maximum_metrics([x[1] for x in outputs])
|
119 |
+
average, avg_bins = metrics.compute_average_metrics([x[1] for x in outputs])
|
120 |
+
accs = np.mean([x[0] for x in outputs])
|
121 |
+
return accs * 100.0, maximum, average, max_bins, avg_bins
|
122 |
+
|
123 |
+
def worker(self, args):
|
124 |
+
json_file, cache_file, K, seed = args
|
125 |
+
acc, maximum, average, max_bins, avg_bins = self.process(
|
126 |
+
json_file=json_file,
|
127 |
+
cache_file=cache_file,
|
128 |
+
equal_func=numberic_compare,
|
129 |
+
evaluator=prep_evaluator,
|
130 |
+
K=K,
|
131 |
+
seed=seed
|
132 |
+
)
|
133 |
+
return acc, maximum, average
|
134 |
+
|
135 |
+
def solve(self, json_file, cache_file=None, repeats=10, K=128):
|
136 |
+
accs, maxs, avgs = [], [], []
|
137 |
+
with multiprocessing.Pool() as pool:
|
138 |
+
results = pool.map(self.worker, [(json_file, cache_file, K, seed) for seed in range(repeats)])
|
139 |
+
accs, maxs, _ = zip(*results)
|
140 |
+
accs, maxs = np.array(accs), np.array(maxs)
|
141 |
+
return {
|
142 |
+
"Accuracy": f"{accs.mean():.2f} ± {accs.std():.2f}",
|
143 |
+
"ECE": f"{maxs[:, 0].mean() * 100.0:.2f} ± {maxs[:, 0].std() * 100.0:.2f}",
|
144 |
+
}
|
compute_rpc.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import metrics
|
3 |
+
import argparse
|
4 |
+
import numpy as np
|
5 |
+
import multiprocessing
|
6 |
+
from tqdm import trange
|
7 |
+
import signal, functools
|
8 |
+
from scipy.special import gamma
|
9 |
+
import re, os, sys, random, time
|
10 |
+
from scipy.stats import weibull_min
|
11 |
+
from scipy.optimize import minimize
|
12 |
+
from fraction import Fraction
|
13 |
+
from data_processing.answer_extraction import *
|
14 |
+
from eval.eval_script import *
|
15 |
+
from compute_perp import Evaluator, numberic_compare
|
16 |
+
from compute_sc import DSU
|
17 |
+
MAX_INT = sys.maxsize
|
18 |
+
INVALID_ANS = "[Invalid]"
|
19 |
+
|
20 |
+
#### Reasoning Pruning Module: Model probability with Weibull distribution ####
|
21 |
+
|
22 |
+
def weibull_pdf(x, k, lam):
|
23 |
+
return (k / lam) * (x / lam) ** (k - 1) * np.exp(-((x / lam) ** k))
|
24 |
+
|
25 |
+
def weibull_mean(k, lam):
|
26 |
+
return lam * gamma(1 + 1 / k)
|
27 |
+
|
28 |
+
def mixture_pdf(x, w1, k1, lam1, k2, lam2):
|
29 |
+
return w1 * weibull_pdf(x, k1, lam1) + (1 - w1) * weibull_pdf(x, k2, lam2)
|
30 |
+
|
31 |
+
def neg_log_likelihood(params, data):
|
32 |
+
w1, k1, lam1, k2, lam2 = params
|
33 |
+
pdf_vals = mixture_pdf(data, w1, k1, lam1, k2, lam2)
|
34 |
+
return -np.sum(np.log(pdf_vals))
|
35 |
+
|
36 |
+
def calculate_membership_probabilities(data, w1, k1, lam1, k2, lam2):
|
37 |
+
pdf1 = weibull_pdf(data, k1, lam1)
|
38 |
+
pdf2 = weibull_pdf(data, k2, lam2)
|
39 |
+
prob1 = w1 * pdf1 / (w1 * pdf1 + (1 - w1) * pdf2)
|
40 |
+
prob2 = 1 - prob1
|
41 |
+
return prob1, prob2
|
42 |
+
|
43 |
+
### Perplexity Consistency Module: Bridging the probability with self-consistency ####
|
44 |
+
|
45 |
+
def wpc_evaluator(predicts, completions, perplexities, answer, equal_func, check_equal):
|
46 |
+
m = len(predicts)
|
47 |
+
dsu = DSU(m)
|
48 |
+
probas = [np.exp(perplexities[i]) for i in range(m)]
|
49 |
+
mean_proba = np.mean(probas)
|
50 |
+
|
51 |
+
# Model probability with Weibull distribution
|
52 |
+
initial_guess = [0.5, 1.0, 1.0, 1.5, 2.0]
|
53 |
+
result = minimize(
|
54 |
+
neg_log_likelihood,
|
55 |
+
initial_guess,
|
56 |
+
args=(probas,),
|
57 |
+
bounds=[(0.2, 0.8), (0.01, None), (0.01, None), (0.01, None), (0.01, None)],
|
58 |
+
)
|
59 |
+
w1, k1, lam1, k2, lam2 = result.x
|
60 |
+
if weibull_mean(k1, lam1) < weibull_mean(k2, lam2):
|
61 |
+
k1, lam1, k2, lam2 = k2, lam2, k1, lam1
|
62 |
+
w1 = 1 - w1
|
63 |
+
|
64 |
+
# Pruning reasoning paths with low probabilities
|
65 |
+
remove = 0
|
66 |
+
for i in range(m):
|
67 |
+
completion_i = completions[i]
|
68 |
+
logprob_i = perplexities[i]
|
69 |
+
proba_i = np.exp(logprob_i)
|
70 |
+
p1, p2 = calculate_membership_probabilities(proba_i, w1, k1, lam1, k2, lam2)
|
71 |
+
if p1 < p2 and proba_i < mean_proba:
|
72 |
+
proba_i = 0
|
73 |
+
remove += 1
|
74 |
+
else:
|
75 |
+
dsu.attr[i][completion_i] = set([proba_i])
|
76 |
+
|
77 |
+
# Combining internal probabilities and self-consistency
|
78 |
+
for i in range(m):
|
79 |
+
if dsu.get_father(i) != i:
|
80 |
+
continue
|
81 |
+
for j in range(i):
|
82 |
+
ans_i = predicts[i]
|
83 |
+
ans_j = predicts[j]
|
84 |
+
completion_i = completions[i]
|
85 |
+
completion_j = completions[j]
|
86 |
+
if equal_func(ans_i, ans_j, completion_i, completion_j):
|
87 |
+
dsu.merge(i, j)
|
88 |
+
|
89 |
+
# Compute majority votes with probabilities
|
90 |
+
max_prob, max_prob_count = 0, 0
|
91 |
+
for i in range(m):
|
92 |
+
if dsu.get_father(i) != i:
|
93 |
+
continue
|
94 |
+
prob_i = np.sum([np.sum(list(dsu.attr[i][k])) for k in dsu.attr[i].keys()])
|
95 |
+
if prob_i > max_prob:
|
96 |
+
max_prob = prob_i
|
97 |
+
max_prob_count = 0
|
98 |
+
if prob_i >= max_prob:
|
99 |
+
max_prob_count += 1
|
100 |
+
|
101 |
+
# Compute accuracy
|
102 |
+
correct, answers = 0, []
|
103 |
+
for i in range(m):
|
104 |
+
if dsu.get_father(i) != i:
|
105 |
+
continue
|
106 |
+
ans_i = predicts[i]
|
107 |
+
prob_i = np.sum([np.sum(list(dsu.attr[i][k])) for k in dsu.attr[i].keys()])
|
108 |
+
answers.append([ans_i, prob_i, check_equal(ans_i, answer)])
|
109 |
+
if prob_i < max_prob:
|
110 |
+
continue
|
111 |
+
if check_equal(ans_i, answer):
|
112 |
+
correct += 1.0 / max_prob_count
|
113 |
+
|
114 |
+
# Normalize probabilities
|
115 |
+
sum_proba = np.sum([x[1] for x in answers])
|
116 |
+
for i in range(len(answers)):
|
117 |
+
answers[i][1] /= sum_proba
|
118 |
+
|
119 |
+
return correct, answers
|
120 |
+
|
121 |
+
class RPCEvaluator(Evaluator):
|
122 |
+
def __init__(self,):
|
123 |
+
self.name = "RPC"
|
124 |
+
|
125 |
+
def worker(self, args):
|
126 |
+
json_file, cache_file, K, seed = args
|
127 |
+
acc, maximum, average, max_bins, avg_bins = self.process(
|
128 |
+
json_file=json_file,
|
129 |
+
cache_file=cache_file,
|
130 |
+
equal_func=numberic_compare,
|
131 |
+
evaluator=wpc_evaluator,
|
132 |
+
K=K,
|
133 |
+
seed=seed
|
134 |
+
)
|
135 |
+
return acc, maximum, average
|
136 |
+
|
compute_sc.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import metrics
|
3 |
+
import argparse
|
4 |
+
import numpy as np
|
5 |
+
import multiprocessing
|
6 |
+
from tqdm import trange
|
7 |
+
import signal, functools
|
8 |
+
import re, os, sys, random, time
|
9 |
+
from fraction import Fraction
|
10 |
+
from data_processing.answer_extraction import *
|
11 |
+
from eval.eval_script import *
|
12 |
+
from compute_perp import Evaluator, numberic_compare
|
13 |
+
MAX_INT = sys.maxsize
|
14 |
+
INVALID_ANS = "[Invalid]"
|
15 |
+
|
16 |
+
__all__ = ["DSU"]
|
17 |
+
|
18 |
+
class DSU:
|
19 |
+
def __init__(self, n):
|
20 |
+
self.n = n
|
21 |
+
self.father = [i for i in range(n)]
|
22 |
+
self.size = [1 for i in range(n)]
|
23 |
+
self.attr = [{} for i in range(n)]
|
24 |
+
|
25 |
+
def get_father(self, x):
|
26 |
+
if self.father[x] == x:
|
27 |
+
return x
|
28 |
+
self.father[x] = self.get_father(self.father[x])
|
29 |
+
return self.father[x]
|
30 |
+
|
31 |
+
def merge(self, x, y):
|
32 |
+
fx = self.get_father(x)
|
33 |
+
fy = self.get_father(y)
|
34 |
+
if fx == fy:
|
35 |
+
return
|
36 |
+
self.father[fy] = fx
|
37 |
+
self.size[fx] += self.size[fy]
|
38 |
+
self.size[fy] = 0
|
39 |
+
for key in self.attr[fy].keys():
|
40 |
+
if key not in self.attr[fx]:
|
41 |
+
self.attr[fx][key] = self.attr[fy][key]
|
42 |
+
else:
|
43 |
+
self.attr[fx][key] |= self.attr[fy][key]
|
44 |
+
self.attr[fy] = {}
|
45 |
+
|
46 |
+
|
47 |
+
def sc_evaluator(predicts, completions, perplexities, answer, equal_func, check_equal):
|
48 |
+
m = len(predicts)
|
49 |
+
dsu = DSU(m)
|
50 |
+
|
51 |
+
# Merge answer for self-consistency
|
52 |
+
for i in range(m):
|
53 |
+
if dsu.get_father(i) != i:
|
54 |
+
continue
|
55 |
+
for j in range(i):
|
56 |
+
ans_i = predicts[i]
|
57 |
+
ans_j = predicts[j]
|
58 |
+
completion_i = completions[i]
|
59 |
+
completion_j = completions[j]
|
60 |
+
if equal_func(ans_i, ans_j, completion_i, completion_j):
|
61 |
+
dsu.merge(i, j)
|
62 |
+
|
63 |
+
# Compute majority votes
|
64 |
+
max_size, max_size_count = 0, 0
|
65 |
+
for i in range(m):
|
66 |
+
if dsu.get_father(i) != i:
|
67 |
+
continue
|
68 |
+
if dsu.size[i] > max_size:
|
69 |
+
max_size = dsu.size[i]
|
70 |
+
max_size_count = 0
|
71 |
+
if dsu.size[i] == max_size:
|
72 |
+
max_size_count += 1
|
73 |
+
|
74 |
+
# Compute accuracy
|
75 |
+
correct, answers = 0, []
|
76 |
+
for i in range(m):
|
77 |
+
if dsu.get_father(i) != i:
|
78 |
+
continue
|
79 |
+
ans_i = predicts[i]
|
80 |
+
answers.append([ans_i, dsu.size[i] / m, check_equal(ans_i, answer)])
|
81 |
+
if dsu.size[i] < max_size:
|
82 |
+
continue
|
83 |
+
if check_equal(ans_i, answer):
|
84 |
+
correct += 1.0 / max_size_count
|
85 |
+
|
86 |
+
# Normalize probabilities
|
87 |
+
sum_proba = np.sum([x[1] for x in answers])
|
88 |
+
for i in range(len(answers)):
|
89 |
+
answers[i][1] /= sum_proba
|
90 |
+
|
91 |
+
return correct, answers
|
92 |
+
|
93 |
+
|
94 |
+
class SCEvaluator(Evaluator):
|
95 |
+
def __init__(self):
|
96 |
+
self.name = "Self-Consistency"
|
97 |
+
|
98 |
+
def worker(self, args):
|
99 |
+
json_file, cache_file, K, seed = args
|
100 |
+
acc, maximum, average, max_bins, avg_bins = self.process(
|
101 |
+
json_file=json_file,
|
102 |
+
cache_file=cache_file,
|
103 |
+
equal_func=numberic_compare,
|
104 |
+
evaluator=sc_evaluator,
|
105 |
+
K=K,
|
106 |
+
seed=seed
|
107 |
+
)
|
108 |
+
return acc, maximum, average
|
data_processing/answer_extraction.py
ADDED
@@ -0,0 +1,362 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import regex
|
3 |
+
|
4 |
+
|
5 |
+
def _fix_fracs(string):
|
6 |
+
substrs = string.split("\\frac")
|
7 |
+
new_str = substrs[0]
|
8 |
+
if len(substrs) > 1:
|
9 |
+
substrs = substrs[1:]
|
10 |
+
for substr in substrs:
|
11 |
+
new_str += "\\frac"
|
12 |
+
if len(substr) > 0 and substr[0] == "{":
|
13 |
+
new_str += substr
|
14 |
+
else:
|
15 |
+
try:
|
16 |
+
assert len(substr) >= 2
|
17 |
+
except:
|
18 |
+
return string
|
19 |
+
a = substr[0]
|
20 |
+
b = substr[1]
|
21 |
+
if b != "{":
|
22 |
+
if len(substr) > 2:
|
23 |
+
post_substr = substr[2:]
|
24 |
+
new_str += "{" + a + "}{" + b + "}" + post_substr
|
25 |
+
else:
|
26 |
+
new_str += "{" + a + "}{" + b + "}"
|
27 |
+
else:
|
28 |
+
if len(substr) > 2:
|
29 |
+
post_substr = substr[2:]
|
30 |
+
new_str += "{" + a + "}" + b + post_substr
|
31 |
+
else:
|
32 |
+
new_str += "{" + a + "}" + b
|
33 |
+
string = new_str
|
34 |
+
return string
|
35 |
+
|
36 |
+
|
37 |
+
def _fix_a_slash_b(string):
|
38 |
+
if len(string.split("/")) != 2:
|
39 |
+
return string
|
40 |
+
a = string.split("/")[0]
|
41 |
+
b = string.split("/")[1]
|
42 |
+
try:
|
43 |
+
if "sqrt" not in a:
|
44 |
+
a = int(a)
|
45 |
+
if "sqrt" not in b:
|
46 |
+
b = int(b)
|
47 |
+
assert string == "{}/{}".format(a, b)
|
48 |
+
new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
|
49 |
+
return new_string
|
50 |
+
except:
|
51 |
+
return string
|
52 |
+
|
53 |
+
|
54 |
+
def _fix_sqrt(string):
|
55 |
+
_string = re.sub(r"\\sqrt(-?[0-9.a-zA-Z]+)", r"\\sqrt{\1}", string)
|
56 |
+
_string = re.sub(r"\\sqrt\s+(\w+)$", r"\\sqrt{\1}", _string)
|
57 |
+
return _string
|
58 |
+
|
59 |
+
|
60 |
+
def _fix_tan(string):
|
61 |
+
_string = re.sub(r"\\tan(-?[0-9.a-zA-Z]+)", r"\\tan{\1}", string)
|
62 |
+
_string = re.sub(r"\\tan\s+(\w+)$", r"\\tan{\1}", _string)
|
63 |
+
return _string
|
64 |
+
|
65 |
+
|
66 |
+
def strip_string(string):
|
67 |
+
string = str(string).strip()
|
68 |
+
# linebreaks
|
69 |
+
string = string.replace("\n", "")
|
70 |
+
|
71 |
+
# right "."
|
72 |
+
string = string.rstrip(".")
|
73 |
+
|
74 |
+
# remove inverse spaces
|
75 |
+
string = string.replace("\\!", "")
|
76 |
+
# string = string.replace("\\ ", "")
|
77 |
+
|
78 |
+
# replace \\ with \
|
79 |
+
# string = string.replace("\\\\", "\\")
|
80 |
+
# string = string.replace("\\\\", "\\")
|
81 |
+
|
82 |
+
if string.startswith("\\text{") and string.endswith("}"):
|
83 |
+
string = string.split("{", 1)[1][:-1]
|
84 |
+
|
85 |
+
# replace tfrac and dfrac with frac
|
86 |
+
string = string.replace("tfrac", "frac")
|
87 |
+
string = string.replace("dfrac", "frac")
|
88 |
+
string = string.replace("cfrac", "frac")
|
89 |
+
|
90 |
+
# remove \left and \right
|
91 |
+
string = string.replace("\\left", "")
|
92 |
+
string = string.replace("\\right", "")
|
93 |
+
|
94 |
+
# Remove unit: miles, dollars if after is not none
|
95 |
+
_string = re.sub(r"\\text{.*?}$", "", string).strip()
|
96 |
+
if _string != "" and _string != string:
|
97 |
+
# print("Warning: unit not removed: '{}' -> '{}'".format(string, _string))
|
98 |
+
string = _string
|
99 |
+
|
100 |
+
# Remove circ (degrees)
|
101 |
+
string = string.replace("^{\\circ}", "").strip()
|
102 |
+
string = string.replace("^\\circ", "").strip()
|
103 |
+
|
104 |
+
string = regex.sub(r"\{(c|m)?m\}(\^(2|3))?", "", string).strip()
|
105 |
+
string = regex.sub(r"p\.m\.$", "", string).strip()
|
106 |
+
string = regex.sub(r"(\d)\s*t$", r"\1", string).strip()
|
107 |
+
|
108 |
+
# remove dollar signs
|
109 |
+
string = string.replace("\\$", "")
|
110 |
+
string = string.replace("$", "")
|
111 |
+
|
112 |
+
# string = string.replace("\\text", "")
|
113 |
+
string = string.replace("x\\in", "")
|
114 |
+
|
115 |
+
# remove percentage
|
116 |
+
string = string.replace("\\%", "%")
|
117 |
+
string = string.replace("\%", "%")
|
118 |
+
# string = string.replace("%", "")
|
119 |
+
|
120 |
+
# " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
|
121 |
+
string = string.replace(" .", " 0.")
|
122 |
+
string = string.replace("{.", "{0.")
|
123 |
+
|
124 |
+
# cdot
|
125 |
+
string = string.replace("\\cdot", "")
|
126 |
+
|
127 |
+
# inf
|
128 |
+
string = string.replace("infinity", "\\infty")
|
129 |
+
if "\\infty" not in string:
|
130 |
+
string = string.replace("inf", "\\infty")
|
131 |
+
string = string.replace("+\\inity", "\\infty")
|
132 |
+
|
133 |
+
# and
|
134 |
+
# string = string.replace("and", "")
|
135 |
+
string = string.replace("\\mathbf", "")
|
136 |
+
string = string.replace("\\mathrm", "")
|
137 |
+
|
138 |
+
# use regex to remove \mbox{...}
|
139 |
+
string = re.sub(r"\\mbox{.*?}", "", string)
|
140 |
+
|
141 |
+
# quote
|
142 |
+
string.replace("'", "")
|
143 |
+
string.replace('"', "")
|
144 |
+
|
145 |
+
# i, j
|
146 |
+
if "j" in string and "i" not in string:
|
147 |
+
string = string.replace("j", "i")
|
148 |
+
|
149 |
+
# replace a.000b where b is not number or b is end, with ab, use regex
|
150 |
+
string = re.sub(r"(\d+)\.0+([^\d])", r"\1\2", string)
|
151 |
+
string = re.sub(r"(\d+)\.0+$", r"\1", string)
|
152 |
+
|
153 |
+
# if empty, return empty string
|
154 |
+
if len(string) == 0:
|
155 |
+
return string
|
156 |
+
if string[0] == ".":
|
157 |
+
string = "0" + string
|
158 |
+
|
159 |
+
# to consider: get rid of e.g. "k = " or "q = " at beginning
|
160 |
+
# if len(string.split("=")) == 2:
|
161 |
+
# if len(string.split("=")[0]) <= 2:
|
162 |
+
# string = string.split("=")[1]
|
163 |
+
|
164 |
+
string = _fix_sqrt(string)
|
165 |
+
string = _fix_tan(string)
|
166 |
+
string = string.replace(" ", "")
|
167 |
+
|
168 |
+
# \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
|
169 |
+
string = _fix_fracs(string)
|
170 |
+
|
171 |
+
# NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
|
172 |
+
string = _fix_a_slash_b(string)
|
173 |
+
|
174 |
+
string = regex.sub(r"(\\|,|\.)+$", "", string)
|
175 |
+
|
176 |
+
return string
|
177 |
+
|
178 |
+
|
179 |
+
def extract_boxed_answers(text):
|
180 |
+
answers = []
|
181 |
+
for piece in text.split("boxed{")[1:]:
|
182 |
+
n = 0
|
183 |
+
for i in range(len(piece)):
|
184 |
+
if piece[i] == "{":
|
185 |
+
n += 1
|
186 |
+
elif piece[i] == "}":
|
187 |
+
n -= 1
|
188 |
+
if n < 0:
|
189 |
+
if i + 1 < len(piece) and piece[i + 1] == "%":
|
190 |
+
answers.append(piece[: i + 1])
|
191 |
+
else:
|
192 |
+
answers.append(piece[:i])
|
193 |
+
break
|
194 |
+
return answers
|
195 |
+
|
196 |
+
|
197 |
+
def extract_program_output(pred_str):
|
198 |
+
"""
|
199 |
+
extract output between the last ```output\n...\n```
|
200 |
+
"""
|
201 |
+
if "```output" not in pred_str:
|
202 |
+
return ""
|
203 |
+
if "```output" in pred_str:
|
204 |
+
pred_str = pred_str.split("```output")[-1]
|
205 |
+
if "```" in pred_str:
|
206 |
+
pred_str = pred_str.split("```")[0]
|
207 |
+
output = pred_str.strip()
|
208 |
+
return output
|
209 |
+
|
210 |
+
|
211 |
+
def extract_answer(pred_str, exhaust=False):
|
212 |
+
pred = []
|
213 |
+
if "final answer is $" in pred_str and "$. I hope" in pred_str:
|
214 |
+
tmp = pred_str.split("final answer is $", 1)[1]
|
215 |
+
pred = [tmp.split("$. I hope", 1)[0].strip()]
|
216 |
+
elif "boxed" in pred_str:
|
217 |
+
pred = extract_boxed_answers(pred_str)
|
218 |
+
elif "he answer is" in pred_str:
|
219 |
+
pred = [pred_str.split("he answer is")[-1].strip()]
|
220 |
+
else:
|
221 |
+
program_output = extract_program_output(pred_str)
|
222 |
+
if program_output != "":
|
223 |
+
# fall back to program
|
224 |
+
pred.append(program_output)
|
225 |
+
else: # use the last number
|
226 |
+
pattern = "-?\d*\.?\d+"
|
227 |
+
ans = re.findall(pattern, pred_str.replace(",", ""))
|
228 |
+
if len(ans) >= 1:
|
229 |
+
ans = ans[-1]
|
230 |
+
else:
|
231 |
+
ans = ""
|
232 |
+
if ans:
|
233 |
+
pred.append(ans)
|
234 |
+
|
235 |
+
# multiple line
|
236 |
+
_pred = []
|
237 |
+
for ans in pred:
|
238 |
+
ans = ans.strip().split("\n")[0]
|
239 |
+
ans = ans.lstrip(":")
|
240 |
+
ans = ans.rstrip(".")
|
241 |
+
ans = ans.rstrip("/")
|
242 |
+
ans = strip_string(ans)
|
243 |
+
_pred.append(ans)
|
244 |
+
if exhaust:
|
245 |
+
return _pred
|
246 |
+
else:
|
247 |
+
return _pred[-1] if _pred else ""
|
248 |
+
|
249 |
+
|
250 |
+
def extract_math_answer(question, reasoning, task):
|
251 |
+
answer = []
|
252 |
+
for ans in extract_answer(reasoning, exhaust=True):
|
253 |
+
if "separated by commas" in question and all(ch not in ans for ch in "()[]"):
|
254 |
+
answer.extend([a.strip() for a in ans.split(",")])
|
255 |
+
elif regex.search(r"\\text\{\s*and\s*\}", ans):
|
256 |
+
answer.extend(
|
257 |
+
[
|
258 |
+
a.strip()
|
259 |
+
for a in regex.sub(r"\\text\{\s*and\s*\}", "[SEP]", ans).split(
|
260 |
+
"[SEP]"
|
261 |
+
)
|
262 |
+
]
|
263 |
+
)
|
264 |
+
else:
|
265 |
+
answer.append(ans.strip())
|
266 |
+
return answer
|
267 |
+
|
268 |
+
|
269 |
+
def extract_math_few_shot_cot_answer(question, reasoning, task):
|
270 |
+
if "Problem:" in reasoning:
|
271 |
+
reasoning = reasoning.split("Problem:", 1)[0]
|
272 |
+
return extract_math_answer(question, reasoning, task)
|
273 |
+
|
274 |
+
|
275 |
+
def extract_last_single_answer(question, reasoning, task):
|
276 |
+
return extract_answer(reasoning, exhaust=False)
|
277 |
+
|
278 |
+
|
279 |
+
def extract_gsm_few_shot_cot_answer(question, reasoning, task):
|
280 |
+
if "Q: " in reasoning:
|
281 |
+
reasoning = reasoning.split("Q: ", 1)[0]
|
282 |
+
pred = [s for s in regex.findall(r"-?\d+\.?\d*", reasoning)]
|
283 |
+
if pred:
|
284 |
+
return pred[-1]
|
285 |
+
else:
|
286 |
+
return "[invalid]"
|
287 |
+
|
288 |
+
|
289 |
+
def extract_agieval_gaokao_mathcloze_few_shot_cot_test(question, reasoning, task):
|
290 |
+
if "问题 " in reasoning:
|
291 |
+
reasoning = reasoning.split("问题 ", 1)[0]
|
292 |
+
if "答案是" in reasoning:
|
293 |
+
ans = reasoning.split("答案是", 1)[1].strip()
|
294 |
+
ans = ans.split("\n")[0].strip()
|
295 |
+
ans = [ans.strip("$")]
|
296 |
+
else:
|
297 |
+
ans = ["placeholder"]
|
298 |
+
return ans
|
299 |
+
|
300 |
+
|
301 |
+
def extract_agieval_gaokao_mathqa_few_shot_cot_test(question, reasoning, task):
|
302 |
+
if "问题 " in reasoning:
|
303 |
+
reasoning = reasoning.split("问题 ", 1)[0]
|
304 |
+
if "答案是" in reasoning:
|
305 |
+
ans = reasoning.split("答案是", 1)[1].strip()
|
306 |
+
ans = ans.split("\n")[0].strip()
|
307 |
+
else:
|
308 |
+
ans = "placeholder"
|
309 |
+
return ans
|
310 |
+
|
311 |
+
|
312 |
+
def extract_sat_few_shot_answer(question, reasoning, task):
|
313 |
+
if "Problem:" in reasoning:
|
314 |
+
reasoning = reasoning.split("Problem:", 1)[0]
|
315 |
+
patt = regex.search(r"the final answer is \(?(?P<ans>[abcd])\)?", reasoning.lower())
|
316 |
+
if patt is not None:
|
317 |
+
return patt.group("ans").upper()
|
318 |
+
return "placeholder"
|
319 |
+
|
320 |
+
|
321 |
+
def extract_ocwcourses_few_shot_answer(question, reasoning, task):
|
322 |
+
if "Problem:" in reasoning:
|
323 |
+
reasoning = reasoning.split("Problem:", 1)[0]
|
324 |
+
patt = regex.search(
|
325 |
+
r"final answer is (?P<ans>.*)\. I hope it is correct.", reasoning
|
326 |
+
)
|
327 |
+
if patt is None:
|
328 |
+
pred = "[invalid]"
|
329 |
+
print(f"DEBUG >>>\n{reasoning}", flush=True)
|
330 |
+
else:
|
331 |
+
pred = patt.group("ans")
|
332 |
+
return pred
|
333 |
+
|
334 |
+
|
335 |
+
def extract_mmlu_stem(question, reasoning, task):
|
336 |
+
if "Problem:" in reasoning:
|
337 |
+
reasoning = reasoning.split("Problem:", 1)[0]
|
338 |
+
return extract_sat_few_shot_answer(question, reasoning, task)
|
339 |
+
|
340 |
+
|
341 |
+
def extract_minif2f_isabelle(question, reasoning, task):
|
342 |
+
if "Informal:" in reasoning:
|
343 |
+
reasoning = reasoning.split("Informal:", 1)[0]
|
344 |
+
return reasoning.strip()
|
345 |
+
|
346 |
+
|
347 |
+
def extract_cmath_few_shot_test(question, reasoning, task):
|
348 |
+
if "问题:" in reasoning:
|
349 |
+
reasoning = reasoning.split("问题:", 1)[0]
|
350 |
+
if "答案是" in reasoning:
|
351 |
+
ans = reasoning.split("答案是", 1)[1].strip()
|
352 |
+
ans = ans.split("\n")[0]
|
353 |
+
ans = ans.strip(":")
|
354 |
+
ans = ans.strip("。")
|
355 |
+
try:
|
356 |
+
ans = [s for s in regex.findall(r"-?\d+\.?\d*", ans)][-1]
|
357 |
+
except:
|
358 |
+
print(f"DEBUG CMATH: {reasoning}", flush=True)
|
359 |
+
ans = "[invalid]"
|
360 |
+
else:
|
361 |
+
ans = extract_last_single_answer(question, reasoning, task)
|
362 |
+
return ans
|
data_processing/process_utils.py
ADDED
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import regex
|
2 |
+
|
3 |
+
from data_processing.answer_extraction import extract_math_answer, strip_string
|
4 |
+
|
5 |
+
|
6 |
+
def process_gsm8k_test(item):
|
7 |
+
sample = {
|
8 |
+
"dataset": "gsm8k-cot",
|
9 |
+
"id": item["id"],
|
10 |
+
"messages": [
|
11 |
+
{"role": "user", "content": item["question"]},
|
12 |
+
{
|
13 |
+
"role": "assistant",
|
14 |
+
"content": regex.sub(r"<<[^<>]*>>", "", item["cot"])
|
15 |
+
+ "\nSo the answer is $\\boxed{"
|
16 |
+
+ item["answer"].strip()
|
17 |
+
+ "}$.",
|
18 |
+
},
|
19 |
+
],
|
20 |
+
"answer": item["answer"].replace(",", ""),
|
21 |
+
}
|
22 |
+
yield sample
|
23 |
+
|
24 |
+
|
25 |
+
def process_math_test(item):
|
26 |
+
question = item["problem"]
|
27 |
+
try:
|
28 |
+
answer = extract_math_answer(question, item["solution"], task="cot")
|
29 |
+
except:
|
30 |
+
return
|
31 |
+
sample = {
|
32 |
+
"dataset": "math-cot",
|
33 |
+
"id": item["id"],
|
34 |
+
"level": item["level"],
|
35 |
+
"type": item["type"],
|
36 |
+
"category": item["category"],
|
37 |
+
"messages": [
|
38 |
+
{"role": "user", "content": question},
|
39 |
+
{
|
40 |
+
"role": "assistant",
|
41 |
+
"content": "\n".join(
|
42 |
+
regex.split(r"(?<=\.) (?=[A-Z])", item["solution"])
|
43 |
+
),
|
44 |
+
},
|
45 |
+
],
|
46 |
+
"answer": answer,
|
47 |
+
}
|
48 |
+
yield sample
|
49 |
+
|
50 |
+
|
51 |
+
def process_math_sat(item):
|
52 |
+
options = item["options"].strip()
|
53 |
+
assert "A" == options[0]
|
54 |
+
options = "(" + options
|
55 |
+
for ch in "BCDEFG":
|
56 |
+
if f" {ch}) " in options:
|
57 |
+
options = regex.sub(f" {ch}\) ", f" ({ch}) ", options)
|
58 |
+
question = f"{item['question'].strip()}\nWhat of the following is the right choice? Explain your answer.\n{options.strip()}"
|
59 |
+
messages = [
|
60 |
+
{"role": "user", "content": question},
|
61 |
+
{"role": "assistant", "content": item["Answer"]},
|
62 |
+
]
|
63 |
+
item = {
|
64 |
+
"dataset": "math_sat",
|
65 |
+
"id": item["id"],
|
66 |
+
"language": "en",
|
67 |
+
"messages": messages,
|
68 |
+
"answer": item["Answer"],
|
69 |
+
}
|
70 |
+
yield item
|
71 |
+
|
72 |
+
|
73 |
+
def process_ocwcourses(item):
|
74 |
+
messages = [
|
75 |
+
{"role": "user", "content": item["problem"].strip()},
|
76 |
+
{"role": "assistant", "content": item["solution"].strip()},
|
77 |
+
]
|
78 |
+
item = {
|
79 |
+
"dataset": "OCWCourses",
|
80 |
+
"id": item["id"],
|
81 |
+
"language": "en",
|
82 |
+
"messages": messages,
|
83 |
+
"answer": item["answer"],
|
84 |
+
}
|
85 |
+
yield item
|
86 |
+
|
87 |
+
|
88 |
+
def process_mmlu_stem(item):
|
89 |
+
options = item["options"]
|
90 |
+
for i, (label, option) in enumerate(zip("ABCD", options)):
|
91 |
+
options[i] = f"({label}) {str(option).strip()}"
|
92 |
+
options = ", ".join(options)
|
93 |
+
question = f"{item['question'].strip()}\nWhat of the following is the right choice? Explain your answer.\n{options}"
|
94 |
+
messages = [
|
95 |
+
{"role": "user", "content": question},
|
96 |
+
{"role": "assistant", "content": item["answer"]},
|
97 |
+
]
|
98 |
+
item = {
|
99 |
+
"dataset": "MMLU-STEM",
|
100 |
+
"id": item["id"],
|
101 |
+
"language": "en",
|
102 |
+
"messages": messages,
|
103 |
+
"answer": item["answer"],
|
104 |
+
}
|
105 |
+
yield item
|
106 |
+
|
107 |
+
|
108 |
+
def process_mgsm_zh(item):
|
109 |
+
item["answer"] = item["answer"].replace(",", "")
|
110 |
+
yield item
|
111 |
+
|
112 |
+
|
113 |
+
def process_cmath(item):
|
114 |
+
item = {
|
115 |
+
"dataset": "cmath",
|
116 |
+
"id": item["id"],
|
117 |
+
"grade": item["grade"],
|
118 |
+
"reasoning_step": item["reasoning_step"],
|
119 |
+
"messages": [
|
120 |
+
{"role": "user", "content": item["question"].strip()},
|
121 |
+
{"role": "assistant", "content": ""},
|
122 |
+
],
|
123 |
+
"answer": item["golden"].strip().replace(",", ""),
|
124 |
+
}
|
125 |
+
yield item
|
126 |
+
|
127 |
+
|
128 |
+
def process_agieval_gaokao_math_cloze(item):
|
129 |
+
item = {
|
130 |
+
"dataset": "agieval-gaokao-math-cloze",
|
131 |
+
"id": item["id"],
|
132 |
+
"messages": [
|
133 |
+
{"role": "user", "content": item["question"].strip()},
|
134 |
+
{"role": "assistant", "content": ""},
|
135 |
+
],
|
136 |
+
"answer": [strip_string(ans) for ans in item["answer"].strip().split(";")],
|
137 |
+
}
|
138 |
+
yield item
|
139 |
+
|
140 |
+
|
141 |
+
def process_agieval_gaokao_mathqa(item):
|
142 |
+
question = item["question"].strip()
|
143 |
+
options = []
|
144 |
+
for option in item["options"]:
|
145 |
+
option = option.strip()
|
146 |
+
assert option[0] == "("
|
147 |
+
assert option[2] == ")"
|
148 |
+
assert option[1] in "ABCD"
|
149 |
+
option = f"{option[1]}: {option[3:].strip()}"
|
150 |
+
options.append(option.strip())
|
151 |
+
question = f"{question}\n{options}"
|
152 |
+
item = {
|
153 |
+
"dataset": "agieval-gaokao-mathqa",
|
154 |
+
"id": item["id"],
|
155 |
+
"messages": [
|
156 |
+
{"role": "user", "content": question},
|
157 |
+
{"role": "assistant", "content": ""},
|
158 |
+
],
|
159 |
+
"answer": item["label"],
|
160 |
+
}
|
161 |
+
yield item
|
162 |
+
|
163 |
+
|
164 |
+
def process_agieval_gaokao_mathqa_few_shot_cot_test(item):
|
165 |
+
question = item["question"].strip().rstrip("\\")
|
166 |
+
options = " ".join([opt.strip() for opt in item["options"]])
|
167 |
+
question = f"{question}\n从以下选项中选择: {options}"
|
168 |
+
item = {
|
169 |
+
"dataset": "agieval-gaokao-mathqa",
|
170 |
+
"id": item["id"],
|
171 |
+
"messages": [
|
172 |
+
{"role": "user", "content": question},
|
173 |
+
{"role": "assistant", "content": ""},
|
174 |
+
],
|
175 |
+
"answer": item["label"],
|
176 |
+
}
|
177 |
+
yield item
|
178 |
+
|
179 |
+
|
180 |
+
def process_minif2f_isabelle(item):
|
181 |
+
question = f"(*### Problem\n\n{item['informal_statement'].strip()}\n\n### Solution\n\n{item['informal_proof'].strip()} *)\n\nFormal:\n{item['formal_statement'].strip()}"
|
182 |
+
item = {
|
183 |
+
"dataset": "minif2f-isabelle",
|
184 |
+
"id": item["id"],
|
185 |
+
"messages": [
|
186 |
+
{"role": "user", "content": question},
|
187 |
+
{"role": "assistant", "content": ""},
|
188 |
+
],
|
189 |
+
"answer": "placeholder",
|
190 |
+
}
|
191 |
+
yield item
|
eval/eval_script.py
ADDED
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import regex
|
2 |
+
from copy import deepcopy
|
3 |
+
from eval.eval_utils import math_equal
|
4 |
+
from eval.ocwcourses_eval_utils import (
|
5 |
+
normalize_numeric,
|
6 |
+
numeric_equality,
|
7 |
+
normalize_symbolic_equation,
|
8 |
+
SymbolicMathMixin,
|
9 |
+
)
|
10 |
+
|
11 |
+
|
12 |
+
def is_correct(item, pred_key="prediction", prec=1e-3):
|
13 |
+
pred = item[pred_key]
|
14 |
+
ans = item["answer"]
|
15 |
+
if isinstance(pred, list) and isinstance(ans, list):
|
16 |
+
pred_matched = set()
|
17 |
+
ans_matched = set()
|
18 |
+
for i in range(len(pred)):
|
19 |
+
for j in range(len(ans)):
|
20 |
+
item_cpy = deepcopy(item)
|
21 |
+
item_cpy.update({pred_key: pred[i], "answer": ans[j]})
|
22 |
+
if is_correct(item_cpy, pred_key=pred_key, prec=prec):
|
23 |
+
pred_matched.add(i)
|
24 |
+
ans_matched.add(j)
|
25 |
+
if item_cpy[pred_key] == "2,3,4":
|
26 |
+
print(item, flush=True)
|
27 |
+
print("wtf", flush=True)
|
28 |
+
return len(pred_matched) == len(pred) and len(ans_matched) == len(ans)
|
29 |
+
elif isinstance(pred, str) and isinstance(ans, str):
|
30 |
+
if "\\cup" in pred and "\\cup" in ans:
|
31 |
+
item = deepcopy(item)
|
32 |
+
item.update(
|
33 |
+
{
|
34 |
+
pred_key: pred.split("\\cup"),
|
35 |
+
"answer": ans.split("\\cup"),
|
36 |
+
}
|
37 |
+
)
|
38 |
+
return is_correct(item, pred_key=pred_key, prec=prec)
|
39 |
+
else:
|
40 |
+
label = False
|
41 |
+
try:
|
42 |
+
label = (
|
43 |
+
abs(
|
44 |
+
float(regex.sub(r",", "", str(pred)))
|
45 |
+
- float(regex.sub(r",", "", str(ans)))
|
46 |
+
)
|
47 |
+
< prec
|
48 |
+
)
|
49 |
+
except:
|
50 |
+
pass
|
51 |
+
label = label or (ans and pred == ans) or math_equal(pred, ans)
|
52 |
+
return label
|
53 |
+
else:
|
54 |
+
print(item, flush=True)
|
55 |
+
raise NotImplementedError()
|
56 |
+
|
57 |
+
|
58 |
+
def eval_math(item, pred_key="prediction", prec=1e-3):
|
59 |
+
pred = item[pred_key]
|
60 |
+
if pred_key == "program_output" and isinstance(pred, str):
|
61 |
+
pred = [pred]
|
62 |
+
ans = item["answer"]
|
63 |
+
if isinstance(pred, list) and isinstance(ans, list):
|
64 |
+
# for some questions in MATH, `reference` repeats answers
|
65 |
+
_ans = []
|
66 |
+
for a in ans:
|
67 |
+
if a not in _ans:
|
68 |
+
_ans.append(a)
|
69 |
+
ans = _ans
|
70 |
+
# some predictions for MATH questions also repeats answers
|
71 |
+
_pred = []
|
72 |
+
for a in pred:
|
73 |
+
if a not in _pred:
|
74 |
+
_pred.append(a)
|
75 |
+
# some predictions mistakenly box non-answer strings
|
76 |
+
pred = _pred[-len(ans) :]
|
77 |
+
|
78 |
+
item.update({pred_key: pred, "answer": ans})
|
79 |
+
return is_correct(item, pred_key=pred_key, prec=prec)
|
80 |
+
|
81 |
+
|
82 |
+
def eval_last_single_answer(item, pred_key="prediction", prec=1e-3):
|
83 |
+
for key in [pred_key, "answer"]:
|
84 |
+
assert isinstance(item[key], str), f"{key} = `{item[key]}` is not a str"
|
85 |
+
return is_correct(item, pred_key=pred_key, prec=prec)
|
86 |
+
|
87 |
+
|
88 |
+
def eval_agieval_gaokao_math_cloze(item, pred_key="prediction", prec=1e-3):
|
89 |
+
if pred_key == "program_output" and isinstance(item[pred_key], str):
|
90 |
+
item[pred_key] = [item[pred_key]]
|
91 |
+
for key in [pred_key, "answer"]:
|
92 |
+
assert isinstance(item[key], list), f"{key} = `{item[key]}` is not a list"
|
93 |
+
pred = item[pred_key]
|
94 |
+
ans = item["answer"]
|
95 |
+
_pred = []
|
96 |
+
for p in pred:
|
97 |
+
p = p + ";"
|
98 |
+
while p:
|
99 |
+
left_brackets = 0
|
100 |
+
for i in range(len(p)):
|
101 |
+
if p[i] == ";" or (p[i] == "," and left_brackets == 0):
|
102 |
+
_p, p = p[:i].strip(), p[i + 1 :].strip()
|
103 |
+
if _p not in _pred:
|
104 |
+
_pred.append(_p)
|
105 |
+
break
|
106 |
+
elif p[i] in "([{":
|
107 |
+
left_brackets += 1
|
108 |
+
elif p[i] in ")]}":
|
109 |
+
left_brackets -= 1
|
110 |
+
pred = _pred[-len(ans) :]
|
111 |
+
if len(pred) == len(ans):
|
112 |
+
for p, a in zip(pred, ans):
|
113 |
+
item.update(
|
114 |
+
{
|
115 |
+
pred_key: p,
|
116 |
+
"answer": a,
|
117 |
+
}
|
118 |
+
)
|
119 |
+
if not is_correct(item, pred_key=pred_key, prec=prec):
|
120 |
+
return False
|
121 |
+
return True
|
122 |
+
else:
|
123 |
+
return False
|
124 |
+
|
125 |
+
|
126 |
+
def eval_agieval_gaokao_mathqa(item, pred_key="prediction", prec=1e-3):
|
127 |
+
if pred_key == "program_output" and isinstance(item[pred_key], str):
|
128 |
+
item[pred_key] = [item[pred_key]]
|
129 |
+
pred_str = " ".join(item[pred_key])
|
130 |
+
ans = item["answer"]
|
131 |
+
tag = None
|
132 |
+
idx = -1
|
133 |
+
for t in "ABCD":
|
134 |
+
if t in pred_str and pred_str.index(t) > idx:
|
135 |
+
tag = t
|
136 |
+
idx = pred_str.index(t)
|
137 |
+
return tag == ans
|
138 |
+
|
139 |
+
|
140 |
+
def eval_math_sat(item, pred_key="prediction", prec=1e-3):
|
141 |
+
for key in [pred_key, "answer"]:
|
142 |
+
assert isinstance(item[key], str), f"{key} = `{item[key]}` is not a str"
|
143 |
+
return item[pred_key].lower() == item["answer"].lower()
|
144 |
+
|
145 |
+
|
146 |
+
def eval_mmlu_stem(item, pred_key="prediction", prec=1e-3):
|
147 |
+
return eval_math_sat(item, pred_key=pred_key, prec=prec)
|
148 |
+
|
149 |
+
|
150 |
+
def eval_ocwcourses(item, pred_key="prediction", prec=1e-3):
|
151 |
+
INVALID_ANSWER = "[invalidanswer]"
|
152 |
+
for key in [pred_key, "answer"]:
|
153 |
+
assert isinstance(item[key], str), f"{key} = `{item[key]}` is not a str"
|
154 |
+
pred = item[pred_key]
|
155 |
+
ans = item["answer"]
|
156 |
+
|
157 |
+
try:
|
158 |
+
float(ans)
|
159 |
+
normalize_fn = normalize_numeric
|
160 |
+
is_equiv = numeric_equality
|
161 |
+
answer_type = "numeric"
|
162 |
+
except ValueError:
|
163 |
+
if "=" in ans:
|
164 |
+
normalize_fn = normalize_symbolic_equation
|
165 |
+
is_equiv = lambda x, y: x == y
|
166 |
+
answer_type = "equation"
|
167 |
+
else:
|
168 |
+
normalize_fn = SymbolicMathMixin().normalize_tex
|
169 |
+
is_equiv = SymbolicMathMixin().is_tex_equiv
|
170 |
+
answer_type = "expression"
|
171 |
+
|
172 |
+
correct_answer = normalize_fn(ans)
|
173 |
+
|
174 |
+
unnormalized_answer = pred if pred else INVALID_ANSWER
|
175 |
+
model_answer = normalize_fn(unnormalized_answer)
|
176 |
+
|
177 |
+
if unnormalized_answer == INVALID_ANSWER:
|
178 |
+
acc = 0
|
179 |
+
elif model_answer == INVALID_ANSWER:
|
180 |
+
acc = 0
|
181 |
+
elif is_equiv(model_answer, correct_answer):
|
182 |
+
acc = 1
|
183 |
+
else:
|
184 |
+
acc = 0
|
185 |
+
|
186 |
+
return acc
|
187 |
+
|
188 |
+
|
189 |
+
def eval_minif2f_isabelle(item, pred_key="prediction", prec=1e-3):
|
190 |
+
return True
|
eval/eval_utils.py
ADDED
@@ -0,0 +1,400 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import multiprocessing
|
2 |
+
from math import isclose
|
3 |
+
import numpy as np
|
4 |
+
from typing import Union, Any, Dict
|
5 |
+
|
6 |
+
from sympy import simplify, N
|
7 |
+
from sympy.parsing.sympy_parser import parse_expr
|
8 |
+
from sympy.parsing.latex import parse_latex
|
9 |
+
import re
|
10 |
+
import regex
|
11 |
+
|
12 |
+
from data_processing.answer_extraction import (
|
13 |
+
extract_answer,
|
14 |
+
extract_program_output,
|
15 |
+
strip_string,
|
16 |
+
)
|
17 |
+
|
18 |
+
|
19 |
+
def extract_program(result: str, last_only=True):
|
20 |
+
"""
|
21 |
+
extract the program after "```python", and before "```"
|
22 |
+
"""
|
23 |
+
program = ""
|
24 |
+
start = False
|
25 |
+
for line in result.split("\n"):
|
26 |
+
if line.startswith("```python"):
|
27 |
+
if last_only:
|
28 |
+
program = "" # only extract the last program
|
29 |
+
else:
|
30 |
+
program += "\n# ========\n"
|
31 |
+
start = True
|
32 |
+
elif line.startswith("```"):
|
33 |
+
start = False
|
34 |
+
elif start:
|
35 |
+
program += line + "\n"
|
36 |
+
return program
|
37 |
+
|
38 |
+
|
39 |
+
def parse_ground_truth(example: Dict[str, Any], data_name):
|
40 |
+
if "gt_cot" in example:
|
41 |
+
return example["gt_cot"], strip_string(example["gt"])
|
42 |
+
|
43 |
+
# parse ground truth
|
44 |
+
if data_name in ["math", "ocw"]:
|
45 |
+
gt_cot = example["solution"]
|
46 |
+
gt_ans = extract_answer(gt_cot)
|
47 |
+
elif data_name == "gsm8k":
|
48 |
+
gt_cot, gt_ans = example["answer"].split("####")
|
49 |
+
elif data_name == "gsm-hard":
|
50 |
+
gt_cot, gt_ans = example["code"], example["target"]
|
51 |
+
elif data_name == "svamp":
|
52 |
+
gt_cot, gt_ans = example["Equation"], example["Answer"]
|
53 |
+
elif data_name == "asdiv":
|
54 |
+
gt_cot = example["formula"]
|
55 |
+
gt_ans = re.sub(r"\(.*?\)", "", example["answer"])
|
56 |
+
elif data_name == "mawps":
|
57 |
+
gt_cot, gt_ans = None, example["target"]
|
58 |
+
elif data_name == "tabmwp":
|
59 |
+
gt_cot = example["solution"]
|
60 |
+
gt_ans = example["answer"]
|
61 |
+
if example["ans_type"] in ["integer_number", "decimal_number"]:
|
62 |
+
if "/" in gt_ans:
|
63 |
+
gt_ans = int(gt_ans.split("/")[0]) / int(gt_ans.split("/")[1])
|
64 |
+
elif "," in gt_ans:
|
65 |
+
gt_ans = float(gt_ans.replace(",", ""))
|
66 |
+
elif "%" in gt_ans:
|
67 |
+
gt_ans = float(gt_ans.split("%")[0]) / 100
|
68 |
+
else:
|
69 |
+
gt_ans = float(gt_ans)
|
70 |
+
elif data_name == "bbh":
|
71 |
+
gt_cot, gt_ans = None, example["target"]
|
72 |
+
else:
|
73 |
+
raise NotImplementedError(data_name)
|
74 |
+
# post process
|
75 |
+
gt_cot = str(gt_cot).strip()
|
76 |
+
gt_ans = strip_string(gt_ans)
|
77 |
+
return gt_cot, gt_ans
|
78 |
+
|
79 |
+
|
80 |
+
def parse_question(example, data_name):
|
81 |
+
question = ""
|
82 |
+
if data_name == "asdiv":
|
83 |
+
question = f"{example['body'].strip()} {example['question'].strip()}"
|
84 |
+
elif data_name == "svamp":
|
85 |
+
body = example["Body"].strip()
|
86 |
+
if not body.endswith("."):
|
87 |
+
body = body + "."
|
88 |
+
question = f'{body} {example["Question"].strip()}'
|
89 |
+
elif data_name == "tabmwp":
|
90 |
+
title_str = (
|
91 |
+
f'regarding "{example["table_title"]}" ' if example["table_title"] else ""
|
92 |
+
)
|
93 |
+
question = f"Read the following table {title_str}and answer a question:\n"
|
94 |
+
question += f'{example["table"]}\n{example["question"]}'
|
95 |
+
if example["choices"]:
|
96 |
+
question += (
|
97 |
+
f' Please select from the following options: {example["choices"]}'
|
98 |
+
)
|
99 |
+
else:
|
100 |
+
for key in ["question", "problem", "Question", "input"]:
|
101 |
+
if key in example:
|
102 |
+
question = example[key]
|
103 |
+
break
|
104 |
+
assert question != ""
|
105 |
+
return question.strip()
|
106 |
+
|
107 |
+
|
108 |
+
def run_execute(executor, result, prompt_type, execute=False):
|
109 |
+
if not result or result == "error":
|
110 |
+
return None, None
|
111 |
+
report = None
|
112 |
+
|
113 |
+
if "program_only" in prompt_type:
|
114 |
+
prediction = extract_program_output(result)
|
115 |
+
elif prompt_type in ["pot", "pal"] and execute:
|
116 |
+
code = extract_program(result)
|
117 |
+
prediction, report = executor.apply(code)
|
118 |
+
else:
|
119 |
+
prediction = extract_answer(result)
|
120 |
+
|
121 |
+
prediction = strip_string(prediction)
|
122 |
+
return prediction, report
|
123 |
+
|
124 |
+
|
125 |
+
def parse_digits(num):
|
126 |
+
# format: 234.23 || 23%
|
127 |
+
num = regex.sub(",", "", str(num))
|
128 |
+
try:
|
129 |
+
return float(num)
|
130 |
+
except:
|
131 |
+
if num.endswith("%"):
|
132 |
+
num = num[:-1]
|
133 |
+
if num.endswith("\\"):
|
134 |
+
num = num[:-1]
|
135 |
+
try:
|
136 |
+
return float(num) / 100
|
137 |
+
except:
|
138 |
+
pass
|
139 |
+
return None
|
140 |
+
|
141 |
+
|
142 |
+
def is_digit(num):
|
143 |
+
# paired with parse_digits
|
144 |
+
return parse_digits(num) is not None
|
145 |
+
|
146 |
+
|
147 |
+
def normalize_prediction(prediction):
|
148 |
+
try: # 1. numerical equal
|
149 |
+
if is_digit(prediction):
|
150 |
+
prediction = np.round(float(str(prediction).replace(",", "")), 6)
|
151 |
+
return str(prediction)
|
152 |
+
except:
|
153 |
+
pass
|
154 |
+
|
155 |
+
# 2. symbolic equal
|
156 |
+
prediction = str(prediction).strip()
|
157 |
+
|
158 |
+
## deal with [], (), {}
|
159 |
+
brackets = []
|
160 |
+
while (
|
161 |
+
prediction.startswith("[")
|
162 |
+
and prediction.endswith("]")
|
163 |
+
or (prediction.startswith("(") and prediction.endswith(")"))
|
164 |
+
):
|
165 |
+
bracket = prediction[0]
|
166 |
+
prediction = prediction[1:-1]
|
167 |
+
if brackets and "," in prediction:
|
168 |
+
pred_parts = [normalize_prediction(part) for part in prediction.split(",")]
|
169 |
+
prediction = ",".join(pred_parts)
|
170 |
+
|
171 |
+
if brackets:
|
172 |
+
for b in reversed(brackets):
|
173 |
+
if b == "[":
|
174 |
+
prediction = "[" + prediction + "]"
|
175 |
+
else:
|
176 |
+
assert b == "("
|
177 |
+
prediction = "(" + prediction + ")"
|
178 |
+
|
179 |
+
def _parse(s):
|
180 |
+
for f in [parse_latex, parse_expr]:
|
181 |
+
try:
|
182 |
+
return f(s)
|
183 |
+
except:
|
184 |
+
pass
|
185 |
+
return s
|
186 |
+
|
187 |
+
prediction = _parse(prediction)
|
188 |
+
|
189 |
+
for s in ["{", "}", "(", ")"]:
|
190 |
+
prediction = prediction.replace(s, "")
|
191 |
+
|
192 |
+
return prediction
|
193 |
+
|
194 |
+
|
195 |
+
def math_equal(
|
196 |
+
prediction: Union[bool, float, str],
|
197 |
+
reference: Union[float, str],
|
198 |
+
include_percentage: bool = True,
|
199 |
+
is_close: bool = True,
|
200 |
+
timeout: bool = False,
|
201 |
+
) -> bool:
|
202 |
+
"""
|
203 |
+
Exact match of math if and only if:
|
204 |
+
1. numerical equal: both can convert to float and are equal
|
205 |
+
2. symbolic equal: both can convert to sympy expression and are equal
|
206 |
+
"""
|
207 |
+
if str(prediction) == str(reference):
|
208 |
+
return True
|
209 |
+
|
210 |
+
try: # 1. numerical equal
|
211 |
+
if is_digit(prediction) and is_digit(reference):
|
212 |
+
prediction = parse_digits(prediction)
|
213 |
+
reference = parse_digits(reference)
|
214 |
+
# number questions
|
215 |
+
if include_percentage:
|
216 |
+
gt_result = [reference / 100, reference, reference * 100]
|
217 |
+
else:
|
218 |
+
gt_result = [reference]
|
219 |
+
for item in gt_result:
|
220 |
+
try:
|
221 |
+
if is_close:
|
222 |
+
if isclose(item, prediction, abs_tol=1e-3):
|
223 |
+
return True
|
224 |
+
else:
|
225 |
+
if item == prediction:
|
226 |
+
return True
|
227 |
+
except Exception:
|
228 |
+
continue
|
229 |
+
return False
|
230 |
+
except:
|
231 |
+
pass
|
232 |
+
|
233 |
+
if not prediction and prediction not in [0, False]:
|
234 |
+
return False
|
235 |
+
|
236 |
+
# 2. symbolic equal
|
237 |
+
reference = str(reference).strip()
|
238 |
+
prediction = str(prediction).strip()
|
239 |
+
|
240 |
+
if (
|
241 |
+
regex.match(r"(\(|\[).+(\)|\])", prediction) is not None
|
242 |
+
and regex.match(r"(\(|\[).+(\)|\])", reference) is not None
|
243 |
+
):
|
244 |
+
pred_parts = prediction[1:-1].split(",")
|
245 |
+
ref_parts = reference[1:-1].split(",")
|
246 |
+
if len(pred_parts) == len(ref_parts):
|
247 |
+
if all(
|
248 |
+
[
|
249 |
+
math_equal(
|
250 |
+
pred_parts[i], ref_parts[i], include_percentage, is_close
|
251 |
+
)
|
252 |
+
for i in range(len(pred_parts))
|
253 |
+
]
|
254 |
+
):
|
255 |
+
return True
|
256 |
+
|
257 |
+
if (
|
258 |
+
(
|
259 |
+
prediction.startswith("\\begin{pmatrix}")
|
260 |
+
or prediction.startswith("\\begin{bmatrix}")
|
261 |
+
)
|
262 |
+
and (
|
263 |
+
prediction.endswith("\\end{pmatrix}")
|
264 |
+
or prediction.endswith("\\end{bmatrix}")
|
265 |
+
)
|
266 |
+
and (
|
267 |
+
reference.startswith("\\begin{pmatrix}")
|
268 |
+
or reference.startswith("\\begin{bmatrix}")
|
269 |
+
)
|
270 |
+
and (
|
271 |
+
reference.endswith("\\end{pmatrix}") or reference.endswith("\\end{bmatrix}")
|
272 |
+
)
|
273 |
+
):
|
274 |
+
pred_lines = [
|
275 |
+
line.strip()
|
276 |
+
for line in prediction[
|
277 |
+
len("\\begin{pmatrix}") : -len("\\end{pmatrix}")
|
278 |
+
].split("\\\\")
|
279 |
+
if line.strip()
|
280 |
+
]
|
281 |
+
ref_lines = [
|
282 |
+
line.strip()
|
283 |
+
for line in reference[
|
284 |
+
len("\\begin{pmatrix}") : -len("\\end{pmatrix}")
|
285 |
+
].split("\\\\")
|
286 |
+
if line.strip()
|
287 |
+
]
|
288 |
+
matched = True
|
289 |
+
if len(pred_lines) == len(ref_lines):
|
290 |
+
for pred_line, ref_line in zip(pred_lines, ref_lines):
|
291 |
+
pred_parts = pred_line.split("&")
|
292 |
+
ref_parts = ref_line.split("&")
|
293 |
+
if len(pred_parts) == len(ref_parts):
|
294 |
+
if not all(
|
295 |
+
[
|
296 |
+
math_equal(
|
297 |
+
pred_parts[i],
|
298 |
+
ref_parts[i],
|
299 |
+
include_percentage,
|
300 |
+
is_close,
|
301 |
+
)
|
302 |
+
for i in range(len(pred_parts))
|
303 |
+
]
|
304 |
+
):
|
305 |
+
matched = False
|
306 |
+
break
|
307 |
+
else:
|
308 |
+
matched = False
|
309 |
+
if not matched:
|
310 |
+
break
|
311 |
+
else:
|
312 |
+
matched = False
|
313 |
+
if matched:
|
314 |
+
return True
|
315 |
+
|
316 |
+
if prediction.count("=") == 1 and reference.count("=") == 1:
|
317 |
+
pred = prediction.split("=")
|
318 |
+
pred = f"{pred[0].strip()} - ({pred[1].strip()})"
|
319 |
+
ref = reference.split("=")
|
320 |
+
ref = f"{ref[0].strip()} - ({ref[1].strip()})"
|
321 |
+
if symbolic_equal(pred, ref) or symbolic_equal(f"-({pred})", ref):
|
322 |
+
return True
|
323 |
+
elif (
|
324 |
+
prediction.count("=") == 1
|
325 |
+
and len(prediction.split("=")[0].strip()) <= 2
|
326 |
+
and "=" not in reference
|
327 |
+
):
|
328 |
+
if math_equal(
|
329 |
+
prediction.split("=")[1], reference, include_percentage, is_close
|
330 |
+
):
|
331 |
+
return True
|
332 |
+
elif (
|
333 |
+
reference.count("=") == 1
|
334 |
+
and len(reference.split("=")[0].strip()) <= 2
|
335 |
+
and "=" not in prediction
|
336 |
+
):
|
337 |
+
if math_equal(
|
338 |
+
prediction, reference.split("=")[1], include_percentage, is_close
|
339 |
+
):
|
340 |
+
return True
|
341 |
+
|
342 |
+
# symbolic equal with sympy
|
343 |
+
if timeout:
|
344 |
+
if call_with_timeout(symbolic_equal_process, prediction, reference):
|
345 |
+
return True
|
346 |
+
else:
|
347 |
+
if symbolic_equal(prediction, reference):
|
348 |
+
return True
|
349 |
+
|
350 |
+
return False
|
351 |
+
|
352 |
+
|
353 |
+
def math_equal_process(param):
|
354 |
+
return math_equal(param[-2], param[-1])
|
355 |
+
|
356 |
+
|
357 |
+
def symbolic_equal(a, b):
|
358 |
+
def _parse(s):
|
359 |
+
for f in [parse_latex, parse_expr]:
|
360 |
+
try:
|
361 |
+
return f(s)
|
362 |
+
except:
|
363 |
+
pass
|
364 |
+
return s
|
365 |
+
|
366 |
+
a = _parse(a)
|
367 |
+
b = _parse(b)
|
368 |
+
|
369 |
+
try:
|
370 |
+
if simplify(a - b) == 0:
|
371 |
+
return True
|
372 |
+
except:
|
373 |
+
pass
|
374 |
+
|
375 |
+
try:
|
376 |
+
if isclose(N(a), N(b), abs_tol=1e-3):
|
377 |
+
return True
|
378 |
+
except:
|
379 |
+
pass
|
380 |
+
return False
|
381 |
+
|
382 |
+
|
383 |
+
def symbolic_equal_process(a, b, output_queue):
|
384 |
+
result = symbolic_equal(a, b)
|
385 |
+
output_queue.put(result)
|
386 |
+
|
387 |
+
|
388 |
+
def call_with_timeout(func, *args, timeout=1, **kwargs):
|
389 |
+
output_queue = multiprocessing.Queue()
|
390 |
+
process_args = args + (output_queue,)
|
391 |
+
process = multiprocessing.Process(target=func, args=process_args, kwargs=kwargs)
|
392 |
+
process.start()
|
393 |
+
process.join(timeout)
|
394 |
+
|
395 |
+
if process.is_alive():
|
396 |
+
process.terminate()
|
397 |
+
process.join()
|
398 |
+
return False
|
399 |
+
|
400 |
+
return output_queue.get()
|
eval/ocwcourses_eval_utils.py
ADDED
@@ -0,0 +1,266 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import numpy as np
|
3 |
+
import sympy
|
4 |
+
from sympy.core.sympify import SympifyError
|
5 |
+
from sympy.parsing.latex import parse_latex
|
6 |
+
|
7 |
+
import signal
|
8 |
+
|
9 |
+
INVALID_ANSWER = "[invalidanswer]"
|
10 |
+
|
11 |
+
|
12 |
+
class timeout:
|
13 |
+
def __init__(self, seconds=1, error_message="Timeout"):
|
14 |
+
self.seconds = seconds
|
15 |
+
self.error_message = error_message
|
16 |
+
|
17 |
+
def handle_timeout(self, signum, frame):
|
18 |
+
raise TimeoutError(self.error_message)
|
19 |
+
|
20 |
+
def __enter__(self):
|
21 |
+
signal.signal(signal.SIGALRM, self.handle_timeout)
|
22 |
+
signal.alarm(self.seconds)
|
23 |
+
|
24 |
+
def __exit__(self, type, value, traceback):
|
25 |
+
signal.alarm(0)
|
26 |
+
|
27 |
+
|
28 |
+
def normalize_numeric(s):
|
29 |
+
if s is None:
|
30 |
+
return None
|
31 |
+
for unit in [
|
32 |
+
"eV",
|
33 |
+
" \\mathrm{~kg} \\cdot \\mathrm{m} / \\mathrm{s}",
|
34 |
+
" kg m/s",
|
35 |
+
"kg*m/s",
|
36 |
+
"kg",
|
37 |
+
"m/s",
|
38 |
+
"m / s",
|
39 |
+
"m s^{-1}",
|
40 |
+
"\\text{ m/s}",
|
41 |
+
" \\mathrm{m/s}",
|
42 |
+
" \\text{ m/s}",
|
43 |
+
"g/mole",
|
44 |
+
"g/mol",
|
45 |
+
"\\mathrm{~g}",
|
46 |
+
"\\mathrm{~g} / \\mathrm{mol}",
|
47 |
+
"W",
|
48 |
+
"erg/s",
|
49 |
+
"years",
|
50 |
+
"year",
|
51 |
+
"cm",
|
52 |
+
]:
|
53 |
+
s = s.replace(unit, "")
|
54 |
+
s = s.strip()
|
55 |
+
for maybe_unit in ["m", "s", "cm"]:
|
56 |
+
s = s.replace("\\mathrm{" + maybe_unit + "}", "")
|
57 |
+
s = s.replace("\\mathrm{~" + maybe_unit + "}", "")
|
58 |
+
s = s.strip()
|
59 |
+
s = s.strip("$")
|
60 |
+
try:
|
61 |
+
return float(eval(s))
|
62 |
+
except:
|
63 |
+
try:
|
64 |
+
expr = parse_latex(s)
|
65 |
+
if expr.is_number:
|
66 |
+
return float(expr)
|
67 |
+
return INVALID_ANSWER
|
68 |
+
except:
|
69 |
+
return INVALID_ANSWER
|
70 |
+
|
71 |
+
|
72 |
+
def numeric_equality(n1, n2, threshold=0.01):
|
73 |
+
if n1 is None or n2 is None:
|
74 |
+
return False
|
75 |
+
if np.isclose(n1, 0) or np.isclose(n2, 0) or np.isclose(n1 - n2, 0):
|
76 |
+
return np.abs(n1 - n2) < threshold * (n1 + n2) / 2
|
77 |
+
else:
|
78 |
+
return np.isclose(n1, n2)
|
79 |
+
|
80 |
+
|
81 |
+
def normalize_symbolic_equation(s):
|
82 |
+
if not isinstance(s, str):
|
83 |
+
return INVALID_ANSWER
|
84 |
+
if s.startswith("\\["):
|
85 |
+
s = s[2:]
|
86 |
+
if s.endswith("\\]"):
|
87 |
+
s = s[:-2]
|
88 |
+
s = s.replace("\\left(", "(")
|
89 |
+
s = s.replace("\\right)", ")")
|
90 |
+
s = s.replace("\\\\", "\\")
|
91 |
+
if s.startswith("$") or s.endswith("$"):
|
92 |
+
s = s.strip("$")
|
93 |
+
try:
|
94 |
+
maybe_expression = parse_latex(s)
|
95 |
+
if not isinstance(maybe_expression, sympy.core.relational.Equality):
|
96 |
+
# we have equation, not expression
|
97 |
+
return INVALID_ANSWER
|
98 |
+
else:
|
99 |
+
return maybe_expression
|
100 |
+
except:
|
101 |
+
return INVALID_ANSWER
|
102 |
+
|
103 |
+
|
104 |
+
class SymbolicMathMixin:
|
105 |
+
"""
|
106 |
+
Methods useful for parsing mathematical expressions from text and determining equivalence of expressions.
|
107 |
+
"""
|
108 |
+
|
109 |
+
SUBSTITUTIONS = [ # used for text normalize
|
110 |
+
("an ", ""),
|
111 |
+
("a ", ""),
|
112 |
+
(".$", "$"),
|
113 |
+
("\\$", ""),
|
114 |
+
(r"\ ", ""),
|
115 |
+
(" ", ""),
|
116 |
+
("mbox", "text"),
|
117 |
+
(",\\text{and}", ","),
|
118 |
+
("\\text{and}", ","),
|
119 |
+
("\\text{m}", "\\text{}"),
|
120 |
+
]
|
121 |
+
REMOVED_EXPRESSIONS = [ # used for text normalizer
|
122 |
+
"square",
|
123 |
+
"ways",
|
124 |
+
"integers",
|
125 |
+
"dollars",
|
126 |
+
"mph",
|
127 |
+
"inches",
|
128 |
+
"ft",
|
129 |
+
"hours",
|
130 |
+
"km",
|
131 |
+
"units",
|
132 |
+
"\\ldots",
|
133 |
+
"sue",
|
134 |
+
"points",
|
135 |
+
"feet",
|
136 |
+
"minutes",
|
137 |
+
"digits",
|
138 |
+
"cents",
|
139 |
+
"degrees",
|
140 |
+
"cm",
|
141 |
+
"gm",
|
142 |
+
"pounds",
|
143 |
+
"meters",
|
144 |
+
"meals",
|
145 |
+
"edges",
|
146 |
+
"students",
|
147 |
+
"childrentickets",
|
148 |
+
"multiples",
|
149 |
+
"\\text{s}",
|
150 |
+
"\\text{.}",
|
151 |
+
"\\text{\ns}",
|
152 |
+
"\\text{}^2",
|
153 |
+
"\\text{}^3",
|
154 |
+
"\\text{\n}",
|
155 |
+
"\\text{}",
|
156 |
+
r"\mathrm{th}",
|
157 |
+
r"^\circ",
|
158 |
+
r"^{\circ}",
|
159 |
+
r"\;",
|
160 |
+
r",\!",
|
161 |
+
"{,}",
|
162 |
+
'"',
|
163 |
+
"\\dots",
|
164 |
+
]
|
165 |
+
|
166 |
+
def normalize_tex(self, final_answer: str) -> str:
|
167 |
+
"""
|
168 |
+
Normalizes a string representing a mathematical expression.
|
169 |
+
Used as a preprocessing step before parsing methods.
|
170 |
+
|
171 |
+
Copied character for character from appendix D of Lewkowycz et al. (2022)
|
172 |
+
"""
|
173 |
+
final_answer = final_answer.split("=")[-1]
|
174 |
+
|
175 |
+
for before, after in self.SUBSTITUTIONS:
|
176 |
+
final_answer = final_answer.replace(before, after)
|
177 |
+
for expr in self.REMOVED_EXPRESSIONS:
|
178 |
+
final_answer = final_answer.replace(expr, "")
|
179 |
+
|
180 |
+
# Extract answer that is in LaTeX math, is bold,
|
181 |
+
# is surrounded by a box, etc.
|
182 |
+
final_answer = re.sub(r"(.*?)(\$)(.*?)(\$)(.*)", "$\\3$", final_answer)
|
183 |
+
final_answer = re.sub(r"(\\text\{)(.*?)(\})", "\\2", final_answer)
|
184 |
+
final_answer = re.sub(r"(\\textbf\{)(.*?)(\})", "\\2", final_answer)
|
185 |
+
final_answer = re.sub(r"(\\overline\{)(.*?)(\})", "\\2", final_answer)
|
186 |
+
final_answer = re.sub(r"(\\boxed\{)(.*)(\})", "\\2", final_answer)
|
187 |
+
|
188 |
+
# Normalize shorthand TeX:
|
189 |
+
# \fracab -> \frac{a}{b}
|
190 |
+
# \frac{abc}{bef} -> \frac{abc}{bef}
|
191 |
+
# \fracabc -> \frac{a}{b}c
|
192 |
+
# \sqrta -> \sqrt{a}
|
193 |
+
# \sqrtab -> sqrt{a}b
|
194 |
+
final_answer = re.sub(r"(frac)([^{])(.)", "frac{\\2}{\\3}", final_answer)
|
195 |
+
final_answer = re.sub(r"(sqrt)([^{])", "sqrt{\\2}", final_answer)
|
196 |
+
final_answer = final_answer.replace("$", "")
|
197 |
+
|
198 |
+
# Normalize 100,000 -> 100000
|
199 |
+
if final_answer.replace(",", "").isdigit():
|
200 |
+
final_answer = final_answer.replace(",", "")
|
201 |
+
|
202 |
+
return final_answer
|
203 |
+
|
204 |
+
def parse_tex(self, text: str, time_limit: int = 5) -> sympy.Basic:
|
205 |
+
"""
|
206 |
+
Wrapper around `sympy.parse_text` that outputs a SymPy expression.
|
207 |
+
Typically, you want to apply `normalize_text` as a preprocessing step.
|
208 |
+
"""
|
209 |
+
try:
|
210 |
+
with timeout(seconds=time_limit):
|
211 |
+
parsed = parse_latex(text)
|
212 |
+
except (
|
213 |
+
# general error handling: there is a long tail of possible sympy/other
|
214 |
+
# errors we would like to catch
|
215 |
+
Exception
|
216 |
+
) as e:
|
217 |
+
print(f"failed to parse {text} with exception {e}")
|
218 |
+
return None
|
219 |
+
|
220 |
+
return parsed
|
221 |
+
|
222 |
+
def is_exp_equiv(self, x1: sympy.Basic, x2: sympy.Basic, time_limit=5) -> bool:
|
223 |
+
"""
|
224 |
+
Determines whether two sympy expressions are equal.
|
225 |
+
"""
|
226 |
+
try:
|
227 |
+
with timeout(seconds=time_limit):
|
228 |
+
try:
|
229 |
+
diff = x1 - x2
|
230 |
+
except (SympifyError, ValueError, TypeError) as e:
|
231 |
+
print(f"Couldn't subtract {x1} and {x2} with exception {e}")
|
232 |
+
return False
|
233 |
+
|
234 |
+
try:
|
235 |
+
if sympy.simplify(diff) == 0:
|
236 |
+
return True
|
237 |
+
else:
|
238 |
+
return False
|
239 |
+
except (SympifyError, ValueError, TypeError) as e:
|
240 |
+
print(f"Failed to simplify {x1}-{x2} with {e}")
|
241 |
+
return False
|
242 |
+
except TimeoutError as e:
|
243 |
+
print(f"Timed out comparing {x1} and {x2}")
|
244 |
+
return False
|
245 |
+
except Exception as e:
|
246 |
+
print(f"failed on unrecognized exception {e}")
|
247 |
+
return False
|
248 |
+
|
249 |
+
def is_tex_equiv(self, x1: str, x2: str, time_limit=5) -> bool:
|
250 |
+
"""
|
251 |
+
Determines whether two (ideally normalized using `normalize_text`) TeX expressions are equal.
|
252 |
+
|
253 |
+
Does so by first checking for string exact-match, then falls back on sympy-equivalence,
|
254 |
+
following the (Lewkowycz et al. 2022) methodology.
|
255 |
+
"""
|
256 |
+
if x1 == x2:
|
257 |
+
# don't resort to sympy if we have full string match, post-normalization
|
258 |
+
return True
|
259 |
+
else:
|
260 |
+
return False
|
261 |
+
parsed_x2 = self.parse_tex(x2)
|
262 |
+
if not parsed_x2:
|
263 |
+
# if our reference fails to parse into a Sympy object,
|
264 |
+
# we forgo parsing + checking our generated answer.
|
265 |
+
return False
|
266 |
+
return self.is_exp_equiv(self.parse_tex(x1), parsed_x2, time_limit=time_limit)
|
eval/python_executor.py
ADDED
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import io
|
3 |
+
from contextlib import redirect_stdout
|
4 |
+
import pickle
|
5 |
+
import regex
|
6 |
+
import copy
|
7 |
+
from typing import Any, Dict, Optional
|
8 |
+
import multiprocess
|
9 |
+
from pebble import ProcessPool
|
10 |
+
from concurrent.futures import TimeoutError
|
11 |
+
from functools import partial
|
12 |
+
import traceback
|
13 |
+
from timeout_decorator import timeout
|
14 |
+
|
15 |
+
|
16 |
+
class GenericRuntime:
|
17 |
+
GLOBAL_DICT = {}
|
18 |
+
LOCAL_DICT = None
|
19 |
+
HEADERS = []
|
20 |
+
|
21 |
+
def __init__(self):
|
22 |
+
self._global_vars = copy.copy(self.GLOBAL_DICT)
|
23 |
+
self._local_vars = copy.copy(self.LOCAL_DICT) if self.LOCAL_DICT else None
|
24 |
+
|
25 |
+
for c in self.HEADERS:
|
26 |
+
self.exec_code(c)
|
27 |
+
|
28 |
+
def exec_code(self, code_piece: str) -> None:
|
29 |
+
if regex.search(r"(\s|^)?input\(", code_piece) or regex.search(
|
30 |
+
r"(\s|^)?os.system\(", code_piece
|
31 |
+
):
|
32 |
+
raise RuntimeError()
|
33 |
+
exec(code_piece, self._global_vars)
|
34 |
+
|
35 |
+
def eval_code(self, expr: str) -> Any:
|
36 |
+
return eval(expr, self._global_vars)
|
37 |
+
|
38 |
+
def inject(self, var_dict: Dict[str, Any]) -> None:
|
39 |
+
for k, v in var_dict.items():
|
40 |
+
self._global_vars[k] = v
|
41 |
+
|
42 |
+
@property
|
43 |
+
def answer(self):
|
44 |
+
return self._global_vars["answer"]
|
45 |
+
|
46 |
+
|
47 |
+
class PythonExecutor:
|
48 |
+
def __init__(
|
49 |
+
self,
|
50 |
+
runtime: Optional[Any] = None,
|
51 |
+
get_answer_symbol: Optional[str] = None,
|
52 |
+
get_answer_expr: Optional[str] = None,
|
53 |
+
get_answer_from_stdout: bool = False,
|
54 |
+
) -> None:
|
55 |
+
self.runtime = runtime if runtime else GenericRuntime()
|
56 |
+
self.answer_symbol = get_answer_symbol
|
57 |
+
self.answer_expr = get_answer_expr
|
58 |
+
self.get_answer_from_stdout = get_answer_from_stdout
|
59 |
+
|
60 |
+
def process_generation_to_code(self, gens: str):
|
61 |
+
batch_code = []
|
62 |
+
for g in gens:
|
63 |
+
multiline_comments = False
|
64 |
+
code = []
|
65 |
+
for line in g.split("\n"):
|
66 |
+
strip_line = line.strip()
|
67 |
+
if strip_line.startswith("#"):
|
68 |
+
line = line.split("#", 1)[0] + "# comments"
|
69 |
+
elif (
|
70 |
+
not multiline_comments
|
71 |
+
and strip_line.startswith('"""')
|
72 |
+
and strip_line.endswith('"""')
|
73 |
+
and len(strip_line) >= 6
|
74 |
+
):
|
75 |
+
line = line.split('"""', 1)[0] + '"""comments"""'
|
76 |
+
elif not multiline_comments and strip_line.startswith('"""'):
|
77 |
+
multiline_comments = True
|
78 |
+
elif multiline_comments and strip_line.endswith('"""'):
|
79 |
+
multiline_comments = False
|
80 |
+
line = ""
|
81 |
+
if not multiline_comments:
|
82 |
+
code.append(line)
|
83 |
+
batch_code.append(code)
|
84 |
+
return batch_code
|
85 |
+
|
86 |
+
@staticmethod
|
87 |
+
def execute(
|
88 |
+
code,
|
89 |
+
get_answer_from_stdout=None,
|
90 |
+
runtime=None,
|
91 |
+
answer_symbol=None,
|
92 |
+
answer_expr=None,
|
93 |
+
timeout_length=10,
|
94 |
+
):
|
95 |
+
try:
|
96 |
+
if get_answer_from_stdout:
|
97 |
+
program_io = io.StringIO()
|
98 |
+
with redirect_stdout(program_io):
|
99 |
+
timeout(timeout_length)(runtime.exec_code)("\n".join(code))
|
100 |
+
program_io.seek(0)
|
101 |
+
result = "".join(program_io.readlines()) # [-1]
|
102 |
+
elif answer_symbol:
|
103 |
+
timeout(timeout_length)(runtime.exec_code)("\n".join(code))
|
104 |
+
result = runtime._global_vars[answer_symbol]
|
105 |
+
elif answer_expr:
|
106 |
+
timeout(timeout_length)(runtime.exec_code)("\n".join(code))
|
107 |
+
result = timeout(timeout_length)(runtime.eval_code)(answer_expr)
|
108 |
+
else:
|
109 |
+
timeout(timeout_length)(runtime.exec_code)("\n".join(code[:-1]))
|
110 |
+
result = timeout(timeout_length)(runtime.eval_code)(code[-1])
|
111 |
+
concise_exec_info = ""
|
112 |
+
exec_info = ""
|
113 |
+
str(result)
|
114 |
+
pickle.dumps(result) # serialization check
|
115 |
+
except:
|
116 |
+
# traceback.print_exc()
|
117 |
+
result = ""
|
118 |
+
concise_exec_info = traceback.format_exc().split("\n")[-2]
|
119 |
+
exec_info = traceback.format_exc()
|
120 |
+
if (
|
121 |
+
get_answer_from_stdout
|
122 |
+
and "exec(code_piece, self._global_vars)" in exec_info
|
123 |
+
):
|
124 |
+
exec_info = exec_info.split("exec(code_piece, self._global_vars)")[
|
125 |
+
-1
|
126 |
+
].strip()
|
127 |
+
msg = []
|
128 |
+
for line in exec_info.split("\n"):
|
129 |
+
patt = regex.search(
|
130 |
+
r'(?P<start>.*)File "(?P<file>.*)", line (?P<lno>\d+), (?P<end>.*)',
|
131 |
+
line,
|
132 |
+
)
|
133 |
+
if patt is not None:
|
134 |
+
if "<module>" in patt.group("end"):
|
135 |
+
continue
|
136 |
+
fname = patt.group("file")
|
137 |
+
if "site-packages" in fname:
|
138 |
+
fname = f"site-packages{fname.split('site-packages', 1)[1]}"
|
139 |
+
line = f'{patt.group("start")}File "{fname}", {patt.group("end")}'
|
140 |
+
else:
|
141 |
+
line = f'{patt.group("start")}{patt.group("end")}'
|
142 |
+
else:
|
143 |
+
patt = regex.search(
|
144 |
+
r"(?P<start>.*)(?P<file>/.*site-packages/.*\.py)(?P<end>.*)",
|
145 |
+
line,
|
146 |
+
)
|
147 |
+
if patt is not None:
|
148 |
+
line = f'{patt.group("start")}site-packages{patt.group("file").split("site-packages", 1)[1]}{patt.group("end")}'
|
149 |
+
msg.append(line)
|
150 |
+
exec_info = "\n".join(msg)
|
151 |
+
return result, concise_exec_info, exec_info
|
152 |
+
|
153 |
+
def apply(self, code):
|
154 |
+
return self.batch_apply([code])[0]
|
155 |
+
|
156 |
+
def batch_apply(self, batch_code):
|
157 |
+
all_code_snippets = self.process_generation_to_code(batch_code)
|
158 |
+
all_exec_results = []
|
159 |
+
executor = partial(
|
160 |
+
self.execute,
|
161 |
+
get_answer_from_stdout=self.get_answer_from_stdout,
|
162 |
+
runtime=self.runtime,
|
163 |
+
answer_symbol=self.answer_symbol,
|
164 |
+
answer_expr=self.answer_expr,
|
165 |
+
timeout_length=10,
|
166 |
+
)
|
167 |
+
with ProcessPool(max_workers=multiprocess.cpu_count()) as pool:
|
168 |
+
iterator = pool.map(executor, all_code_snippets, timeout=10).result()
|
169 |
+
|
170 |
+
while True:
|
171 |
+
try:
|
172 |
+
result = next(iterator)
|
173 |
+
all_exec_results.append(result)
|
174 |
+
except StopIteration:
|
175 |
+
break
|
176 |
+
except TimeoutError as error:
|
177 |
+
all_exec_results.append(("", "Timeout Error", "Timeout Error"))
|
178 |
+
except Exception as error:
|
179 |
+
print(error)
|
180 |
+
exit()
|
181 |
+
|
182 |
+
batch_results = []
|
183 |
+
for code, (result, concise_exec_info, exec_info) in zip(
|
184 |
+
all_code_snippets, all_exec_results
|
185 |
+
):
|
186 |
+
metadata = {
|
187 |
+
"code": code,
|
188 |
+
"exec_result": result,
|
189 |
+
"concise_exec_info": concise_exec_info,
|
190 |
+
"exec_info": exec_info,
|
191 |
+
}
|
192 |
+
batch_results.append((result, metadata))
|
193 |
+
return batch_results
|
main.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import argparse
|
3 |
+
from huggingface_hub import hf_hub_download
|
4 |
+
import json
|
5 |
+
from build_cache import cache
|
6 |
+
from compute_perp import Evaluator as PPLEvaluator
|
7 |
+
from compute_sc import SCEvaluator
|
8 |
+
from compute_rpc import RPCEvaluator
|
9 |
+
|
10 |
+
REPOID = {
|
11 |
+
"MATH": "WNJXYK/MATH-Reasoning-Paths",
|
12 |
+
"MathOdyssey": "WNJXYK/MathOdyssey-Reasoning-Paths",
|
13 |
+
"AIME": "WNJXYK/AIME_1983_2024-Reasoning-Paths",
|
14 |
+
"OlympiadBench": "WNJXYK/OlympiadBench-Reasoning-Paths"
|
15 |
+
}
|
16 |
+
|
17 |
+
EVALUATOR_MAP = {
|
18 |
+
"PPL": PPLEvaluator,
|
19 |
+
"SC": SCEvaluator,
|
20 |
+
"RPC": RPCEvaluator
|
21 |
+
}
|
22 |
+
|
23 |
+
args = argparse.ArgumentParser()
|
24 |
+
args.add_argument("--dataset", type=str, choices=["MATH", "MathOdyssey", "AIME", "OlympiadBench"], default="MathOdyssey")
|
25 |
+
args.add_argument("--model", type=str, choices=["Deepseek-Math-RL-7B", "InternLM2-Math-Plus-1.8B", "InternLM2-Math-Plus-7B"], default="InternLM2-Math-Plus-7B")
|
26 |
+
args.add_argument("--K", type=int, default=128)
|
27 |
+
args.add_argument("--method", type=str, default="PPL", choices=["PPL", "SC", "RPC"])
|
28 |
+
args = args.parse_args()
|
29 |
+
|
30 |
+
repo_id = REPOID[args.dataset]
|
31 |
+
filename = args.model + ".json"
|
32 |
+
|
33 |
+
# Download sampled reasoning paths from Hugging Face
|
34 |
+
try:
|
35 |
+
file_path = hf_hub_download(repo_id=repo_id, filename=filename, repo_type="dataset")
|
36 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
37 |
+
json_file = json.load(f)
|
38 |
+
print(f"Load sampled reasoning paths {filename} from {repo_id} successfully!")
|
39 |
+
except Exception as e:
|
40 |
+
print(f"Failed to load sampled reasoning paths {filename} from {repo_id}: {e}")
|
41 |
+
|
42 |
+
# Build cache for checking equality
|
43 |
+
cache_path = file_path.replace(".json", ".cache.json")
|
44 |
+
cache(json_file, cache_path)
|
45 |
+
with open(cache_path, 'r', encoding='utf-8') as f:
|
46 |
+
cache_file = json.load(f)
|
47 |
+
|
48 |
+
# Run!
|
49 |
+
results = EVALUATOR_MAP[args.method]().solve(json_file=json_file, cache_file=cache_file, K=args.K)
|
50 |
+
|
51 |
+
# Report results
|
52 |
+
result_str = f"{args.method} {args.dataset} {args.model} {args.K} {results}"
|
53 |
+
with open("results.txt", "a") as f:
|
54 |
+
f.write(result_str + "\n")
|
55 |
+
print(result_str)
|
metrics.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
EPS = 1e-10
|
4 |
+
NLLEPS = 1e-6
|
5 |
+
|
6 |
+
def compute_maximum_metrics(predicts, n_bins=10):
|
7 |
+
n = len(predicts)
|
8 |
+
acc, cnf, siz = np.zeros(n_bins), np.zeros(n_bins), np.zeros(n_bins)
|
9 |
+
brier_score = []
|
10 |
+
negative_ll = []
|
11 |
+
|
12 |
+
for idx in range(n):
|
13 |
+
m = len(predicts[idx])
|
14 |
+
|
15 |
+
# Compute maximum probabilities and corresponding counts within each problem
|
16 |
+
max_prob, max_prob_counts = -1e6, 0
|
17 |
+
for i in range(m):
|
18 |
+
ans, prob, flag = predicts[idx][i]
|
19 |
+
if prob > max_prob:
|
20 |
+
max_prob, max_prob_counts = prob, 0
|
21 |
+
if prob >= max_prob - EPS:
|
22 |
+
max_prob_counts += 1
|
23 |
+
# print(max_prob, max_prob_counts)
|
24 |
+
# Compute the maximum accuracy for each problem as well as the ECE metric
|
25 |
+
vote_acc = 0
|
26 |
+
for i in range(m):
|
27 |
+
ans, prob, flag = predicts[idx][i]
|
28 |
+
if prob < max_prob:
|
29 |
+
continue
|
30 |
+
if np.isnan(prob):
|
31 |
+
continue
|
32 |
+
if flag:
|
33 |
+
vote_acc += 1.0 / max_prob_counts
|
34 |
+
# Compute Expected Calibration Error
|
35 |
+
for cur in range(n_bins):
|
36 |
+
lower, upper = cur / n_bins, (cur + 1) / n_bins
|
37 |
+
if lower < max_prob <= upper:
|
38 |
+
if flag:
|
39 |
+
acc[cur] += 1.0 / max_prob_counts
|
40 |
+
cnf[cur] += prob / max_prob_counts
|
41 |
+
siz[cur] += 1.0 / max_prob_counts
|
42 |
+
|
43 |
+
# Compute Brier Score
|
44 |
+
brier_score.append((vote_acc - max_prob) ** 2)
|
45 |
+
|
46 |
+
# Compute Negative Likelihhod
|
47 |
+
cliped_max_prob = max(min(max_prob, 1 - NLLEPS), NLLEPS)
|
48 |
+
cliped_vote_acc = max(min(vote_acc, 1 - NLLEPS), NLLEPS)
|
49 |
+
negative_ll.append(
|
50 |
+
-np.log(cliped_max_prob) * cliped_vote_acc
|
51 |
+
- np.log(1 - cliped_max_prob) * (1 - cliped_vote_acc)
|
52 |
+
)
|
53 |
+
|
54 |
+
# Turn each metrics into values
|
55 |
+
ece = 0
|
56 |
+
for cur in range(n_bins):
|
57 |
+
if siz[cur] > 0:
|
58 |
+
acc[cur] = acc[cur] / siz[cur]
|
59 |
+
cnf[cur] = cnf[cur] / siz[cur]
|
60 |
+
ece += siz[cur] * np.abs(acc[cur] - cnf[cur])
|
61 |
+
# print(siz[cur], acc[cur], cnf[cur])
|
62 |
+
ece = ece / sum(siz)
|
63 |
+
bs = np.mean(brier_score)
|
64 |
+
nll = np.mean(negative_ll)
|
65 |
+
|
66 |
+
return (ece, bs, nll), (acc, cnf, siz)
|
67 |
+
|
68 |
+
|
69 |
+
def compute_average_metrics(predicts, n_bins=10):
|
70 |
+
n = len(predicts)
|
71 |
+
acc, cnf, siz = np.zeros(n_bins), np.zeros(n_bins), np.zeros(n_bins)
|
72 |
+
brier_score = []
|
73 |
+
negative_ll = []
|
74 |
+
|
75 |
+
for idx in range(n):
|
76 |
+
m = len(predicts[idx])
|
77 |
+
|
78 |
+
problem_brier_score = []
|
79 |
+
problem_negative_ll = []
|
80 |
+
for i in range(m):
|
81 |
+
ans, prob, flag = predicts[idx][i]
|
82 |
+
# Compute Expected Calibration Error
|
83 |
+
for cur in range(n_bins):
|
84 |
+
lower, upper = cur / n_bins, (cur + 1) / n_bins
|
85 |
+
if lower < prob <= upper:
|
86 |
+
if flag:
|
87 |
+
acc[cur] += 1.0 / m
|
88 |
+
cnf[cur] += prob / m
|
89 |
+
siz[cur] += 1.0 / m
|
90 |
+
|
91 |
+
# Compute Brier Score
|
92 |
+
problem_brier_score.append(((1 if flag else 0) - prob) ** 2)
|
93 |
+
|
94 |
+
# Compute Negative Likelyhood
|
95 |
+
cliped_max_prob = max(min(prob, 1 - NLLEPS), NLLEPS)
|
96 |
+
cliped_vote_acc = max(min(1 if flag else 0, 1 - NLLEPS), NLLEPS)
|
97 |
+
problem_negative_ll.append(
|
98 |
+
-np.log(cliped_max_prob) * cliped_vote_acc
|
99 |
+
- np.log(1 - cliped_max_prob) * (1 - cliped_vote_acc)
|
100 |
+
)
|
101 |
+
|
102 |
+
brier_score.append(np.mean(problem_brier_score))
|
103 |
+
negative_ll.append(np.mean(problem_negative_ll))
|
104 |
+
|
105 |
+
ece = 0
|
106 |
+
for cur in range(n_bins):
|
107 |
+
if siz[cur] > 0:
|
108 |
+
acc[cur] = acc[cur] / siz[cur]
|
109 |
+
cnf[cur] = cnf[cur] / siz[cur]
|
110 |
+
ece += siz[cur] * np.abs(acc[cur] - cnf[cur])
|
111 |
+
ece = ece / sum(siz)
|
112 |
+
bs = np.mean(brier_score)
|
113 |
+
nll = np.mean(negative_ll)
|
114 |
+
|
115 |
+
return (ece, bs, nll), (acc, cnf, siz)
|
requirements.txt
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
antlr4-python3-runtime==4.13.2
|
2 |
+
datasets==4.1.1
|
3 |
+
Fraction==2.2.0
|
4 |
+
huggingface-hub==0.35.3
|
5 |
+
multiprocess==0.70.16
|
6 |
+
numpy==2.0.2
|
7 |
+
regex==2025.9.18
|
8 |
+
scipy==1.13.1
|
9 |
+
sympy==1.14.0
|
10 |
+
tqdm==4.67.1
|