OpenGeminiAPI / examples /api_eval /eval_gemini_google.py
HoneyTian's picture
update
5fcf580
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import argparse
import json
import os
import sys
import time
pwd = os.path.abspath(os.path.dirname(__file__))
sys.path.append(os.path.join(pwd, "../../"))
from google import genai
from google.genai import types
from project_settings import environment, project_path
def get_args():
"""
python3 eval_gemini_google.py --model_name gemini-2.5-pro --eval_result eval_math_result_gemini-2.5-pro.jsonl
python3 eval_gemini_google.py --model_name gemini-2.5-flash --eval_result eval_math_result_gemini-2.5-flash.jsonl
python3 eval_gemini_google.py --model_name gemini-2.5-flash-lite-preview-06-17 --eval_result eval_math_result_gemini-2.5-flash-lite-preview-06-17.jsonl
:return:
"""
parser = argparse.ArgumentParser()
parser.add_argument(
"--google_application_credentials",
default=(project_path / "dotenv/potent-veld-462405-t3-8091a29b2894.json").as_posix(),
type=str
)
parser.add_argument(
"--model_name",
# default="gemini-2.5-pro",
# default="gemini-2.5-flash",
default="gemini-2.5-flash-lite-preview-06-17",
type=str
)
parser.add_argument(
"--eval_data",
default=(project_path / "data/arc-easy.jsonl").as_posix(),
type=str
)
parser.add_argument(
"--eval_result",
default=(project_path / "data/eval_math_result.jsonl").as_posix(),
type=str
)
args = parser.parse_args()
return args
def main():
args = get_args()
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = args.google_application_credentials
client = genai.Client(
vertexai=True,
project="potent-veld-462405-t3",
location="global",
)
generate_content_config = types.GenerateContentConfig(
temperature=1,
top_p=0.95,
max_output_tokens=8192,
response_modalities=["TEXT"],
)
total = 0
total_correct = 0
# finished
finished_idx_set = set()
if os.path.exists(args.eval_result):
with open(args.eval_result, "r", encoding="utf-8") as f:
for row in f:
row = json.loads(row)
idx = row["id"]
total = row["total"]
total_correct = row["total_correct"]
finished_idx_set.add(idx)
print(f"finished count: {len(finished_idx_set)}")
with open(args.eval_data, "r", encoding="utf-8") as fin, open(args.eval_result, "a+", encoding="utf-8") as fout:
for row in fin:
if total > 20:
break
row = json.loads(row)
idx = row["id"]
question = row["question"]
choices = row["choices"]
answer_key = row["answerkey"]
if idx in finished_idx_set:
continue
finished_idx_set.add(idx)
instruct = "Complete this single-choice question."
choices_str = ""
for choice in choices:
label = choice["label"]
text = choice["text"]
choices_str += f"If you think the answer is `{text}` output: `{label}`\n"
prompt = f"""
{instruct}
Question:
{question}
Choices:
{choices_str}
Remember to output ONLY the corresponding letter.
Your output is:
""".strip()
# print(prompt)
contents = [
types.Content(
role="user",
parts=[
types.Part.from_text(text=prompt)
]
)
]
time_begin = time.time()
response: types.GenerateContentResponse = client.models.generate_content(
model=args.model_name,
contents=contents,
config=generate_content_config,
)
time_cost = time.time() - time_begin
print(time_cost)
try:
prediction = response.candidates[0].content.parts[0].text
except TypeError:
continue
correct = 1 if prediction == answer_key else 0
total += 1
total_correct += correct
score = total_correct / total
row_ = {
"id": idx,
"question": question,
"choices": choices,
"ground_true": answer_key,
"prediction": prediction,
"correct": correct,
"total": total,
"total_correct": total_correct,
"score": score,
"time_cost": time_cost,
}
row_ = json.dumps(row_, ensure_ascii=False)
fout.write(f"{row_}\n")
# print(f"score: {score}")
return
if __name__ == "__main__":
main()