fastapi_last / main.py
sujal7102003's picture
Upload folder using huggingface_hub
effe919 verified
from fastapi import FastAPI, Request
from pydantic import BaseModel
from typing import List
import joblib
import tensorflow as tf
import numpy as np
from catboost import CatBoostClassifier
from fastapi.templating import Jinja2Templates
from fastapi.responses import HTMLResponse
# App setup
app = FastAPI()
templates = Jinja2Templates(directory="templates")
catboost_model = CatBoostClassifier()
@app.get("/", response_class=HTMLResponse)
def read_index(request: Request):
return templates.TemplateResponse("index.html", {"request": request})
# Input model
class PredictionInput(BaseModel):
features: List[float]
# Load models once
ann_model = tf.keras.models.load_model("ann_model.keras")
xgb_model = joblib.load("xgboost.joblib")
voting_model = joblib.load("voting_classifier.joblib")
svm_model = joblib.load("svm.joblib")
rf_model = joblib.load("random_forest.joblib")
lr_model = joblib.load("logistic_regression (1).joblib")
catboost_model.load_model("catboost_model.cbm")
# Prediction endpoints (No auth)
@app.post("/predict/ann")
def predict_ann(input_data: PredictionInput):
prediction = ann_model.predict(np.array([input_data.features]))
return {"model": "ANN", "prediction": prediction.tolist()}
@app.post("/predict/xgboost")
def predict_xgboost(input_data: PredictionInput):
prediction = xgb_model.predict([input_data.features])
return {"model": "XGBoost", "prediction": prediction.tolist()}
@app.post("/predict/voting")
def predict_voting(input_data: PredictionInput):
prediction = voting_model.predict([input_data.features])
return {"model": "VotingClassifier", "prediction": prediction.tolist()}
@app.post("/predict/svm")
def predict_svm(input_data: PredictionInput):
prediction = svm_model.predict([input_data.features])
return {"model": "SVM", "prediction": prediction.tolist()}
@app.post("/predict/randomforest")
def predict_rf(input_data: PredictionInput):
prediction = rf_model.predict([input_data.features])
return {"model": "RandomForest", "prediction": prediction.tolist()}
@app.post("/predict/logistic")
def predict_lr(input_data: PredictionInput):
prediction = lr_model.predict([input_data.features])
return {"model": "LogisticRegression", "prediction": prediction.tolist()}
@app.post("/predict/catboost")
def predict_catboost(input_data: PredictionInput):
prediction = catboost_model.predict([input_data.features])
return {"model": "CatBoost", "prediction": prediction.tolist()}