Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
from fastapi import APIRouter, HTTPException | |
from typing import Dict, Any | |
import os | |
from tasks.evaluation_task import EvaluationTask | |
from huggingface_hub import hf_hub_download | |
import json | |
from datetime import datetime | |
import asyncio | |
router = APIRouter(tags=["evaluation"]) | |
# Store active evaluation tasks by session_id | |
active_evaluation_tasks = {} | |
async def evaluate_benchmark(data: Dict[str, Any]): | |
""" | |
Lancer l'évaluation d'un benchmark pour une session donnée | |
Args: | |
data: Dictionary containing session_id | |
Returns: | |
Dictionary with status and initial logs | |
""" | |
session_id = data.get("session_id") | |
if not session_id: | |
return {"error": "Session ID missing or invalid"} | |
# Check if an evaluation is already in progress for this session | |
if session_id in active_evaluation_tasks: | |
evaluation_task = active_evaluation_tasks[session_id] | |
# If the evaluation is already completed, we can start a new one | |
if evaluation_task.is_task_completed(): | |
# Delete the old task | |
del active_evaluation_tasks[session_id] | |
else: | |
# An evaluation is already in progress | |
return { | |
"status": "already_running", | |
"message": "An evaluation is already in progress for this session", | |
"logs": evaluation_task.get_logs() | |
} | |
try: | |
# Dataset name based on session ID | |
dataset_name = f"yourbench/yourbench_{session_id}" | |
# Create and start a new evaluation task | |
evaluation_task = EvaluationTask(session_uid=session_id, dataset_name=dataset_name) | |
active_evaluation_tasks[session_id] = evaluation_task | |
# Start the evaluation asynchronously | |
asyncio.create_task(evaluation_task.run()) | |
# Get initial logs | |
initial_logs = evaluation_task.get_logs() | |
return { | |
"status": "started", | |
"message": f"Evaluation started for benchmark {dataset_name}", | |
"logs": initial_logs | |
} | |
except Exception as e: | |
return { | |
"status": "error", | |
"error": str(e), | |
"message": f"Error starting evaluation: {str(e)}" | |
} | |
async def get_evaluation_logs(session_id: str): | |
""" | |
Récupérer les logs d'une évaluation en cours | |
Args: | |
session_id: ID de la session pour laquelle récupérer les logs | |
Returns: | |
Dictionary avec logs et statut de complétion | |
""" | |
if session_id not in active_evaluation_tasks: | |
raise HTTPException(status_code=404, detail="Tâche d'évaluation non trouvée") | |
evaluation_task = active_evaluation_tasks[session_id] | |
logs = evaluation_task.get_logs() | |
is_completed = evaluation_task.is_task_completed() | |
# Get results if available and evaluation is completed | |
results = None | |
if is_completed and hasattr(evaluation_task, 'results') and evaluation_task.results: | |
results = evaluation_task.results | |
# Get step information | |
progress = evaluation_task.get_progress() | |
return { | |
"logs": logs, | |
"is_completed": is_completed, | |
"results": results, | |
"current_step": progress["current_step"], | |
"completed_steps": progress["completed_steps"] | |
} | |
async def get_evaluation_results(session_id: str): | |
""" | |
Retrieve results of a completed evaluation | |
Args: | |
session_id: Session ID to retrieve results for | |
Returns: | |
Dictionary with evaluation results | |
""" | |
try: | |
# Get organization from environment | |
organization = os.getenv("HF_ORGANIZATION", "yourbench") | |
dataset_name = f"{organization}/yourbench_{session_id}" | |
# Try to load results from the Hub | |
try: | |
results_file = hf_hub_download( | |
repo_id=dataset_name, | |
repo_type="dataset", | |
filename="lighteval_results.json" | |
) | |
with open(results_file) as f: | |
results_data = json.load(f) | |
# Check if results are in the new format or old format | |
if "results" in results_data and isinstance(results_data["results"], list): | |
# New format: { "metadata": ..., "results": [...] } | |
results_list = results_data["results"] | |
metadata = results_data.get("metadata", {}) | |
else: | |
# Old format: [...] (list directly) | |
results_list = results_data | |
metadata = {} | |
# Format results to match the expected format | |
formatted_results = { | |
"metadata": { | |
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), | |
"session_id": metadata.get("session_id", session_id), | |
"total_models_tested": len(results_list), | |
"successful_tests": len([r for r in results_list if r.get("status") == "success"]) | |
}, | |
"models_comparison": [ | |
{ | |
"model_name": result["model"], | |
"provider": result["provider"], | |
"success": result.get("status") == "success", | |
"accuracy": result["accuracy"], | |
"evaluation_time": result["execution_time"], | |
"error": result.get("status") if result.get("status") != "success" else None | |
} | |
for result in results_list | |
] | |
} | |
return { | |
"success": True, | |
"results": formatted_results | |
} | |
except Exception as e: | |
return { | |
"success": False, | |
"message": f"Failed to load results from Hub: {str(e)}" | |
} | |
except Exception as e: | |
return { | |
"success": False, | |
"message": f"Error retrieving evaluation results: {str(e)}" | |
} |