grok4-gpqa-eval / benchmarks /mmlu_benchmark.py
TeddyYao's picture
Upload 38 files
8474f02 verified
from .base_benchmark import BaseBenchmark
from typing import Dict, Any, Optional, Tuple, List
from datasets import load_dataset
import re
import random
from .prompt_templates import get_mmlu_prompt
from .evaluation_utils import extract_answer_mmlu
class MMLUBenchmark(BaseBenchmark):
"""MMLU (Massive Multitask Language Understanding) benchmark"""
SUBJECTS = [
'abstract_algebra', 'anatomy', 'astronomy', 'business_ethics',
'clinical_knowledge', 'college_biology', 'college_chemistry',
'college_computer_science', 'college_mathematics', 'college_medicine',
'college_physics', 'computer_security', 'conceptual_physics',
'econometrics', 'electrical_engineering', 'elementary_mathematics',
'formal_logic', 'global_facts', 'high_school_biology',
'high_school_chemistry', 'high_school_computer_science',
'high_school_european_history', 'high_school_geography',
'high_school_government_and_politics', 'high_school_macroeconomics',
'high_school_mathematics', 'high_school_microeconomics',
'high_school_physics', 'high_school_psychology', 'high_school_statistics',
'high_school_us_history', 'high_school_world_history', 'human_aging',
'human_sexuality', 'international_law', 'jurisprudence',
'logical_fallacies', 'machine_learning', 'management', 'marketing',
'medical_genetics', 'miscellaneous', 'moral_disputes', 'moral_scenarios',
'nutrition', 'philosophy', 'prehistory', 'professional_accounting',
'professional_law', 'professional_medicine', 'professional_psychology',
'public_relations', 'security_studies', 'sociology', 'us_foreign_policy',
'virology', 'world_religions'
]
def __init__(self):
super().__init__(name="MMLU", dataset_name="cais/mmlu")
async def load_dataset(self, sample_size: Optional[int] = None, **kwargs):
"""Load MMLU dataset"""
subjects = kwargs.get('subjects', ['all'])
if 'all' in subjects:
subjects = self.SUBJECTS
else:
subjects = [s for s in subjects if s in self.SUBJECTS]
self.dataset = []
self.few_shot_examples = {} # Store few-shot examples per subject
for subject in subjects:
try:
# Load dev split for few-shot examples
dev_ds = load_dataset(self.dataset_name, subject, split='dev')
# Standard MMLU uses 5-shot
self.few_shot_examples[subject] = [
{
'question': ex['question'],
'choices': ex['choices'],
'answer': ex['answer']
}
for ex in list(dev_ds)[:5]
]
# Load test split for evaluation
test_ds = load_dataset(self.dataset_name, subject, split='test')
for sample in test_ds:
self.dataset.append({
'subject': subject,
'question': sample['question'],
'choices': sample['choices'],
'answer': sample['answer'], # 0-3 index
'raw_sample': sample
})
except Exception as e:
print(f"Error loading {subject}: {e}")
continue
# Shuffle dataset
random.shuffle(self.dataset)
if sample_size and len(self.dataset) > sample_size:
self.dataset = self.dataset[:sample_size]
def format_prompt(self, sample: Dict[str, Any]) -> str:
"""Format MMLU question as prompt with few-shot examples"""
subject = sample['subject']
few_shot_examples = self.few_shot_examples.get(subject, [])
return get_mmlu_prompt(
question=sample['question'],
choices=sample['choices'],
subject=subject.replace('_', ' ').title(),
few_shot_examples=few_shot_examples
)
async def evaluate_sample(self, api, sample: Dict[str, Any], **kwargs) -> Tuple[bool, Dict[str, Any]]:
"""Evaluate a single MMLU sample"""
prompt = self.format_prompt(sample)
try:
response = await api.generate_with_retry(prompt, **kwargs)
# Extract answer from response using standard extraction
predicted_letter = extract_answer_mmlu(response)
if predicted_letter:
predicted_index = ord(predicted_letter) - ord('A')
else:
# If no clear answer, mark as incorrect
predicted_index = -1
correct_index = sample['answer']
is_correct = predicted_index == correct_index
result = {
'subject': sample['subject'],
'question': sample['question'],
'choices': sample['choices'],
'correct_answer': correct_index,
'predicted_answer': predicted_index,
'model_response': response,
'is_correct': is_correct
}
return is_correct, result
except Exception as e:
result = {
'subject': sample['subject'],
'question': sample['question'],
'error': str(e),
'is_correct': False
}
return False, result