from fastapi import FastAPI, HTTPException from pydantic import BaseModel import pandas as pd from autogluon.multimodal import MultiModalPredictor app = FastAPI() # Load the model predictor = MultiModalPredictor.load("ml_model") # Input schema class PredictionInput(BaseModel): anchor_age: int dbp: int heart_rate: int sbp: int pH: float PaCO2: float PaO2: float HCO3: float SaO2: float Compliance: float Flow_Rate_L_min: float Inspired_O2_Fraction: float Minute_Volume: float Peak_Insp_Pressure: float Plateau_Pressure: float Resistance_Exp: float Resistance_Insp: float Respiratory_Rate_Total: float Tidal_Volume_observed: float Tidal_Volume_set: float Total_PEEP_Level: float respiratory_diagnoses: str # Column renaming to match training data rename_map = { "anchor_age": "anchor_age", "dbp": "dbp", "heart_rate": "heart_rate", "sbp": "sbp", "pH": "pH", "PaCO2": "PaCO2", "PaO2": "PaO2", "HCO3": "HCO3", "SaO2": "SaO2", "Compliance": "Compliance", "Flow_Rate_L_min": "Flow Rate (L/min)", "Inspired_O2_Fraction": "Inspired O2 Fraction", "Minute_Volume": "Minute Volume", "Peak_Insp_Pressure": "Peak Insp. Pressure", "Plateau_Pressure": "Plateau Pressure", "Resistance_Exp": "Resistance Exp", "Resistance_Insp": "Resistance Insp", "Respiratory_Rate_Total": "Respiratory Rate (Total)", "Tidal_Volume_observed": "Tidal Volume (observed)", "Tidal_Volume_set": "Tidal Volume (set)", "Total_PEEP_Level": "Total PEEP Level", "respiratory_diagnoses": "respiratory_diagnoses" } # Mapping from predicted class index to readable label label_map = { 0: "APRV", 1: "CMV", 2: "NIV", 3: "SPECIAL", 4: "PSV", 5: "SIMV", 6: "SPONT" } @app.post("/predict") def predict(input_data: PredictionInput): try: # Rename input fields to match training data input_dict = input_data.dict() renamed_input = {rename_map[k]: v for k, v in input_dict.items()} df = pd.DataFrame([renamed_input]) # Run prediction raw_prediction = predictor.predict(df)[0] # Convert numeric label to string label ventilation_mode = label_map.get(int(raw_prediction), "Unknown") return {"ventilation_mode": ventilation_mode} except Exception as e: raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}")