Spaces:
Running
Running
from abc import ABC, abstractmethod | |
from typing import List, Dict, Any, Optional, Tuple | |
import asyncio | |
from dataclasses import dataclass | |
import time | |
from tqdm import tqdm | |
class BenchmarkResult: | |
"""Container for benchmark results""" | |
benchmark_name: str | |
model_name: str | |
total_questions: int | |
correct: int | |
accuracy: float | |
avg_response_time: float | |
raw_results: List[Dict[str, Any]] | |
class BaseBenchmark(ABC): | |
"""Base class for all benchmark implementations""" | |
def __init__(self, name: str, dataset_name: str = None): | |
self.name = name | |
self.dataset_name = dataset_name or name | |
self.dataset = None | |
self.results = [] | |
async def load_dataset(self, sample_size: Optional[int] = None, **kwargs): | |
"""Load the benchmark dataset""" | |
pass | |
async def evaluate_sample(self, api, sample: Dict[str, Any], **kwargs) -> Tuple[bool, Dict[str, Any]]: | |
"""Evaluate a single sample""" | |
pass | |
def format_prompt(self, sample: Dict[str, Any]) -> str: | |
"""Format the prompt for the model""" | |
pass | |
async def run_benchmark(self, api, sample_size: Optional[int] = None, **kwargs) -> BenchmarkResult: | |
"""Run the benchmark on the given API""" | |
print(f"Running {self.name} benchmark on {api.model_name}...") | |
# Load dataset | |
await self.load_dataset(sample_size, **kwargs) | |
if not self.dataset: | |
raise ValueError(f"No dataset loaded for {self.name}") | |
# Prepare samples | |
samples = self.dataset if sample_size is None else self.dataset[:sample_size] | |
total_samples = len(samples) | |
# Run evaluation | |
correct_count = 0 | |
response_times = [] | |
raw_results = [] | |
# Use async semaphore for concurrent requests | |
concurrent_limit = kwargs.get('concurrent_requests', 5) | |
semaphore = asyncio.Semaphore(concurrent_limit) | |
async def evaluate_with_semaphore(sample, idx): | |
async with semaphore: | |
start_time = time.time() | |
is_correct, result = await self.evaluate_sample(api, sample, **kwargs) | |
end_time = time.time() | |
result['response_time'] = end_time - start_time | |
result['index'] = idx | |
return is_correct, result | |
# Create tasks for all samples | |
tasks = [evaluate_with_semaphore(sample, idx) for idx, sample in enumerate(samples)] | |
# Run with progress bar | |
# Add imports needed for progress saving | |
import json | |
import os | |
with tqdm(total=total_samples, desc=f"{self.name}") as pbar: | |
for coro in asyncio.as_completed(tasks): | |
is_correct, result = await coro | |
if is_correct: | |
correct_count += 1 | |
response_times.append(result['response_time']) | |
raw_results.append(result) | |
pbar.update(1) | |
# --- START: REAL-TIME PROGRESS SAVING --- | |
# Every 10 samples, save the progress to a file | |
if pbar.n > 0 and pbar.n % 10 == 0: | |
# Ensure results directory exists | |
results_dir = kwargs.get('output_dir', 'results') | |
os.makedirs(results_dir, exist_ok=True) | |
progress_path = os.path.join(results_dir, f'{self.name}_progress.json') | |
# Sort results by index before saving | |
sorted_progress = sorted(raw_results, key=lambda x: x['index']) | |
try: | |
with open(progress_path, 'w') as f: | |
json.dump(sorted_progress, f, indent=2) | |
except Exception as e: | |
print(f"Error saving progress: {e}") | |
# --- END: REAL-TIME PROGRESS SAVING --- | |
# Calculate metrics | |
accuracy = correct_count / total_samples if total_samples > 0 else 0 | |
avg_response_time = sum(response_times) / len(response_times) if response_times else 0 | |
return BenchmarkResult( | |
benchmark_name=self.name, | |
model_name=api.model_name, | |
total_questions=total_samples, | |
correct=correct_count, | |
accuracy=accuracy, | |
avg_response_time=avg_response_time, | |
raw_results=sorted(raw_results, key=lambda x: x['index']) | |
) |