Spaces:
Sleeping
Sleeping
from fastapi import FastAPI | |
from pydantic import BaseModel | |
from typing import List, Dict | |
from Embedder.E5_Embeddedr import E5_Embeddedr | |
from Models.Utils import * | |
from Models.Prompts import * | |
from OLAP_Conn.DuckConn import DuckConn | |
from RAG.RAG_Retrival import RAG_Retrival | |
import sys | |
from contextlib import asynccontextmanager | |
###################################################### | |
####----------------PARAMETER CLASSES----------------- | |
###################################################### | |
class Message(BaseModel): | |
role: str | |
content: str | |
class ModelParameter(BaseModel): | |
model: str | |
max_token: int | |
temperature: float | |
###################################################### | |
####-------------------DEFINATIONS-------------------- | |
###################################################### | |
PATH_DUCK = "Data.duckdb" | |
db = None | |
model = None | |
embedder =None | |
rag_retriv = None | |
###################################################### | |
####-----------------STARTUP EVENTS------------------- | |
###################################################### | |
async def lifespan(app: FastAPI): | |
global db, embedder | |
db = DuckConn(PATH_DUCK) | |
embedder = E5_Embeddedr() | |
yield | |
app = FastAPI(lifespan=lifespan) | |
#################################################### | |
####--------------------ROUTES---------------------- | |
#################################################### | |
async def chat(messages: List[Message],parameters:ModelParameter): | |
model = get_specific_model(parameters.model) | |
model.set_config(temperature=parameters.temperature,max_tokens=parameters.max_token) | |
rag_retriv = RAG_Retrival(db, model, embedder) | |
# Convert Pydantic objects to dict | |
messages_data = [msg.model_dump() for msg in messages] | |
prompt = messages_data[0]['content'] | |
relevant_queures= rag_retriv.query_relevant(prompt) | |
relevant_queures = ''.join(relevant_queures) | |
final_queury = [message_user(final_prompt(prompt,relevant_queures))] | |
model_answer = model.send_message(final_queury) | |
return {"status": "success", "response": model_answer} | |