File size: 3,916 Bytes
e53c2d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
from app.models import LocalLLM, Embedder, Reranker, Gemini
from app.processor import DocumentProcessor
from app.database import VectorDatabase
import time
import os
from app.settings import reranker_model, embedder_model, base_path, use_gemini


# TODO: write a better prompt
# TODO: wrap original(user's) prompt with LLM's one
#
class RagSystem:
    def __init__(self):
        self.embedder = Embedder(model=embedder_model)
        self.reranker = Reranker(model=reranker_model)
        self.processor = DocumentProcessor()
        self.db = VectorDatabase(embedder=self.embedder)
        self.llm = Gemini() if use_gemini else LocalLLM()

    '''
    Provides a prompt with substituted context from chunks

    TODO: add template to prompt without docs
    '''

    def get_prompt_template(self, user_prompt: str, chunks: list) -> str:
        sources = ""
        prompt = ""

        for chunk in chunks:
            citation = (f"[Source: {chunk.filename}, "
                        f"Page: {chunk.page_number}, "
                        f"Lines: {chunk.start_line}-{chunk.end_line}, "
                        f"Start: {chunk.start_index}]\n\n")
            sources += f"Original text:\n{chunk.get_raw_text()}\nCitation:{citation}"

        with open(os.path.join(base_path, "prompt_templates", "test2.txt")) as f:
            prompt = f.read()

        prompt += (
            "**QUESTION**: "
            f"{user_prompt.strip()}\n"
            "**CONTEXT DOCUMENTS**:\n"
            f"{sources}\n"
        )

        return prompt

    '''
    Splits the list of documents into groups with 'split_by' docs (done to avoid qdrant_client connection error handling), loads them,
    splits into chunks, and saves to db
    '''

    def upload_documents(self, documents: list[str], split_by: int = 3, debug_mode: bool = True) -> None:

        for i in range(0, len(documents), split_by):

            if debug_mode:
                print("<" + "-" * 10 + "New document group is taken into processing" + "-" * 10 + ">")

            docs = documents[i: i + split_by]

            loading_time = 0
            chunk_generating_time = 0
            db_saving_time = 0

            print("Start loading the documents")
            start = time.time()
            self.processor.load_documents(documents=docs, add_to_unprocessed=True)
            loading_time = time.time() - start

            print("Start loading chunk generation")
            start = time.time()
            self.processor.generate_chunks()
            chunk_generating_time = time.time() - start

            print("Start saving to db")
            start = time.time()
            self.db.store(self.processor.get_and_save_unsaved_chunks())
            db_saving_time = time.time() - start

            if debug_mode:
                print(
                    f"loading time = {loading_time}, chunk generation time = {chunk_generating_time}, saving time = {db_saving_time}\n")

    '''
    Produces answer to user's request. First, finds the most relevant chunks, generates prompt with them, and asks llm
    '''

    def generate_response(self, user_prompt: str) -> str:
        relevant_chunks = self.db.search(query=user_prompt, top_k=15)
        relevant_chunks = [relevant_chunks[ranked["corpus_id"]]
                           for ranked in self.reranker.rank(query=user_prompt, chunks=relevant_chunks)[:3]]

        general_prompt = self.get_prompt_template(user_prompt=user_prompt, chunks=relevant_chunks)
        return self.llm.get_response(prompt=general_prompt)

    '''
    Produces the list of the most relevant chunkВs
    '''

    def get_relevant_chunks(self, query):
        relevant_chunks = self.db.search(query=query, top_k=15)
        relevant_chunks = [relevant_chunks[ranked["corpus_id"]]
                           for ranked in self.reranker.rank(query=query, chunks=relevant_chunks)]
        return relevant_chunks