from .base_benchmark import BaseBenchmark from typing import Dict, Any, Optional, Tuple, List from datasets import load_dataset import re import random from .prompt_templates import get_mmlu_prompt from .evaluation_utils import extract_answer_mmlu class MMLUBenchmark(BaseBenchmark): """MMLU (Massive Multitask Language Understanding) benchmark""" SUBJECTS = [ 'abstract_algebra', 'anatomy', 'astronomy', 'business_ethics', 'clinical_knowledge', 'college_biology', 'college_chemistry', 'college_computer_science', 'college_mathematics', 'college_medicine', 'college_physics', 'computer_security', 'conceptual_physics', 'econometrics', 'electrical_engineering', 'elementary_mathematics', 'formal_logic', 'global_facts', 'high_school_biology', 'high_school_chemistry', 'high_school_computer_science', 'high_school_european_history', 'high_school_geography', 'high_school_government_and_politics', 'high_school_macroeconomics', 'high_school_mathematics', 'high_school_microeconomics', 'high_school_physics', 'high_school_psychology', 'high_school_statistics', 'high_school_us_history', 'high_school_world_history', 'human_aging', 'human_sexuality', 'international_law', 'jurisprudence', 'logical_fallacies', 'machine_learning', 'management', 'marketing', 'medical_genetics', 'miscellaneous', 'moral_disputes', 'moral_scenarios', 'nutrition', 'philosophy', 'prehistory', 'professional_accounting', 'professional_law', 'professional_medicine', 'professional_psychology', 'public_relations', 'security_studies', 'sociology', 'us_foreign_policy', 'virology', 'world_religions' ] def __init__(self): super().__init__(name="MMLU", dataset_name="cais/mmlu") async def load_dataset(self, sample_size: Optional[int] = None, **kwargs): """Load MMLU dataset""" subjects = kwargs.get('subjects', ['all']) if 'all' in subjects: subjects = self.SUBJECTS else: subjects = [s for s in subjects if s in self.SUBJECTS] self.dataset = [] self.few_shot_examples = {} # Store few-shot examples per subject for subject in subjects: try: # Load dev split for few-shot examples dev_ds = load_dataset(self.dataset_name, subject, split='dev') # Standard MMLU uses 5-shot self.few_shot_examples[subject] = [ { 'question': ex['question'], 'choices': ex['choices'], 'answer': ex['answer'] } for ex in list(dev_ds)[:5] ] # Load test split for evaluation test_ds = load_dataset(self.dataset_name, subject, split='test') for sample in test_ds: self.dataset.append({ 'subject': subject, 'question': sample['question'], 'choices': sample['choices'], 'answer': sample['answer'], # 0-3 index 'raw_sample': sample }) except Exception as e: print(f"Error loading {subject}: {e}") continue # Shuffle dataset random.shuffle(self.dataset) if sample_size and len(self.dataset) > sample_size: self.dataset = self.dataset[:sample_size] def format_prompt(self, sample: Dict[str, Any]) -> str: """Format MMLU question as prompt with few-shot examples""" subject = sample['subject'] few_shot_examples = self.few_shot_examples.get(subject, []) return get_mmlu_prompt( question=sample['question'], choices=sample['choices'], subject=subject.replace('_', ' ').title(), few_shot_examples=few_shot_examples ) async def evaluate_sample(self, api, sample: Dict[str, Any], **kwargs) -> Tuple[bool, Dict[str, Any]]: """Evaluate a single MMLU sample""" prompt = self.format_prompt(sample) try: response = await api.generate_with_retry(prompt, **kwargs) # Extract answer from response using standard extraction predicted_letter = extract_answer_mmlu(response) if predicted_letter: predicted_index = ord(predicted_letter) - ord('A') else: # If no clear answer, mark as incorrect predicted_index = -1 correct_index = sample['answer'] is_correct = predicted_index == correct_index result = { 'subject': sample['subject'], 'question': sample['question'], 'choices': sample['choices'], 'correct_answer': correct_index, 'predicted_answer': predicted_index, 'model_response': response, 'is_correct': is_correct } return is_correct, result except Exception as e: result = { 'subject': sample['subject'], 'question': sample['question'], 'error': str(e), 'is_correct': False } return False, result