grok4-gpqa-eval / benchmarks /evaluation_utils.py
TeddyYao's picture
Upload 38 files
8474f02 verified
"""Evaluation utilities matching standard implementations"""
import re
from typing import Optional, Union
import numpy as np
try:
import sympy
from sympy.parsing.latex import parse_latex
SYMPY_AVAILABLE = True
except ImportError:
SYMPY_AVAILABLE = False
def normalize_math_answer(answer: str) -> str:
"""Normalize mathematical answers following lm-eval's approach"""
if not answer:
return ""
# Extract content after equals sign
if '=' in answer:
answer = answer.split('=')[-1]
# Remove dollar signs and spaces
answer = answer.strip()
answer = answer.strip('$')
# Remove text{} and textbf{}
answer = re.sub(r'\\text\{([^}]*)\}', r'\1', answer)
answer = re.sub(r'\\textbf\{([^}]*)\}', r'\1', answer)
# Fix \fracab -> \frac{a}{b}
answer = re.sub(r'\\frac([0-9a-zA-Z])([0-9a-zA-Z])', r'\\frac{\1}{\2}', answer)
# Remove commas from numbers
answer = re.sub(r'(\d),', r'\1', answer)
# Remove specific words
for word in ['square', 'units', 'integers', 'dollars', 'mph', 'inches', 'feet', 'minutes', 'cm', 'gm', 'pounds', 'meters', 'meals', 'edges', 'students', 'childrentickets', 'multiples', 'hours', 'degrees', 'ounces', 'bits', 'factorization', 'greenmarbles', 'redmarbles', 'bluemarbles']:
answer = answer.replace(word, '')
# Remove extra spaces
answer = ' '.join(answer.split())
return answer.strip()
def extract_answer_gsm8k(response: str) -> Optional[float]:
"""Extract answer from GSM8K response following official format"""
# Look for the last number in the response
numbers = re.findall(r'[-+]?\d*\.?\d+', response)
if numbers:
try:
return float(numbers[-1])
except:
pass
return None
def extract_answer_mmlu(response: str) -> Optional[str]:
"""Extract MMLU answer following official format"""
# Clean response
response = response.strip()
# Look for single letter answer
if len(response) == 1 and response in 'ABCD':
return response
# Look for letter followed by parenthesis or period
match = re.search(r'^([ABCD])[).\s]', response)
if match:
return match.group(1)
# Look for "answer is X" pattern
match = re.search(r'answer is ([ABCD])', response, re.IGNORECASE)
if match:
return match.group(1).upper()
# Look for first occurrence of A, B, C, or D
match = re.search(r'[ABCD]', response)
if match:
return match.group(0)
return None
def calculate_accuracy_with_confidence(results: list) -> dict:
"""Calculate accuracy with confidence intervals"""
correct = sum(1 for r in results if r.get('is_correct', False))
total = len(results)
if total == 0:
return {
'accuracy': 0.0,
'correct': 0,
'total': 0,
'confidence_interval': (0.0, 0.0)
}
accuracy = correct / total
# Wilson score interval for binomial proportion
z = 1.96 # 95% confidence
n = total
p = accuracy
denominator = 1 + z**2 / n
center = (p + z**2 / (2*n)) / denominator
margin = z * np.sqrt(p * (1-p) / n + z**2 / (4*n**2)) / denominator
lower = max(0, center - margin)
upper = min(1, center + margin)
return {
'accuracy': accuracy,
'correct': correct,
'total': total,
'confidence_interval': (lower, upper)
}
def is_math_equiv(pred: str, gold: str) -> bool:
"""Check mathematical equivalence using SymPy (matching lm-eval)"""
# First normalize both answers
pred_norm = normalize_math_answer(pred)
gold_norm = normalize_math_answer(gold)
# Quick string comparison
if pred_norm == gold_norm:
return True
if not SYMPY_AVAILABLE:
# Fallback to string comparison
return pred_norm == gold_norm
try:
# Try to parse as LaTeX
try:
pred_expr = parse_latex(pred_norm)
gold_expr = parse_latex(gold_norm)
except:
# Try parsing as regular SymPy expression
pred_expr = sympy.sympify(pred_norm)
gold_expr = sympy.sympify(gold_norm)
# Check if expressions are equivalent
diff = sympy.simplify(pred_expr - gold_expr)
return diff == 0 or diff.is_zero
except Exception:
# If parsing fails, fall back to string comparison
return pred_norm == gold_norm
def is_gsm8k_correct(pred: str, gold: str) -> bool:
"""Check GSM8K answer correctness"""
if pred == gold:
return True
try:
# Try numeric comparison
pred_num = float(pred)
gold_num = float(gold)
# GSM8K uses exact match, but we allow tiny floating point errors
return abs(pred_num - gold_num) < 1e-9
except:
return False