grok4-gpqa-eval / benchmarks /gpqa_benchmark.py
TeddyYao's picture
Upload 38 files
8474f02 verified
from .base_benchmark import BaseBenchmark
from typing import Dict, Any, Optional, Tuple
from datasets import load_dataset
import re
import random
from .evaluation_utils import extract_answer_mmlu
class GPQABenchmark(BaseBenchmark):
"""GPQA (Graduate-Level Google-Proof Q&A) benchmark"""
def __init__(self):
super().__init__(name="GPQA", dataset_name="Idavidrein/gpqa")
async def load_dataset(self, sample_size: Optional[int] = None, **kwargs):
"""Load GPQA dataset"""
# GPQA has different subsets: gpqa_main, gpqa_diamond, gpqa_extended
subset = kwargs.get('subset', 'gpqa_main')
try:
# Set HF token if available
import os
hf_token = os.getenv('HF_TOKEN') or os.getenv('HUGGING_FACE_HUB_TOKEN')
if hf_token:
dataset = load_dataset(self.dataset_name, subset, split='train', token=hf_token)
else:
dataset = load_dataset(self.dataset_name, subset, split='train')
except Exception as e:
if "gated dataset" in str(e) or "authentication" in str(e).lower():
raise Exception(
"GPQA dataset requires authentication. Please:\n"
"1. Set HF_TOKEN environment variable\n"
"2. Request access at https://huggingface.co/datasets/Idavidrein/gpqa\n"
f"Original error: {e}"
)
# Fallback to main if subset not found
try:
dataset = load_dataset(self.dataset_name, 'gpqa_main', split='train')
except:
raise e
self.dataset = []
for sample in dataset:
# GPQA has these fields: Question, Correct Answer, Incorrect Answer 1-3
choices = [
sample.get('Correct Answer', ''),
sample.get('Incorrect Answer 1', ''),
sample.get('Incorrect Answer 2', ''),
sample.get('Incorrect Answer 3', '')
]
# Shuffle choices and track correct index
import random
indices = list(range(4))
random.shuffle(indices)
shuffled_choices = [choices[i] for i in indices]
correct_index = indices.index(0) # 0 was the correct answer position
self.dataset.append({
'question': sample['Question'],
'choices': shuffled_choices,
'correct_index': correct_index,
'subject': sample.get('Subdomain', 'Unknown'),
'raw_sample': sample
})
# Shuffle dataset
random.shuffle(self.dataset)
if sample_size and len(self.dataset) > sample_size:
self.dataset = self.dataset[:sample_size]
def format_prompt(self, sample: Dict[str, Any]) -> str:
"""Format GPQA question as prompt matching official format"""
question = sample['question']
choices = sample['choices']
# GPQA uses a simpler format in lm-eval
prompt = f"""What is the correct answer to this question: {question}
Choices:
(A) {choices[0]}
(B) {choices[1]}
(C) {choices[2]}
(D) {choices[3]}
Answer:"""
return prompt
async def evaluate_sample(self, api, sample: Dict[str, Any], **kwargs) -> Tuple[bool, Dict[str, Any]]:
"""Evaluate a single GPQA sample"""
prompt = self.format_prompt(sample)
try:
response = await api.generate_with_retry(prompt, **kwargs)
# Extract answer from response using standard extraction
predicted_letter = extract_answer_mmlu(response)
if predicted_letter:
predicted_index = ord(predicted_letter) - ord('A')
else:
# If no clear answer, mark as incorrect
predicted_index = -1
correct_index = sample['correct_index']
is_correct = predicted_index == correct_index
result = {
'question': sample['question'],
'choices': sample['choices'],
'correct_answer': correct_index,
'predicted_answer': predicted_index,
'model_response': response,
'is_correct': is_correct,
'subject': sample['subject']
}
return is_correct, result
except Exception as e:
result = {
'question': sample['question'],
'error': str(e),
'is_correct': False
}
return False, result