Spaces:
Running
Running
from .base_benchmark import BaseBenchmark | |
from typing import Dict, Any, Optional, Tuple | |
from datasets import load_dataset | |
import re | |
from .prompt_templates import get_gsm8k_cot_prompt | |
class GSM8KBenchmark(BaseBenchmark): | |
"""GSM8K (Grade School Math 8K) benchmark""" | |
def __init__(self): | |
super().__init__(name="GSM8K", dataset_name="gsm8k") | |
async def load_dataset(self, sample_size: Optional[int] = None, **kwargs): | |
"""Load GSM8K dataset""" | |
dataset = load_dataset(self.dataset_name, 'main', split='test') | |
self.dataset = [] | |
for sample in dataset: | |
self.dataset.append({ | |
'question': sample['question'], | |
'answer': sample['answer'], | |
'raw_sample': sample | |
}) | |
# Shuffle dataset | |
import random | |
random.shuffle(self.dataset) | |
if sample_size and len(self.dataset) > sample_size: | |
self.dataset = self.dataset[:sample_size] | |
def extract_answer_from_solution(self, solution: str) -> Optional[str]: | |
"""Extract numerical answer from GSM8K solution string""" | |
# GSM8K answers are in format: "... #### number" | |
match = re.search(r'#### ([\-0-9\.\,]+)', solution) | |
if match: | |
answer_str = match.group(1).replace(',', '') | |
return answer_str | |
return None | |
def extract_number_from_response(self, response: str) -> Optional[str]: | |
"""Extract the final numerical answer from model response""" | |
# Official lm-eval uses these patterns in order: | |
# 1. Look for "The answer is X" pattern (CoT standard) | |
match = re.search(r'The answer is ([\-0-9\.\,]+)\.?', response, re.IGNORECASE) | |
if match: | |
return match.group(1).replace(',', '') | |
# 2. Look for #### format (if model knows GSM8K format) | |
match = re.search(r'#### ([\-0-9\.\,]+)', response) | |
if match: | |
return match.group(1).replace(',', '') | |
# 3. Flexible extraction: find all numbers and take the last one | |
# This matches lm-eval's flexible-extract with group_select: -1 | |
numbers = re.findall(r'(-?[$0-9.,]{2,})|(-?[0-9]+)', response) | |
if numbers: | |
# Flatten tuples and get last non-empty match | |
flat_numbers = [n for group in numbers for n in group if n] | |
if flat_numbers: | |
last_number = flat_numbers[-1] | |
# Clean the number | |
cleaned = last_number.replace('$', '').replace(',', '') | |
try: | |
# Validate it's a proper number | |
float(cleaned) | |
return cleaned | |
except: | |
pass | |
return None | |
def format_prompt(self, sample: Dict[str, Any]) -> str: | |
"""Format GSM8K question as prompt with CoT examples""" | |
# Use the standard CoT prompt from lm-eval | |
return get_gsm8k_cot_prompt(sample['question']) | |
async def evaluate_sample(self, api, sample: Dict[str, Any], **kwargs) -> Tuple[bool, Dict[str, Any]]: | |
"""Evaluate a single GSM8K sample""" | |
prompt = self.format_prompt(sample) | |
try: | |
response = await api.generate_with_retry(prompt, **kwargs) | |
# Extract correct answer | |
correct_answer = self.extract_answer_from_solution(sample['answer']) | |
# Extract model's answer | |
model_answer = self.extract_number_from_response(response) | |
# Check if answers match (exact string match after normalization) | |
is_correct = False | |
if correct_answer is not None and model_answer is not None: | |
# GSM8K uses exact match on normalized strings | |
is_correct = correct_answer == model_answer | |
result = { | |
'question': sample['question'], | |
'correct_answer': correct_answer, | |
'model_answer': model_answer, | |
'model_response': response, | |
'is_correct': is_correct, | |
'solution': sample['answer'] | |
} | |
return is_correct, result | |
except Exception as e: | |
result = { | |
'question': sample['question'], | |
'error': str(e), | |
'is_correct': False | |
} | |
return False, result |