File size: 2,477 Bytes
4aea73a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import pandas as pd
from autogluon.multimodal import MultiModalPredictor

app = FastAPI()

# Load the model
predictor = MultiModalPredictor.load("ml_model")

# Input schema
class PredictionInput(BaseModel):
    anchor_age: int
    dbp: int
    heart_rate: int
    sbp: int
    pH: float
    PaCO2: float
    PaO2: float
    HCO3: float
    SaO2: float
    Compliance: float
    Flow_Rate_L_min: float
    Inspired_O2_Fraction: float
    Minute_Volume: float
    Peak_Insp_Pressure: float
    Plateau_Pressure: float
    Resistance_Exp: float
    Resistance_Insp: float
    Respiratory_Rate_Total: float
    Tidal_Volume_observed: float
    Tidal_Volume_set: float
    Total_PEEP_Level: float
    respiratory_diagnoses: str

# Column renaming to match training data
rename_map = {
    "anchor_age": "anchor_age",
    "dbp": "dbp",
    "heart_rate": "heart_rate",
    "sbp": "sbp",
    "pH": "pH",
    "PaCO2": "PaCO2",
    "PaO2": "PaO2",
    "HCO3": "HCO3",
    "SaO2": "SaO2",
    "Compliance": "Compliance",
    "Flow_Rate_L_min": "Flow Rate (L/min)",
    "Inspired_O2_Fraction": "Inspired O2 Fraction",
    "Minute_Volume": "Minute Volume",
    "Peak_Insp_Pressure": "Peak Insp. Pressure",
    "Plateau_Pressure": "Plateau Pressure",
    "Resistance_Exp": "Resistance Exp",
    "Resistance_Insp": "Resistance Insp",
    "Respiratory_Rate_Total": "Respiratory Rate (Total)",
    "Tidal_Volume_observed": "Tidal Volume (observed)",
    "Tidal_Volume_set": "Tidal Volume (set)",
    "Total_PEEP_Level": "Total PEEP Level",
    "respiratory_diagnoses": "respiratory_diagnoses"
}

# Mapping from predicted class index to readable label
label_map = {
    0: "APRV",
    1: "CMV",
    2: "NIV",
    3: "SPECIAL",
    4: "PSV",
    5: "SIMV",
    6: "SPONT"
}

@app.post("/predict")
def predict(input_data: PredictionInput):
    try:
        # Rename input fields to match training data
        input_dict = input_data.dict()
        renamed_input = {rename_map[k]: v for k, v in input_dict.items()}
        df = pd.DataFrame([renamed_input])

        # Run prediction
        raw_prediction = predictor.predict(df)[0]

        # Convert numeric label to string label
        ventilation_mode = label_map.get(int(raw_prediction), "Unknown")

        return {"ventilation_mode": ventilation_mode}

    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}")