from fastapi import FastAPI, Request from fastapi.responses import JSONResponse from fastapi.staticfiles import StaticFiles from fastapi.responses import FileResponse from fastapi.middleware.cors import CORSMiddleware import torch from src.text_embedding import TextEmbeddingModel from src.index import Indexer import os import pickle from infer import infer_3_class, infer_model_specific import uvicorn from datasets import disable_caching disable_caching() app = FastAPI() origins = ["*"] app.add_middleware( CORSMiddleware, allow_origins=origins, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) class Opt: def __init__(self): self.model_name = "ngocminhta/faid-v1" self.database_path = "core/seen_db" self.embedding_dim = 768 self.device_num = 1 opt = Opt() def load_pkl(path): with open(path, 'rb') as f: return pickle.load(f) @app.on_event("startup") def load_model_resources(): global model, tokenizer, index, label_dict, is_mixed_dict, write_model_dict model = TextEmbeddingModel(opt.model_name) tokenizer=model.tokenizer index = Indexer(opt.embedding_dim) index.deserialize_from(opt.database_path) label_dict=load_pkl(os.path.join(opt.database_path,'label_dict.pkl')) is_mixed_dict=load_pkl(os.path.join(opt.database_path,'is_mixed_dict.pkl')) write_model_dict=load_pkl(os.path.join(opt.database_path,'write_model_dict.pkl')) @app.route('/predict', methods=['POST']) async def predict(request: Request): data = await request.json() mode = data.get("mode", "normal").lower() text_list = data.get("text", []) if mode == "normal": results = infer_3_class(model=model, tokenizer=tokenizer, index=index, label_dict=label_dict, is_mixed_dict=is_mixed_dict, text_list=text_list, K=21) return JSONResponse(content={"results": results}) elif mode == "advanced": results = infer_model_specific(model=model, tokenizer=tokenizer, index=index, label_dict=label_dict, is_mixed_dict=is_mixed_dict, write_model_dict=write_model_dict, text_list=text_list, K=21, K_model=9) return JSONResponse(content={"results": results}) app.mount("/", StaticFiles(directory="static", html=True), name="static") @app.get("/") def index() -> FileResponse: return FileResponse(path="/app/static/index.html", media_type="text/html") if __name__ == "__main__": port = int(os.getenv("PORT", 8000)) uvicorn.run(app, host="0.0.0.0", port=port)