falcon-api / app.py
ngocminhta
update model search
2fffdc8
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
from fastapi.middleware.cors import CORSMiddleware
import torch
from src.text_embedding import TextEmbeddingModel
from src.index import Indexer
import os
import pickle
from infer import infer_3_class, infer_model_specific
import uvicorn
from datasets import disable_caching
disable_caching()
app = FastAPI()
origins = ["*"]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class Opt:
def __init__(self):
self.model_name = "ngocminhta/faid-v1"
self.database_path = "core/seen_db"
self.embedding_dim = 768
self.device_num = 1
opt = Opt()
def load_pkl(path):
with open(path, 'rb') as f:
return pickle.load(f)
@app.on_event("startup")
def load_model_resources():
global model, tokenizer, index, label_dict, is_mixed_dict, write_model_dict
model = TextEmbeddingModel(opt.model_name)
tokenizer=model.tokenizer
index = Indexer(opt.embedding_dim)
index.deserialize_from(opt.database_path)
label_dict=load_pkl(os.path.join(opt.database_path,'label_dict.pkl'))
is_mixed_dict=load_pkl(os.path.join(opt.database_path,'is_mixed_dict.pkl'))
write_model_dict=load_pkl(os.path.join(opt.database_path,'write_model_dict.pkl'))
@app.route('/predict', methods=['POST'])
async def predict(request: Request):
data = await request.json()
mode = data.get("mode", "normal").lower()
text_list = data.get("text", [])
if mode == "normal":
results = infer_3_class(model=model,
tokenizer=tokenizer,
index=index,
label_dict=label_dict,
is_mixed_dict=is_mixed_dict,
text_list=text_list,
K=21)
return JSONResponse(content={"results": results})
elif mode == "advanced":
results = infer_model_specific(model=model,
tokenizer=tokenizer,
index=index,
label_dict=label_dict,
is_mixed_dict=is_mixed_dict,
write_model_dict=write_model_dict,
text_list=text_list,
K=21,
K_model=9)
return JSONResponse(content={"results": results})
app.mount("/", StaticFiles(directory="static", html=True), name="static")
@app.get("/")
def index() -> FileResponse:
return FileResponse(path="/app/static/index.html", media_type="text/html")
if __name__ == "__main__":
port = int(os.getenv("PORT", 8000))
uvicorn.run(app, host="0.0.0.0", port=port)