Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, Request | |
from fastapi.responses import JSONResponse | |
from fastapi.staticfiles import StaticFiles | |
from fastapi.responses import FileResponse | |
from fastapi.middleware.cors import CORSMiddleware | |
import torch | |
from src.text_embedding import TextEmbeddingModel | |
from src.index import Indexer | |
import os | |
import pickle | |
from infer import infer_3_class, infer_model_specific | |
import uvicorn | |
from datasets import disable_caching | |
disable_caching() | |
app = FastAPI() | |
origins = ["*"] | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=origins, | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
class Opt: | |
def __init__(self): | |
self.model_name = "ngocminhta/faid-v1" | |
self.database_path = "core/seen_db" | |
self.embedding_dim = 768 | |
self.device_num = 1 | |
opt = Opt() | |
def load_pkl(path): | |
with open(path, 'rb') as f: | |
return pickle.load(f) | |
def load_model_resources(): | |
global model, tokenizer, index, label_dict, is_mixed_dict, write_model_dict | |
model = TextEmbeddingModel(opt.model_name) | |
tokenizer=model.tokenizer | |
index = Indexer(opt.embedding_dim) | |
index.deserialize_from(opt.database_path) | |
label_dict=load_pkl(os.path.join(opt.database_path,'label_dict.pkl')) | |
is_mixed_dict=load_pkl(os.path.join(opt.database_path,'is_mixed_dict.pkl')) | |
write_model_dict=load_pkl(os.path.join(opt.database_path,'write_model_dict.pkl')) | |
async def predict(request: Request): | |
data = await request.json() | |
mode = data.get("mode", "normal").lower() | |
text_list = data.get("text", []) | |
if mode == "normal": | |
results = infer_3_class(model=model, | |
tokenizer=tokenizer, | |
index=index, | |
label_dict=label_dict, | |
is_mixed_dict=is_mixed_dict, | |
text_list=text_list, | |
K=21) | |
return JSONResponse(content={"results": results}) | |
elif mode == "advanced": | |
results = infer_model_specific(model=model, | |
tokenizer=tokenizer, | |
index=index, | |
label_dict=label_dict, | |
is_mixed_dict=is_mixed_dict, | |
write_model_dict=write_model_dict, | |
text_list=text_list, | |
K=21, | |
K_model=9) | |
return JSONResponse(content={"results": results}) | |
app.mount("/", StaticFiles(directory="static", html=True), name="static") | |
def index() -> FileResponse: | |
return FileResponse(path="/app/static/index.html", media_type="text/html") | |
if __name__ == "__main__": | |
port = int(os.getenv("PORT", 8000)) | |
uvicorn.run(app, host="0.0.0.0", port=port) |