File size: 2,712 Bytes
3fef185
 
34cfd03
 
3fef185
 
 
 
 
 
617c3f7
3fef185
d6fa263
 
3fef185
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1578895
3fef185
 
 
 
 
 
 
 
 
 
 
 
617c3f7
3fef185
 
 
 
 
 
 
 
617c3f7
3fef185
 
 
 
 
 
 
 
 
5b2f797
 
 
 
 
 
667fbf3
3fef185
 
617c3f7
 
 
 
 
 
 
2fffdc8
 
617c3f7
3fef185
b0b3c8b
 
 
 
 
 
3fef185
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
from fastapi.middleware.cors import CORSMiddleware
import torch
from src.text_embedding import TextEmbeddingModel
from src.index import Indexer
import os
import pickle
from infer import infer_3_class, infer_model_specific
import uvicorn
from datasets import disable_caching
disable_caching()

app = FastAPI()

origins = ["*"]

app.add_middleware(
    CORSMiddleware,
    allow_origins=origins,
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

class Opt:
    def __init__(self):
        self.model_name = "ngocminhta/faid-v1"
        self.database_path = "core/seen_db"
        self.embedding_dim = 768
        self.device_num = 1
        
opt = Opt()

def load_pkl(path):
    with open(path, 'rb') as f:
        return pickle.load(f)
    
@app.on_event("startup")
def load_model_resources():
    global model, tokenizer, index, label_dict, is_mixed_dict, write_model_dict
    
    model = TextEmbeddingModel(opt.model_name)
    tokenizer=model.tokenizer

    index = Indexer(opt.embedding_dim)
    index.deserialize_from(opt.database_path)
    label_dict=load_pkl(os.path.join(opt.database_path,'label_dict.pkl'))
    is_mixed_dict=load_pkl(os.path.join(opt.database_path,'is_mixed_dict.pkl'))
    write_model_dict=load_pkl(os.path.join(opt.database_path,'write_model_dict.pkl'))
    
    
@app.route('/predict', methods=['POST'])
async def predict(request: Request):
    data = await request.json()
    mode = data.get("mode", "normal").lower()
    text_list = data.get("text", [])
    
    if mode == "normal":
        results = infer_3_class(model=model, 
            tokenizer=tokenizer, 
            index=index, 
            label_dict=label_dict, 
            is_mixed_dict=is_mixed_dict, 
            text_list=text_list,
            K=21)
        return JSONResponse(content={"results": results})
    elif mode == "advanced":
        results = infer_model_specific(model=model,
            tokenizer=tokenizer,
            index=index,
            label_dict=label_dict,
            is_mixed_dict=is_mixed_dict,
            write_model_dict=write_model_dict,
            text_list=text_list,
            K=21,
            K_model=9)
        return JSONResponse(content={"results": results})

app.mount("/", StaticFiles(directory="static", html=True), name="static")

@app.get("/")
def index() -> FileResponse:
    return FileResponse(path="/app/static/index.html", media_type="text/html")

if __name__ == "__main__":
    port = int(os.getenv("PORT", 8000))
    uvicorn.run(app, host="0.0.0.0", port=port)