File size: 2,129 Bytes
2ebf9ad
 
 
a7f1f74
 
 
 
 
 
39751d8
2ebf9ad
a7f1f74
 
39751d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a7f1f74
 
 
 
 
 
39751d8
 
 
 
 
 
 
 
 
 
 
a7f1f74
 
39751d8
a7f1f74
39751d8
a7f1f74
39751d8
 
 
2ebf9ad
 
39751d8
 
 
 
 
2ebf9ad
39751d8
 
 
a7f1f74
 
 
 
 
 
 
2ebf9ad
a7f1f74
39751d8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
from fastapi import FastAPI
from pydantic import BaseModel
from typing import List, Dict
from Embedder.E5_Embeddedr import E5_Embeddedr
from Models.Utils import *
from Models.Prompts import *
from OLAP_Conn.DuckConn import DuckConn
from RAG.RAG_Retrival import RAG_Retrival
import sys
from contextlib import asynccontextmanager




######################################################
####----------------PARAMETER CLASSES-----------------
######################################################


class Message(BaseModel):
    role: str
    content: str

class ModelParameter(BaseModel):
    model: str
    max_token: int
    temperature: float

######################################################
####-------------------DEFINATIONS--------------------
######################################################

PATH_DUCK =  "Data.duckdb"
db = None
model = None
embedder =None
rag_retriv = None



######################################################
####-----------------STARTUP EVENTS-------------------
######################################################



@asynccontextmanager
async def lifespan(app: FastAPI):
    global db,  embedder
    db = DuckConn(PATH_DUCK)
    embedder = E5_Embeddedr()
    yield  

app = FastAPI(lifespan=lifespan)

####################################################
####--------------------ROUTES----------------------
####################################################

@app.post("/chat")
async def chat(messages: List[Message],parameters:ModelParameter):
    model = get_specific_model(parameters.model)
    model.set_config(temperature=parameters.temperature,max_tokens=parameters.max_token)
    rag_retriv = RAG_Retrival(db, model, embedder)


    # Convert Pydantic objects to dict
    messages_data = [msg.model_dump() for msg in messages]
    
    prompt =  messages_data[0]['content']

    relevant_queures= rag_retriv.query_relevant(prompt)
    relevant_queures = ''.join(relevant_queures)
    final_queury = [message_user(final_prompt(prompt,relevant_queures))]
    model_answer = model.send_message(final_queury)


    return {"status": "success", "response": model_answer}