|
from fastapi import FastAPI |
|
from pydantic import BaseModel |
|
import joblib |
|
import os |
|
import uvicorn |
|
from fastapi.staticfiles import StaticFiles |
|
from fastapi.middleware.cors import CORSMiddleware |
|
from fastapi.responses import FileResponse |
|
|
|
app = FastAPI() |
|
|
|
|
|
|
|
|
|
|
|
origins = [ |
|
"http://localhost", |
|
"http://127.0.0.1", |
|
"null", |
|
"http://localhost:7860", |
|
"http://localhost:7860", |
|
"http://127.0.0.1:7860", |
|
"http://0.0.0.0:7860", |
|
"http://localhost:7860", |
|
"http://localhost:7860", |
|
"http://127.0.0.1:7860", |
|
"http://0.0.0.0:7860", |
|
|
|
] |
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=origins, |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
print(f"os.path.dirname(__file__): {os.path.dirname(__file__)}") |
|
model_path = os.path.join(os.path.dirname(__file__), 'model/model.joblib') |
|
model = joblib.load(model_path) |
|
|
|
app.mount("/ui", StaticFiles(directory="ui", html=True), name="ui") |
|
|
|
class Iris(BaseModel): |
|
sepal_length: float |
|
|
|
@app.post("/predict") |
|
def predict(data: Iris): |
|
prediction = model.predict([[data.sepal_length]]) |
|
if prediction[0] == 0: |
|
return {"prediction": "setosa"} |
|
elif prediction[0] == 1: |
|
return {"prediction": "versicolor"} |
|
else: |
|
return {"prediction": "virginica"} |
|
|
|
@app.get("/") |
|
def serve_index(): |
|
return FileResponse("ui/index.html") |
|
|
|
if __name__ == "__main__": |
|
uvicorn.run(app, host="0.0.0.0", port=7860) |
|
|