from .base_benchmark import BaseBenchmark from typing import Dict, Any, Optional, Tuple from datasets import load_dataset import re from .evaluation_utils import normalize_math_answer, is_math_equiv class MATHBenchmark(BaseBenchmark): """MATH (Mathematics) benchmark for competition-level problems""" LEVELS = ['Level 1', 'Level 2', 'Level 3', 'Level 4', 'Level 5'] TYPES = ['Algebra', 'Counting & Probability', 'Geometry', 'Intermediate Algebra', 'Number Theory', 'Prealgebra', 'Precalculus'] def __init__(self): super().__init__(name="MATH", dataset_name="hendrycks/competition_math") async def load_dataset(self, sample_size: Optional[int] = None, **kwargs): """Load MATH dataset""" dataset = load_dataset(self.dataset_name, split='test') # Filter by difficulty level if specified difficulty_levels = kwargs.get('difficulty', ['all']) if 'all' not in difficulty_levels: dataset = dataset.filter(lambda x: x['level'] in difficulty_levels) self.dataset = [] for sample in dataset: self.dataset.append({ 'problem': sample['problem'], 'solution': sample['solution'], 'level': sample['level'], 'type': sample['type'], 'raw_sample': sample }) # Shuffle dataset import random random.shuffle(self.dataset) if sample_size and len(self.dataset) > sample_size: self.dataset = self.dataset[:sample_size] def extract_answer(self, solution: str) -> Optional[str]: """Extract the final answer from MATH solution using lm-eval's method""" # Find all boxed content boxed_matches = re.findall(r'\\boxed\{([^{}]*)\}', solution) fbox_matches = re.findall(r'\\fbox\{([^{}]*)\}', solution) all_matches = boxed_matches + fbox_matches if all_matches: # Return the last boxed answer return all_matches[-1].strip() return None def extract_model_answer(self, response: str) -> Optional[str]: """Extract answer from model response""" # Try to find boxed answer first answer = self.extract_answer(response) if answer: return answer # If no boxed answer, look for common patterns # "The answer is X" match = re.search(r'answer is[\s:]*([^.\n]+)', response, re.IGNORECASE) if match: return match.group(1).strip() # "Therefore, X" match = re.search(r'therefore[,\s]+([^.\n]+)', response, re.IGNORECASE) if match: return match.group(1).strip() return None def format_prompt(self, sample: Dict[str, Any]) -> str: """Format MATH problem as prompt""" prompt = f"""Solve the following mathematics problem step by step. Show all your work and put your final answer in the format \\boxed{{answer}}. Problem: {sample['problem']} Solution:""" return prompt async def evaluate_sample(self, api, sample: Dict[str, Any], **kwargs) -> Tuple[bool, Dict[str, Any]]: """Evaluate a single MATH sample""" prompt = self.format_prompt(sample) try: response = await api.generate_with_retry(prompt, **kwargs) # Extract correct answer correct_answer = self.extract_answer(sample['solution']) # Extract model's answer model_answer = self.extract_model_answer(response) # Compare answers using mathematical equivalence is_correct = False if correct_answer and model_answer: # Use the official equivalence checking is_correct = is_math_equiv(model_answer, correct_answer) result = { 'problem': sample['problem'], 'level': sample['level'], 'type': sample['type'], 'correct_answer': correct_answer, 'model_answer': model_answer, 'model_response': response, 'is_correct': is_correct } return is_correct, result except Exception as e: result = { 'problem': sample['problem'], 'level': sample['level'], 'type': sample['type'], 'error': str(e), 'is_correct': False } return False, result