AidMateLLM / app.py
taha454's picture
Update app.py
39751d8 verified
from fastapi import FastAPI
from pydantic import BaseModel
from typing import List, Dict
from Embedder.E5_Embeddedr import E5_Embeddedr
from Models.Utils import *
from Models.Prompts import *
from OLAP_Conn.DuckConn import DuckConn
from RAG.RAG_Retrival import RAG_Retrival
import sys
from contextlib import asynccontextmanager
######################################################
####----------------PARAMETER CLASSES-----------------
######################################################
class Message(BaseModel):
role: str
content: str
class ModelParameter(BaseModel):
model: str
max_token: int
temperature: float
######################################################
####-------------------DEFINATIONS--------------------
######################################################
PATH_DUCK = "Data.duckdb"
db = None
model = None
embedder =None
rag_retriv = None
######################################################
####-----------------STARTUP EVENTS-------------------
######################################################
@asynccontextmanager
async def lifespan(app: FastAPI):
global db, embedder
db = DuckConn(PATH_DUCK)
embedder = E5_Embeddedr()
yield
app = FastAPI(lifespan=lifespan)
####################################################
####--------------------ROUTES----------------------
####################################################
@app.post("/chat")
async def chat(messages: List[Message],parameters:ModelParameter):
model = get_specific_model(parameters.model)
model.set_config(temperature=parameters.temperature,max_tokens=parameters.max_token)
rag_retriv = RAG_Retrival(db, model, embedder)
# Convert Pydantic objects to dict
messages_data = [msg.model_dump() for msg in messages]
prompt = messages_data[0]['content']
relevant_queures= rag_retriv.query_relevant(prompt)
relevant_queures = ''.join(relevant_queures)
final_queury = [message_user(final_prompt(prompt,relevant_queures))]
model_answer = model.send_message(final_queury)
return {"status": "success", "response": model_answer}