Spaces:
Running
Running
File size: 4,934 Bytes
8474f02 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
"""Evaluation utilities matching standard implementations"""
import re
from typing import Optional, Union
import numpy as np
try:
import sympy
from sympy.parsing.latex import parse_latex
SYMPY_AVAILABLE = True
except ImportError:
SYMPY_AVAILABLE = False
def normalize_math_answer(answer: str) -> str:
"""Normalize mathematical answers following lm-eval's approach"""
if not answer:
return ""
# Extract content after equals sign
if '=' in answer:
answer = answer.split('=')[-1]
# Remove dollar signs and spaces
answer = answer.strip()
answer = answer.strip('$')
# Remove text{} and textbf{}
answer = re.sub(r'\\text\{([^}]*)\}', r'\1', answer)
answer = re.sub(r'\\textbf\{([^}]*)\}', r'\1', answer)
# Fix \fracab -> \frac{a}{b}
answer = re.sub(r'\\frac([0-9a-zA-Z])([0-9a-zA-Z])', r'\\frac{\1}{\2}', answer)
# Remove commas from numbers
answer = re.sub(r'(\d),', r'\1', answer)
# Remove specific words
for word in ['square', 'units', 'integers', 'dollars', 'mph', 'inches', 'feet', 'minutes', 'cm', 'gm', 'pounds', 'meters', 'meals', 'edges', 'students', 'childrentickets', 'multiples', 'hours', 'degrees', 'ounces', 'bits', 'factorization', 'greenmarbles', 'redmarbles', 'bluemarbles']:
answer = answer.replace(word, '')
# Remove extra spaces
answer = ' '.join(answer.split())
return answer.strip()
def extract_answer_gsm8k(response: str) -> Optional[float]:
"""Extract answer from GSM8K response following official format"""
# Look for the last number in the response
numbers = re.findall(r'[-+]?\d*\.?\d+', response)
if numbers:
try:
return float(numbers[-1])
except:
pass
return None
def extract_answer_mmlu(response: str) -> Optional[str]:
"""Extract MMLU answer following official format"""
# Clean response
response = response.strip()
# Look for single letter answer
if len(response) == 1 and response in 'ABCD':
return response
# Look for letter followed by parenthesis or period
match = re.search(r'^([ABCD])[).\s]', response)
if match:
return match.group(1)
# Look for "answer is X" pattern
match = re.search(r'answer is ([ABCD])', response, re.IGNORECASE)
if match:
return match.group(1).upper()
# Look for first occurrence of A, B, C, or D
match = re.search(r'[ABCD]', response)
if match:
return match.group(0)
return None
def calculate_accuracy_with_confidence(results: list) -> dict:
"""Calculate accuracy with confidence intervals"""
correct = sum(1 for r in results if r.get('is_correct', False))
total = len(results)
if total == 0:
return {
'accuracy': 0.0,
'correct': 0,
'total': 0,
'confidence_interval': (0.0, 0.0)
}
accuracy = correct / total
# Wilson score interval for binomial proportion
z = 1.96 # 95% confidence
n = total
p = accuracy
denominator = 1 + z**2 / n
center = (p + z**2 / (2*n)) / denominator
margin = z * np.sqrt(p * (1-p) / n + z**2 / (4*n**2)) / denominator
lower = max(0, center - margin)
upper = min(1, center + margin)
return {
'accuracy': accuracy,
'correct': correct,
'total': total,
'confidence_interval': (lower, upper)
}
def is_math_equiv(pred: str, gold: str) -> bool:
"""Check mathematical equivalence using SymPy (matching lm-eval)"""
# First normalize both answers
pred_norm = normalize_math_answer(pred)
gold_norm = normalize_math_answer(gold)
# Quick string comparison
if pred_norm == gold_norm:
return True
if not SYMPY_AVAILABLE:
# Fallback to string comparison
return pred_norm == gold_norm
try:
# Try to parse as LaTeX
try:
pred_expr = parse_latex(pred_norm)
gold_expr = parse_latex(gold_norm)
except:
# Try parsing as regular SymPy expression
pred_expr = sympy.sympify(pred_norm)
gold_expr = sympy.sympify(gold_norm)
# Check if expressions are equivalent
diff = sympy.simplify(pred_expr - gold_expr)
return diff == 0 or diff.is_zero
except Exception:
# If parsing fails, fall back to string comparison
return pred_norm == gold_norm
def is_gsm8k_correct(pred: str, gold: str) -> bool:
"""Check GSM8K answer correctness"""
if pred == gold:
return True
try:
# Try numeric comparison
pred_num = float(pred)
gold_num = float(gold)
# GSM8K uses exact match, but we allow tiny floating point errors
return abs(pred_num - gold_num) < 1e-9
except:
return False |