HoneyTian's picture
update
0ec61d2
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import argparse
import json
from openai import OpenAI
from project_settings import environment, project_path
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--gemini_api_key",
default=environment.get(key="GEMINI_API_KEY"),
type=str
)
parser.add_argument(
"--model_name",
# default="gemini-2.5-pro",
default="gemini-2.5-flash",
type=str
)
parser.add_argument(
"--eval_data",
default=(project_path / "data/arc-easy.jsonl").as_posix(),
type=str
)
parser.add_argument(
"--eval_result",
default=(project_path / "data/eval_math_result.jsonl").as_posix(),
type=str
)
args = parser.parse_args()
return args
def main():
args = get_args()
client = OpenAI(
api_key=args.gemini_api_key,
base_url="https://generativelanguage.googleapis.com/v1beta"
)
total = 0
total_correct = 0
with open(args.eval_data, "r", encoding="utf-8") as fin, open(args.eval_result, "a+", encoding="utf-8") as fout:
for row in fin:
row = json.loads(row)
idx = row["id"]
question = row["question"]
choices = row["choices"]
answer_key = row["answerkey"]
instruct = "Complete this single-choice question."
choices_str = ""
for choice in choices:
label = choice["label"]
text = choice["text"]
choices_str += f"If you think the answer is `{text}` output: `{label}`\n"
prompt = f"""
{instruct}
Question:
{question}
Choices:
{choices_str}
Remember to output ONLY the corresponding letter.
Your output is:
""".strip()
# print(prompt)
response = client.chat.completions.create(
model="gemini-2.5-pro",
messages=[{"role": "user", "content": prompt}],
stream=False,
# max_tokens=1,
temperature=0.0,
# logit_bias={
# 32: 100,
# 33: 100,
# 34: 100,
# 35: 100,
# 36: 100,
# }
)
prediction = response.choices[0].message.content
correct = 1 if prediction == answer_key else 0
total += 1
total_correct += correct
score = total_correct / total
row_ = {
"id": idx,
"question": question,
"choices": choices,
"ground_true": answer_key,
"prediction": prediction,
"correct": correct,
"total": total,
"total_correct": total_correct,
"score": score,
}
row_ = json.dumps(row_, ensure_ascii=False)
fout.write(f"{row_}\n")
print(f"score: {score}")
return
if __name__ == "__main__":
main()