from fastapi import FastAPI, HTTPException, BackgroundTasks, UploadFile, File from fastapi.responses import FileResponse from pydantic import BaseModel from typing import Optional, Dict, Any, List import uvicorn import torch import logging import os import asyncio import pandas as pd from datetime import datetime import shutil from pathlib import Path import numpy as np import sys # Add parent directory to Python path sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) from voting import perform_voting_ensemble, save_predictions from config import LABEL_COLUMNS, PREDICTIONS_SAVE_DIR from dataset_utils import load_label_encoders # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) app = FastAPI(title="Ensemble Voting API") # Create necessary directories UPLOAD_DIR = Path("uploads") PREDICTIONS_DIR = Path(PREDICTIONS_SAVE_DIR) UPLOAD_DIR.mkdir(parents=True, exist_ok=True) PREDICTIONS_DIR.mkdir(parents=True, exist_ok=True) class EnsembleConfig(BaseModel): model_names: List[str] weights: Optional[Dict[str, float]] = None class EnsembleResponse(BaseModel): message: str metrics: Dict[str, Any] predictions: List[Dict[str, Any]] class PredictionData(BaseModel): model_name: str probabilities: List[List[float]] true_labels: Optional[List[int]] = None @app.get("/") async def root(): return {"message": "Ensemble Voting API"} @app.get("/health") async def health_check(): return {"status": "healthy"} @app.post("/ensemble/vote") async def perform_ensemble( config: EnsembleConfig ): """Perform ensemble voting using specified models""" try: # Perform ensemble voting ensemble_reports, true_labels, ensemble_predictions = perform_voting_ensemble(config.model_names) # Load label encoders for decoding predictions label_encoders = load_label_encoders() # Format predictions with original labels formatted_predictions = [] for i, (col, preds) in enumerate(zip(LABEL_COLUMNS, ensemble_predictions)): if true_labels[i] is not None: label_encoder = label_encoders[col] true_labels_orig = label_encoder.inverse_transform(true_labels[i]) pred_labels_orig = label_encoder.inverse_transform(preds) for true, pred in zip(true_labels_orig, pred_labels_orig): formatted_predictions.append({ "field": col, "true_label": true, "predicted_label": pred }) return EnsembleResponse( message="Ensemble voting completed successfully", metrics=ensemble_reports, predictions=formatted_predictions ) except Exception as e: logger.error(f"Ensemble voting failed: {str(e)}") raise HTTPException(status_code=500, detail=f"Ensemble voting failed: {str(e)}") @app.post("/ensemble/save-predictions") async def save_model_predictions( prediction_data: PredictionData ): """Save predictions from a model for later ensemble voting""" try: # Convert probabilities to numpy arrays all_probs = [np.array(probs) for probs in prediction_data.probabilities] true_labels = [np.array(prediction_data.true_labels) if prediction_data.true_labels else None] # Save predictions save_predictions( prediction_data.model_name, all_probs, true_labels ) return { "message": f"Predictions saved successfully for model {prediction_data.model_name}", "model_name": prediction_data.model_name } except Exception as e: logger.error(f"Failed to save predictions: {str(e)}") raise HTTPException(status_code=500, detail=f"Failed to save predictions: {str(e)}") @app.get("/ensemble/available-models") async def get_available_models(): """Get list of models with saved predictions""" try: model_dirs = [d for d in os.listdir(PREDICTIONS_DIR) if os.path.isdir(os.path.join(PREDICTIONS_DIR, d))] available_models = [] for model_name in model_dirs: model_dir = os.path.join(PREDICTIONS_DIR, model_name) has_all_files = all( os.path.exists(os.path.join(model_dir, f"{col}_probs.pkl")) for col in LABEL_COLUMNS ) if has_all_files: available_models.append(model_name) return { "available_models": available_models, "total_models": len(available_models) } except Exception as e: logger.error(f"Failed to get available models: {str(e)}") raise HTTPException(status_code=500, detail=f"Failed to get available models: {str(e)}") if __name__ == "__main__": port = int(os.environ.get("PORT", 7861)) uvicorn.run(app, host="0.0.0.0", port=port)