ml_dart_model / app.py
joeyaintjoking's picture
initial commit
4aea73a
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import pandas as pd
from autogluon.multimodal import MultiModalPredictor
app = FastAPI()
# Load the model
predictor = MultiModalPredictor.load("ml_model")
# Input schema
class PredictionInput(BaseModel):
anchor_age: int
dbp: int
heart_rate: int
sbp: int
pH: float
PaCO2: float
PaO2: float
HCO3: float
SaO2: float
Compliance: float
Flow_Rate_L_min: float
Inspired_O2_Fraction: float
Minute_Volume: float
Peak_Insp_Pressure: float
Plateau_Pressure: float
Resistance_Exp: float
Resistance_Insp: float
Respiratory_Rate_Total: float
Tidal_Volume_observed: float
Tidal_Volume_set: float
Total_PEEP_Level: float
respiratory_diagnoses: str
# Column renaming to match training data
rename_map = {
"anchor_age": "anchor_age",
"dbp": "dbp",
"heart_rate": "heart_rate",
"sbp": "sbp",
"pH": "pH",
"PaCO2": "PaCO2",
"PaO2": "PaO2",
"HCO3": "HCO3",
"SaO2": "SaO2",
"Compliance": "Compliance",
"Flow_Rate_L_min": "Flow Rate (L/min)",
"Inspired_O2_Fraction": "Inspired O2 Fraction",
"Minute_Volume": "Minute Volume",
"Peak_Insp_Pressure": "Peak Insp. Pressure",
"Plateau_Pressure": "Plateau Pressure",
"Resistance_Exp": "Resistance Exp",
"Resistance_Insp": "Resistance Insp",
"Respiratory_Rate_Total": "Respiratory Rate (Total)",
"Tidal_Volume_observed": "Tidal Volume (observed)",
"Tidal_Volume_set": "Tidal Volume (set)",
"Total_PEEP_Level": "Total PEEP Level",
"respiratory_diagnoses": "respiratory_diagnoses"
}
# Mapping from predicted class index to readable label
label_map = {
0: "APRV",
1: "CMV",
2: "NIV",
3: "SPECIAL",
4: "PSV",
5: "SIMV",
6: "SPONT"
}
@app.post("/predict")
def predict(input_data: PredictionInput):
try:
# Rename input fields to match training data
input_dict = input_data.dict()
renamed_input = {rename_map[k]: v for k, v in input_dict.items()}
df = pd.DataFrame([renamed_input])
# Run prediction
raw_prediction = predictor.predict(df)[0]
# Convert numeric label to string label
ventilation_mode = label_map.get(int(raw_prediction), "Unknown")
return {"ventilation_mode": ventilation_mode}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}")