Spaces:
Running
Running
"""Evaluation utilities matching standard implementations""" | |
import re | |
from typing import Optional, Union | |
import numpy as np | |
try: | |
import sympy | |
from sympy.parsing.latex import parse_latex | |
SYMPY_AVAILABLE = True | |
except ImportError: | |
SYMPY_AVAILABLE = False | |
def normalize_math_answer(answer: str) -> str: | |
"""Normalize mathematical answers following lm-eval's approach""" | |
if not answer: | |
return "" | |
# Extract content after equals sign | |
if '=' in answer: | |
answer = answer.split('=')[-1] | |
# Remove dollar signs and spaces | |
answer = answer.strip() | |
answer = answer.strip('$') | |
# Remove text{} and textbf{} | |
answer = re.sub(r'\\text\{([^}]*)\}', r'\1', answer) | |
answer = re.sub(r'\\textbf\{([^}]*)\}', r'\1', answer) | |
# Fix \fracab -> \frac{a}{b} | |
answer = re.sub(r'\\frac([0-9a-zA-Z])([0-9a-zA-Z])', r'\\frac{\1}{\2}', answer) | |
# Remove commas from numbers | |
answer = re.sub(r'(\d),', r'\1', answer) | |
# Remove specific words | |
for word in ['square', 'units', 'integers', 'dollars', 'mph', 'inches', 'feet', 'minutes', 'cm', 'gm', 'pounds', 'meters', 'meals', 'edges', 'students', 'childrentickets', 'multiples', 'hours', 'degrees', 'ounces', 'bits', 'factorization', 'greenmarbles', 'redmarbles', 'bluemarbles']: | |
answer = answer.replace(word, '') | |
# Remove extra spaces | |
answer = ' '.join(answer.split()) | |
return answer.strip() | |
def extract_answer_gsm8k(response: str) -> Optional[float]: | |
"""Extract answer from GSM8K response following official format""" | |
# Look for the last number in the response | |
numbers = re.findall(r'[-+]?\d*\.?\d+', response) | |
if numbers: | |
try: | |
return float(numbers[-1]) | |
except: | |
pass | |
return None | |
def extract_answer_mmlu(response: str) -> Optional[str]: | |
"""Extract MMLU answer following official format""" | |
# Clean response | |
response = response.strip() | |
# Look for single letter answer | |
if len(response) == 1 and response in 'ABCD': | |
return response | |
# Look for letter followed by parenthesis or period | |
match = re.search(r'^([ABCD])[).\s]', response) | |
if match: | |
return match.group(1) | |
# Look for "answer is X" pattern | |
match = re.search(r'answer is ([ABCD])', response, re.IGNORECASE) | |
if match: | |
return match.group(1).upper() | |
# Look for first occurrence of A, B, C, or D | |
match = re.search(r'[ABCD]', response) | |
if match: | |
return match.group(0) | |
return None | |
def calculate_accuracy_with_confidence(results: list) -> dict: | |
"""Calculate accuracy with confidence intervals""" | |
correct = sum(1 for r in results if r.get('is_correct', False)) | |
total = len(results) | |
if total == 0: | |
return { | |
'accuracy': 0.0, | |
'correct': 0, | |
'total': 0, | |
'confidence_interval': (0.0, 0.0) | |
} | |
accuracy = correct / total | |
# Wilson score interval for binomial proportion | |
z = 1.96 # 95% confidence | |
n = total | |
p = accuracy | |
denominator = 1 + z**2 / n | |
center = (p + z**2 / (2*n)) / denominator | |
margin = z * np.sqrt(p * (1-p) / n + z**2 / (4*n**2)) / denominator | |
lower = max(0, center - margin) | |
upper = min(1, center + margin) | |
return { | |
'accuracy': accuracy, | |
'correct': correct, | |
'total': total, | |
'confidence_interval': (lower, upper) | |
} | |
def is_math_equiv(pred: str, gold: str) -> bool: | |
"""Check mathematical equivalence using SymPy (matching lm-eval)""" | |
# First normalize both answers | |
pred_norm = normalize_math_answer(pred) | |
gold_norm = normalize_math_answer(gold) | |
# Quick string comparison | |
if pred_norm == gold_norm: | |
return True | |
if not SYMPY_AVAILABLE: | |
# Fallback to string comparison | |
return pred_norm == gold_norm | |
try: | |
# Try to parse as LaTeX | |
try: | |
pred_expr = parse_latex(pred_norm) | |
gold_expr = parse_latex(gold_norm) | |
except: | |
# Try parsing as regular SymPy expression | |
pred_expr = sympy.sympify(pred_norm) | |
gold_expr = sympy.sympify(gold_norm) | |
# Check if expressions are equivalent | |
diff = sympy.simplify(pred_expr - gold_expr) | |
return diff == 0 or diff.is_zero | |
except Exception: | |
# If parsing fails, fall back to string comparison | |
return pred_norm == gold_norm | |
def is_gsm8k_correct(pred: str, gold: str) -> bool: | |
"""Check GSM8K answer correctness""" | |
if pred == gold: | |
return True | |
try: | |
# Try numeric comparison | |
pred_num = float(pred) | |
gold_num = float(gold) | |
# GSM8K uses exact match, but we allow tiny floating point errors | |
return abs(pred_num - gold_num) < 1e-9 | |
except: | |
return False |