Spaces:
Running
Running
from .base_benchmark import BaseBenchmark | |
from typing import Dict, Any, Optional, Tuple | |
from datasets import load_dataset | |
import re | |
import random | |
from .evaluation_utils import extract_answer_mmlu | |
class GPQABenchmark(BaseBenchmark): | |
"""GPQA (Graduate-Level Google-Proof Q&A) benchmark""" | |
def __init__(self): | |
super().__init__(name="GPQA", dataset_name="Idavidrein/gpqa") | |
async def load_dataset(self, sample_size: Optional[int] = None, **kwargs): | |
"""Load GPQA dataset""" | |
# GPQA has different subsets: gpqa_main, gpqa_diamond, gpqa_extended | |
subset = kwargs.get('subset', 'gpqa_main') | |
try: | |
# Set HF token if available | |
import os | |
hf_token = os.getenv('HF_TOKEN') or os.getenv('HUGGING_FACE_HUB_TOKEN') | |
if hf_token: | |
dataset = load_dataset(self.dataset_name, subset, split='train', token=hf_token) | |
else: | |
dataset = load_dataset(self.dataset_name, subset, split='train') | |
except Exception as e: | |
if "gated dataset" in str(e) or "authentication" in str(e).lower(): | |
raise Exception( | |
"GPQA dataset requires authentication. Please:\n" | |
"1. Set HF_TOKEN environment variable\n" | |
"2. Request access at https://huggingface.co/datasets/Idavidrein/gpqa\n" | |
f"Original error: {e}" | |
) | |
# Fallback to main if subset not found | |
try: | |
dataset = load_dataset(self.dataset_name, 'gpqa_main', split='train') | |
except: | |
raise e | |
self.dataset = [] | |
for sample in dataset: | |
# GPQA has these fields: Question, Correct Answer, Incorrect Answer 1-3 | |
choices = [ | |
sample.get('Correct Answer', ''), | |
sample.get('Incorrect Answer 1', ''), | |
sample.get('Incorrect Answer 2', ''), | |
sample.get('Incorrect Answer 3', '') | |
] | |
# Shuffle choices and track correct index | |
import random | |
indices = list(range(4)) | |
random.shuffle(indices) | |
shuffled_choices = [choices[i] for i in indices] | |
correct_index = indices.index(0) # 0 was the correct answer position | |
self.dataset.append({ | |
'question': sample['Question'], | |
'choices': shuffled_choices, | |
'correct_index': correct_index, | |
'subject': sample.get('Subdomain', 'Unknown'), | |
'raw_sample': sample | |
}) | |
# Shuffle dataset | |
random.shuffle(self.dataset) | |
if sample_size and len(self.dataset) > sample_size: | |
self.dataset = self.dataset[:sample_size] | |
def format_prompt(self, sample: Dict[str, Any]) -> str: | |
"""Format GPQA question as prompt matching official format""" | |
question = sample['question'] | |
choices = sample['choices'] | |
# GPQA uses a simpler format in lm-eval | |
prompt = f"""What is the correct answer to this question: {question} | |
Choices: | |
(A) {choices[0]} | |
(B) {choices[1]} | |
(C) {choices[2]} | |
(D) {choices[3]} | |
Answer:""" | |
return prompt | |
async def evaluate_sample(self, api, sample: Dict[str, Any], **kwargs) -> Tuple[bool, Dict[str, Any]]: | |
"""Evaluate a single GPQA sample""" | |
prompt = self.format_prompt(sample) | |
try: | |
response = await api.generate_with_retry(prompt, **kwargs) | |
# Extract answer from response using standard extraction | |
predicted_letter = extract_answer_mmlu(response) | |
if predicted_letter: | |
predicted_index = ord(predicted_letter) - ord('A') | |
else: | |
# If no clear answer, mark as incorrect | |
predicted_index = -1 | |
correct_index = sample['correct_index'] | |
is_correct = predicted_index == correct_index | |
result = { | |
'question': sample['question'], | |
'choices': sample['choices'], | |
'correct_answer': correct_index, | |
'predicted_answer': predicted_index, | |
'model_response': response, | |
'is_correct': is_correct, | |
'subject': sample['subject'] | |
} | |
return is_correct, result | |
except Exception as e: | |
result = { | |
'question': sample['question'], | |
'error': str(e), | |
'is_correct': False | |
} | |
return False, result |