Spaces:
Runtime error
Runtime error
from fastapi import FastAPI, HTTPException, BackgroundTasks, UploadFile, File | |
from fastapi.responses import FileResponse | |
from pydantic import BaseModel | |
from typing import Optional, Dict, Any, List | |
import uvicorn | |
import torch | |
import logging | |
import os | |
import asyncio | |
import pandas as pd | |
from datetime import datetime | |
import shutil | |
from pathlib import Path | |
import numpy as np | |
import sys | |
# Add parent directory to Python path | |
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) | |
from voting import perform_voting_ensemble, save_predictions | |
from config import LABEL_COLUMNS, PREDICTIONS_SAVE_DIR | |
from dataset_utils import load_label_encoders | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
app = FastAPI(title="Ensemble Voting API") | |
# Create necessary directories | |
UPLOAD_DIR = Path("uploads") | |
PREDICTIONS_DIR = Path(PREDICTIONS_SAVE_DIR) | |
UPLOAD_DIR.mkdir(parents=True, exist_ok=True) | |
PREDICTIONS_DIR.mkdir(parents=True, exist_ok=True) | |
class EnsembleConfig(BaseModel): | |
model_names: List[str] | |
weights: Optional[Dict[str, float]] = None | |
class EnsembleResponse(BaseModel): | |
message: str | |
metrics: Dict[str, Any] | |
predictions: List[Dict[str, Any]] | |
class PredictionData(BaseModel): | |
model_name: str | |
probabilities: List[List[float]] | |
true_labels: Optional[List[int]] = None | |
async def root(): | |
return {"message": "Ensemble Voting API"} | |
async def health_check(): | |
return {"status": "healthy"} | |
async def perform_ensemble( | |
config: EnsembleConfig | |
): | |
"""Perform ensemble voting using specified models""" | |
try: | |
# Perform ensemble voting | |
ensemble_reports, true_labels, ensemble_predictions = perform_voting_ensemble(config.model_names) | |
# Load label encoders for decoding predictions | |
label_encoders = load_label_encoders() | |
# Format predictions with original labels | |
formatted_predictions = [] | |
for i, (col, preds) in enumerate(zip(LABEL_COLUMNS, ensemble_predictions)): | |
if true_labels[i] is not None: | |
label_encoder = label_encoders[col] | |
true_labels_orig = label_encoder.inverse_transform(true_labels[i]) | |
pred_labels_orig = label_encoder.inverse_transform(preds) | |
for true, pred in zip(true_labels_orig, pred_labels_orig): | |
formatted_predictions.append({ | |
"field": col, | |
"true_label": true, | |
"predicted_label": pred | |
}) | |
return EnsembleResponse( | |
message="Ensemble voting completed successfully", | |
metrics=ensemble_reports, | |
predictions=formatted_predictions | |
) | |
except Exception as e: | |
logger.error(f"Ensemble voting failed: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Ensemble voting failed: {str(e)}") | |
async def save_model_predictions( | |
prediction_data: PredictionData | |
): | |
"""Save predictions from a model for later ensemble voting""" | |
try: | |
# Convert probabilities to numpy arrays | |
all_probs = [np.array(probs) for probs in prediction_data.probabilities] | |
true_labels = [np.array(prediction_data.true_labels) if prediction_data.true_labels else None] | |
# Save predictions | |
save_predictions( | |
prediction_data.model_name, | |
all_probs, | |
true_labels | |
) | |
return { | |
"message": f"Predictions saved successfully for model {prediction_data.model_name}", | |
"model_name": prediction_data.model_name | |
} | |
except Exception as e: | |
logger.error(f"Failed to save predictions: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Failed to save predictions: {str(e)}") | |
async def get_available_models(): | |
"""Get list of models with saved predictions""" | |
try: | |
model_dirs = [d for d in os.listdir(PREDICTIONS_DIR) | |
if os.path.isdir(os.path.join(PREDICTIONS_DIR, d))] | |
available_models = [] | |
for model_name in model_dirs: | |
model_dir = os.path.join(PREDICTIONS_DIR, model_name) | |
has_all_files = all( | |
os.path.exists(os.path.join(model_dir, f"{col}_probs.pkl")) | |
for col in LABEL_COLUMNS | |
) | |
if has_all_files: | |
available_models.append(model_name) | |
return { | |
"available_models": available_models, | |
"total_models": len(available_models) | |
} | |
except Exception as e: | |
logger.error(f"Failed to get available models: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Failed to get available models: {str(e)}") | |
if __name__ == "__main__": | |
port = int(os.environ.get("PORT", 7861)) | |
uvicorn.run(app, host="0.0.0.0", port=port) |