grok4-gpqa-eval / benchmarks /gsm8k_benchmark.py
TeddyYao's picture
Upload 38 files
8474f02 verified
from .base_benchmark import BaseBenchmark
from typing import Dict, Any, Optional, Tuple
from datasets import load_dataset
import re
from .prompt_templates import get_gsm8k_cot_prompt
class GSM8KBenchmark(BaseBenchmark):
"""GSM8K (Grade School Math 8K) benchmark"""
def __init__(self):
super().__init__(name="GSM8K", dataset_name="gsm8k")
async def load_dataset(self, sample_size: Optional[int] = None, **kwargs):
"""Load GSM8K dataset"""
dataset = load_dataset(self.dataset_name, 'main', split='test')
self.dataset = []
for sample in dataset:
self.dataset.append({
'question': sample['question'],
'answer': sample['answer'],
'raw_sample': sample
})
# Shuffle dataset
import random
random.shuffle(self.dataset)
if sample_size and len(self.dataset) > sample_size:
self.dataset = self.dataset[:sample_size]
def extract_answer_from_solution(self, solution: str) -> Optional[str]:
"""Extract numerical answer from GSM8K solution string"""
# GSM8K answers are in format: "... #### number"
match = re.search(r'#### ([\-0-9\.\,]+)', solution)
if match:
answer_str = match.group(1).replace(',', '')
return answer_str
return None
def extract_number_from_response(self, response: str) -> Optional[str]:
"""Extract the final numerical answer from model response"""
# Official lm-eval uses these patterns in order:
# 1. Look for "The answer is X" pattern (CoT standard)
match = re.search(r'The answer is ([\-0-9\.\,]+)\.?', response, re.IGNORECASE)
if match:
return match.group(1).replace(',', '')
# 2. Look for #### format (if model knows GSM8K format)
match = re.search(r'#### ([\-0-9\.\,]+)', response)
if match:
return match.group(1).replace(',', '')
# 3. Flexible extraction: find all numbers and take the last one
# This matches lm-eval's flexible-extract with group_select: -1
numbers = re.findall(r'(-?[$0-9.,]{2,})|(-?[0-9]+)', response)
if numbers:
# Flatten tuples and get last non-empty match
flat_numbers = [n for group in numbers for n in group if n]
if flat_numbers:
last_number = flat_numbers[-1]
# Clean the number
cleaned = last_number.replace('$', '').replace(',', '')
try:
# Validate it's a proper number
float(cleaned)
return cleaned
except:
pass
return None
def format_prompt(self, sample: Dict[str, Any]) -> str:
"""Format GSM8K question as prompt with CoT examples"""
# Use the standard CoT prompt from lm-eval
return get_gsm8k_cot_prompt(sample['question'])
async def evaluate_sample(self, api, sample: Dict[str, Any], **kwargs) -> Tuple[bool, Dict[str, Any]]:
"""Evaluate a single GSM8K sample"""
prompt = self.format_prompt(sample)
try:
response = await api.generate_with_retry(prompt, **kwargs)
# Extract correct answer
correct_answer = self.extract_answer_from_solution(sample['answer'])
# Extract model's answer
model_answer = self.extract_number_from_response(response)
# Check if answers match (exact string match after normalization)
is_correct = False
if correct_answer is not None and model_answer is not None:
# GSM8K uses exact match on normalized strings
is_correct = correct_answer == model_answer
result = {
'question': sample['question'],
'correct_answer': correct_answer,
'model_answer': model_answer,
'model_response': response,
'is_correct': is_correct,
'solution': sample['answer']
}
return is_correct, result
except Exception as e:
result = {
'question': sample['question'],
'error': str(e),
'is_correct': False
}
return False, result