Spaces:
Build error
Build error
File size: 2,482 Bytes
b1713d3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import pandas as pd
from autogluon.multimodal import MultiModalPredictor
app = FastAPI()
# Load the model
predictor = MultiModalPredictor.load("model_ml_dart")
# Input schema
class PredictionInput(BaseModel):
anchor_age: int
dbp: int
heart_rate: int
sbp: int
pH: float
PaCO2: float
PaO2: float
HCO3: float
SaO2: float
Compliance: float
Flow_Rate_L_min: float
Inspired_O2_Fraction: float
Minute_Volume: float
Peak_Insp_Pressure: float
Plateau_Pressure: float
Resistance_Exp: float
Resistance_Insp: float
Respiratory_Rate_Total: float
Tidal_Volume_observed: float
Tidal_Volume_set: float
Total_PEEP_Level: float
respiratory_diagnoses: str
# Column renaming to match training data
rename_map = {
"anchor_age": "anchor_age",
"dbp": "dbp",
"heart_rate": "heart_rate",
"sbp": "sbp",
"pH": "pH",
"PaCO2": "PaCO2",
"PaO2": "PaO2",
"HCO3": "HCO3",
"SaO2": "SaO2",
"Compliance": "Compliance",
"Flow_Rate_L_min": "Flow Rate (L/min)",
"Inspired_O2_Fraction": "Inspired O2 Fraction",
"Minute_Volume": "Minute Volume",
"Peak_Insp_Pressure": "Peak Insp. Pressure",
"Plateau_Pressure": "Plateau Pressure",
"Resistance_Exp": "Resistance Exp",
"Resistance_Insp": "Resistance Insp",
"Respiratory_Rate_Total": "Respiratory Rate (Total)",
"Tidal_Volume_observed": "Tidal Volume (observed)",
"Tidal_Volume_set": "Tidal Volume (set)",
"Total_PEEP_Level": "Total PEEP Level",
"respiratory_diagnoses": "respiratory_diagnoses"
}
# Mapping from predicted class index to readable label
label_map = {
0: "APRV",
1: "CMV",
2: "NIV",
3: "SPECIAL",
4: "PSV",
5: "SIMV",
6: "SPONT"
}
@app.post("/predict")
def predict(input_data: PredictionInput):
try:
# Rename input fields to match training data
input_dict = input_data.dict()
renamed_input = {rename_map[k]: v for k, v in input_dict.items()}
df = pd.DataFrame([renamed_input])
# Run prediction
raw_prediction = predictor.predict(df)[0]
# Convert numeric label to string label
ventilation_mode = label_map.get(int(raw_prediction), "Unknown")
return {"ventilation_mode": ventilation_mode}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}")
|