"""Evaluation utilities matching standard implementations""" import re from typing import Optional, Union import numpy as np try: import sympy from sympy.parsing.latex import parse_latex SYMPY_AVAILABLE = True except ImportError: SYMPY_AVAILABLE = False def normalize_math_answer(answer: str) -> str: """Normalize mathematical answers following lm-eval's approach""" if not answer: return "" # Extract content after equals sign if '=' in answer: answer = answer.split('=')[-1] # Remove dollar signs and spaces answer = answer.strip() answer = answer.strip('$') # Remove text{} and textbf{} answer = re.sub(r'\\text\{([^}]*)\}', r'\1', answer) answer = re.sub(r'\\textbf\{([^}]*)\}', r'\1', answer) # Fix \fracab -> \frac{a}{b} answer = re.sub(r'\\frac([0-9a-zA-Z])([0-9a-zA-Z])', r'\\frac{\1}{\2}', answer) # Remove commas from numbers answer = re.sub(r'(\d),', r'\1', answer) # Remove specific words for word in ['square', 'units', 'integers', 'dollars', 'mph', 'inches', 'feet', 'minutes', 'cm', 'gm', 'pounds', 'meters', 'meals', 'edges', 'students', 'childrentickets', 'multiples', 'hours', 'degrees', 'ounces', 'bits', 'factorization', 'greenmarbles', 'redmarbles', 'bluemarbles']: answer = answer.replace(word, '') # Remove extra spaces answer = ' '.join(answer.split()) return answer.strip() def extract_answer_gsm8k(response: str) -> Optional[float]: """Extract answer from GSM8K response following official format""" # Look for the last number in the response numbers = re.findall(r'[-+]?\d*\.?\d+', response) if numbers: try: return float(numbers[-1]) except: pass return None def extract_answer_mmlu(response: str) -> Optional[str]: """Extract MMLU answer following official format""" # Clean response response = response.strip() # Look for single letter answer if len(response) == 1 and response in 'ABCD': return response # Look for letter followed by parenthesis or period match = re.search(r'^([ABCD])[).\s]', response) if match: return match.group(1) # Look for "answer is X" pattern match = re.search(r'answer is ([ABCD])', response, re.IGNORECASE) if match: return match.group(1).upper() # Look for first occurrence of A, B, C, or D match = re.search(r'[ABCD]', response) if match: return match.group(0) return None def calculate_accuracy_with_confidence(results: list) -> dict: """Calculate accuracy with confidence intervals""" correct = sum(1 for r in results if r.get('is_correct', False)) total = len(results) if total == 0: return { 'accuracy': 0.0, 'correct': 0, 'total': 0, 'confidence_interval': (0.0, 0.0) } accuracy = correct / total # Wilson score interval for binomial proportion z = 1.96 # 95% confidence n = total p = accuracy denominator = 1 + z**2 / n center = (p + z**2 / (2*n)) / denominator margin = z * np.sqrt(p * (1-p) / n + z**2 / (4*n**2)) / denominator lower = max(0, center - margin) upper = min(1, center + margin) return { 'accuracy': accuracy, 'correct': correct, 'total': total, 'confidence_interval': (lower, upper) } def is_math_equiv(pred: str, gold: str) -> bool: """Check mathematical equivalence using SymPy (matching lm-eval)""" # First normalize both answers pred_norm = normalize_math_answer(pred) gold_norm = normalize_math_answer(gold) # Quick string comparison if pred_norm == gold_norm: return True if not SYMPY_AVAILABLE: # Fallback to string comparison return pred_norm == gold_norm try: # Try to parse as LaTeX try: pred_expr = parse_latex(pred_norm) gold_expr = parse_latex(gold_norm) except: # Try parsing as regular SymPy expression pred_expr = sympy.sympify(pred_norm) gold_expr = sympy.sympify(gold_norm) # Check if expressions are equivalent diff = sympy.simplify(pred_expr - gold_expr) return diff == 0 or diff.is_zero except Exception: # If parsing fails, fall back to string comparison return pred_norm == gold_norm def is_gsm8k_correct(pred: str, gold: str) -> bool: """Check GSM8K answer correctness""" if pred == gold: return True try: # Try numeric comparison pred_num = float(pred) gold_num = float(gold) # GSM8K uses exact match, but we allow tiny floating point errors return abs(pred_num - gold_num) < 1e-9 except: return False