Spaces:
Sleeping
Sleeping
File size: 2,712 Bytes
3fef185 34cfd03 3fef185 617c3f7 3fef185 d6fa263 3fef185 1578895 3fef185 617c3f7 3fef185 617c3f7 3fef185 5b2f797 667fbf3 3fef185 617c3f7 2fffdc8 617c3f7 3fef185 b0b3c8b 3fef185 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
from fastapi.middleware.cors import CORSMiddleware
import torch
from src.text_embedding import TextEmbeddingModel
from src.index import Indexer
import os
import pickle
from infer import infer_3_class, infer_model_specific
import uvicorn
from datasets import disable_caching
disable_caching()
app = FastAPI()
origins = ["*"]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class Opt:
def __init__(self):
self.model_name = "ngocminhta/faid-v1"
self.database_path = "core/seen_db"
self.embedding_dim = 768
self.device_num = 1
opt = Opt()
def load_pkl(path):
with open(path, 'rb') as f:
return pickle.load(f)
@app.on_event("startup")
def load_model_resources():
global model, tokenizer, index, label_dict, is_mixed_dict, write_model_dict
model = TextEmbeddingModel(opt.model_name)
tokenizer=model.tokenizer
index = Indexer(opt.embedding_dim)
index.deserialize_from(opt.database_path)
label_dict=load_pkl(os.path.join(opt.database_path,'label_dict.pkl'))
is_mixed_dict=load_pkl(os.path.join(opt.database_path,'is_mixed_dict.pkl'))
write_model_dict=load_pkl(os.path.join(opt.database_path,'write_model_dict.pkl'))
@app.route('/predict', methods=['POST'])
async def predict(request: Request):
data = await request.json()
mode = data.get("mode", "normal").lower()
text_list = data.get("text", [])
if mode == "normal":
results = infer_3_class(model=model,
tokenizer=tokenizer,
index=index,
label_dict=label_dict,
is_mixed_dict=is_mixed_dict,
text_list=text_list,
K=21)
return JSONResponse(content={"results": results})
elif mode == "advanced":
results = infer_model_specific(model=model,
tokenizer=tokenizer,
index=index,
label_dict=label_dict,
is_mixed_dict=is_mixed_dict,
write_model_dict=write_model_dict,
text_list=text_list,
K=21,
K_model=9)
return JSONResponse(content={"results": results})
app.mount("/", StaticFiles(directory="static", html=True), name="static")
@app.get("/")
def index() -> FileResponse:
return FileResponse(path="/app/static/index.html", media_type="text/html")
if __name__ == "__main__":
port = int(os.getenv("PORT", 8000))
uvicorn.run(app, host="0.0.0.0", port=port) |