Spaces:
Build error
Build error
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
import pandas as pd | |
from autogluon.multimodal import MultiModalPredictor | |
app = FastAPI() | |
# Load the model | |
predictor = MultiModalPredictor.load("model_ml_dart") | |
# Input schema | |
class PredictionInput(BaseModel): | |
anchor_age: int | |
dbp: int | |
heart_rate: int | |
sbp: int | |
pH: float | |
PaCO2: float | |
PaO2: float | |
HCO3: float | |
SaO2: float | |
Compliance: float | |
Flow_Rate_L_min: float | |
Inspired_O2_Fraction: float | |
Minute_Volume: float | |
Peak_Insp_Pressure: float | |
Plateau_Pressure: float | |
Resistance_Exp: float | |
Resistance_Insp: float | |
Respiratory_Rate_Total: float | |
Tidal_Volume_observed: float | |
Tidal_Volume_set: float | |
Total_PEEP_Level: float | |
respiratory_diagnoses: str | |
# Column renaming to match training data | |
rename_map = { | |
"anchor_age": "anchor_age", | |
"dbp": "dbp", | |
"heart_rate": "heart_rate", | |
"sbp": "sbp", | |
"pH": "pH", | |
"PaCO2": "PaCO2", | |
"PaO2": "PaO2", | |
"HCO3": "HCO3", | |
"SaO2": "SaO2", | |
"Compliance": "Compliance", | |
"Flow_Rate_L_min": "Flow Rate (L/min)", | |
"Inspired_O2_Fraction": "Inspired O2 Fraction", | |
"Minute_Volume": "Minute Volume", | |
"Peak_Insp_Pressure": "Peak Insp. Pressure", | |
"Plateau_Pressure": "Plateau Pressure", | |
"Resistance_Exp": "Resistance Exp", | |
"Resistance_Insp": "Resistance Insp", | |
"Respiratory_Rate_Total": "Respiratory Rate (Total)", | |
"Tidal_Volume_observed": "Tidal Volume (observed)", | |
"Tidal_Volume_set": "Tidal Volume (set)", | |
"Total_PEEP_Level": "Total PEEP Level", | |
"respiratory_diagnoses": "respiratory_diagnoses" | |
} | |
# Mapping from predicted class index to readable label | |
label_map = { | |
0: "APRV", | |
1: "CMV", | |
2: "NIV", | |
3: "SPECIAL", | |
4: "PSV", | |
5: "SIMV", | |
6: "SPONT" | |
} | |
def predict(input_data: PredictionInput): | |
try: | |
# Rename input fields to match training data | |
input_dict = input_data.dict() | |
renamed_input = {rename_map[k]: v for k, v in input_dict.items()} | |
df = pd.DataFrame([renamed_input]) | |
# Run prediction | |
raw_prediction = predictor.predict(df)[0] | |
# Convert numeric label to string label | |
ventilation_mode = label_map.get(int(raw_prediction), "Unknown") | |
return {"ventilation_mode": ventilation_mode} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}") | |