from .base_benchmark import BaseBenchmark from typing import Dict, Any, Optional, Tuple from datasets import load_dataset import re from .prompt_templates import get_gsm8k_cot_prompt class GSM8KBenchmark(BaseBenchmark): """GSM8K (Grade School Math 8K) benchmark""" def __init__(self): super().__init__(name="GSM8K", dataset_name="gsm8k") async def load_dataset(self, sample_size: Optional[int] = None, **kwargs): """Load GSM8K dataset""" dataset = load_dataset(self.dataset_name, 'main', split='test') self.dataset = [] for sample in dataset: self.dataset.append({ 'question': sample['question'], 'answer': sample['answer'], 'raw_sample': sample }) # Shuffle dataset import random random.shuffle(self.dataset) if sample_size and len(self.dataset) > sample_size: self.dataset = self.dataset[:sample_size] def extract_answer_from_solution(self, solution: str) -> Optional[str]: """Extract numerical answer from GSM8K solution string""" # GSM8K answers are in format: "... #### number" match = re.search(r'#### ([\-0-9\.\,]+)', solution) if match: answer_str = match.group(1).replace(',', '') return answer_str return None def extract_number_from_response(self, response: str) -> Optional[str]: """Extract the final numerical answer from model response""" # Official lm-eval uses these patterns in order: # 1. Look for "The answer is X" pattern (CoT standard) match = re.search(r'The answer is ([\-0-9\.\,]+)\.?', response, re.IGNORECASE) if match: return match.group(1).replace(',', '') # 2. Look for #### format (if model knows GSM8K format) match = re.search(r'#### ([\-0-9\.\,]+)', response) if match: return match.group(1).replace(',', '') # 3. Flexible extraction: find all numbers and take the last one # This matches lm-eval's flexible-extract with group_select: -1 numbers = re.findall(r'(-?[$0-9.,]{2,})|(-?[0-9]+)', response) if numbers: # Flatten tuples and get last non-empty match flat_numbers = [n for group in numbers for n in group if n] if flat_numbers: last_number = flat_numbers[-1] # Clean the number cleaned = last_number.replace('$', '').replace(',', '') try: # Validate it's a proper number float(cleaned) return cleaned except: pass return None def format_prompt(self, sample: Dict[str, Any]) -> str: """Format GSM8K question as prompt with CoT examples""" # Use the standard CoT prompt from lm-eval return get_gsm8k_cot_prompt(sample['question']) async def evaluate_sample(self, api, sample: Dict[str, Any], **kwargs) -> Tuple[bool, Dict[str, Any]]: """Evaluate a single GSM8K sample""" prompt = self.format_prompt(sample) try: response = await api.generate_with_retry(prompt, **kwargs) # Extract correct answer correct_answer = self.extract_answer_from_solution(sample['answer']) # Extract model's answer model_answer = self.extract_number_from_response(response) # Check if answers match (exact string match after normalization) is_correct = False if correct_answer is not None and model_answer is not None: # GSM8K uses exact match on normalized strings is_correct = correct_answer == model_answer result = { 'question': sample['question'], 'correct_answer': correct_answer, 'model_answer': model_answer, 'model_response': response, 'is_correct': is_correct, 'solution': sample['answer'] } return is_correct, result except Exception as e: result = { 'question': sample['question'], 'error': str(e), 'is_correct': False } return False, result