grok4-gpqa-eval / benchmarks /base_benchmark.py
TeddyYao's picture
Upload 38 files
8474f02 verified
from abc import ABC, abstractmethod
from typing import List, Dict, Any, Optional, Tuple
import asyncio
from dataclasses import dataclass
import time
from tqdm import tqdm
@dataclass
class BenchmarkResult:
"""Container for benchmark results"""
benchmark_name: str
model_name: str
total_questions: int
correct: int
accuracy: float
avg_response_time: float
raw_results: List[Dict[str, Any]]
class BaseBenchmark(ABC):
"""Base class for all benchmark implementations"""
def __init__(self, name: str, dataset_name: str = None):
self.name = name
self.dataset_name = dataset_name or name
self.dataset = None
self.results = []
@abstractmethod
async def load_dataset(self, sample_size: Optional[int] = None, **kwargs):
"""Load the benchmark dataset"""
pass
@abstractmethod
async def evaluate_sample(self, api, sample: Dict[str, Any], **kwargs) -> Tuple[bool, Dict[str, Any]]:
"""Evaluate a single sample"""
pass
@abstractmethod
def format_prompt(self, sample: Dict[str, Any]) -> str:
"""Format the prompt for the model"""
pass
async def run_benchmark(self, api, sample_size: Optional[int] = None, **kwargs) -> BenchmarkResult:
"""Run the benchmark on the given API"""
print(f"Running {self.name} benchmark on {api.model_name}...")
# Load dataset
await self.load_dataset(sample_size, **kwargs)
if not self.dataset:
raise ValueError(f"No dataset loaded for {self.name}")
# Prepare samples
samples = self.dataset if sample_size is None else self.dataset[:sample_size]
total_samples = len(samples)
# Run evaluation
correct_count = 0
response_times = []
raw_results = []
# Use async semaphore for concurrent requests
concurrent_limit = kwargs.get('concurrent_requests', 5)
semaphore = asyncio.Semaphore(concurrent_limit)
async def evaluate_with_semaphore(sample, idx):
async with semaphore:
start_time = time.time()
is_correct, result = await self.evaluate_sample(api, sample, **kwargs)
end_time = time.time()
result['response_time'] = end_time - start_time
result['index'] = idx
return is_correct, result
# Create tasks for all samples
tasks = [evaluate_with_semaphore(sample, idx) for idx, sample in enumerate(samples)]
# Run with progress bar
# Add imports needed for progress saving
import json
import os
with tqdm(total=total_samples, desc=f"{self.name}") as pbar:
for coro in asyncio.as_completed(tasks):
is_correct, result = await coro
if is_correct:
correct_count += 1
response_times.append(result['response_time'])
raw_results.append(result)
pbar.update(1)
# --- START: REAL-TIME PROGRESS SAVING ---
# Every 10 samples, save the progress to a file
if pbar.n > 0 and pbar.n % 10 == 0:
# Ensure results directory exists
results_dir = kwargs.get('output_dir', 'results')
os.makedirs(results_dir, exist_ok=True)
progress_path = os.path.join(results_dir, f'{self.name}_progress.json')
# Sort results by index before saving
sorted_progress = sorted(raw_results, key=lambda x: x['index'])
try:
with open(progress_path, 'w') as f:
json.dump(sorted_progress, f, indent=2)
except Exception as e:
print(f"Error saving progress: {e}")
# --- END: REAL-TIME PROGRESS SAVING ---
# Calculate metrics
accuracy = correct_count / total_samples if total_samples > 0 else 0
avg_response_time = sum(response_times) / len(response_times) if response_times else 0
return BenchmarkResult(
benchmark_name=self.name,
model_name=api.model_name,
total_questions=total_samples,
correct=correct_count,
accuracy=accuracy,
avg_response_time=avg_response_time,
raw_results=sorted(raw_results, key=lambda x: x['index'])
)