File size: 3,099 Bytes
0ec61d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import argparse
import json

from openai import OpenAI

from project_settings import environment, project_path


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--gemini_api_key",
        default=environment.get(key="GEMINI_API_KEY"),
        type=str
    )
    parser.add_argument(
        "--model_name",
        # default="gemini-2.5-pro",
        default="gemini-2.5-flash",
        type=str
    )
    parser.add_argument(
        "--eval_data",
        default=(project_path / "data/arc-easy.jsonl").as_posix(),
        type=str
    )
    parser.add_argument(
        "--eval_result",
        default=(project_path / "data/eval_math_result.jsonl").as_posix(),
        type=str
    )
    args = parser.parse_args()
    return args


def main():
    args = get_args()

    client = OpenAI(
        api_key=args.gemini_api_key,
        base_url="https://generativelanguage.googleapis.com/v1beta"
    )

    total = 0
    total_correct = 0
    with open(args.eval_data, "r", encoding="utf-8") as fin, open(args.eval_result, "a+", encoding="utf-8") as fout:
        for row in fin:
            row = json.loads(row)
            idx = row["id"]
            question = row["question"]
            choices = row["choices"]
            answer_key = row["answerkey"]

            instruct = "Complete this single-choice question."

            choices_str = ""
            for choice in choices:
                label = choice["label"]
                text = choice["text"]
                choices_str += f"If you think the answer is `{text}` output: `{label}`\n"

            prompt = f"""
{instruct}

Question: 
{question}

Choices: 
{choices_str}

Remember to output ONLY the corresponding letter.
Your output is:
            """.strip()
            # print(prompt)
            response = client.chat.completions.create(
                model="gemini-2.5-pro",
                messages=[{"role": "user", "content": prompt}],
                stream=False,
                # max_tokens=1,
                temperature=0.0,
                # logit_bias={
                #     32: 100,
                #     33: 100,
                #     34: 100,
                #     35: 100,
                #     36: 100,
                # }
            )

            prediction = response.choices[0].message.content

            correct = 1 if prediction == answer_key else 0

            total += 1
            total_correct += correct
            score = total_correct / total

            row_ = {
                "id": idx,
                "question": question,
                "choices": choices,
                "ground_true": answer_key,
                "prediction": prediction,
                "correct": correct,
                "total": total,
                "total_correct": total_correct,
                "score": score,
            }
            row_ = json.dumps(row_, ensure_ascii=False)
            fout.write(f"{row_}\n")

            print(f"score: {score}")

    return


if __name__ == "__main__":
    main()