File size: 6,108 Bytes
48ec4db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
from app.core.models import Reranker, GeminiLLM, GeminiEmbed, Wrapper
from app.settings import settings, BASE_DIR, logger
from app.core.processor import DocumentProcessor
from app.core.database import VectorDatabase
from typing import Any, AsyncGenerator
import aiofiles
import asyncio
import os


class RagSystem:
    def __init__(self):
        self.embedder = GeminiEmbed()
        self.reranker = Reranker(model=settings.models.reranker_model)
        self.db = VectorDatabase(embedder=self.embedder)
        self.llm = GeminiLLM()
        self.wrapper = Wrapper()
        self.processor = DocumentProcessor()


    async def get_general_prompt(self, user_prompt: str, collection_name: str) -> str:
        loop = asyncio.get_event_loop()
        start = loop.time()
        await logger.info(f"Time of initializing - {loop.time() - start}")

        start = loop.time()
        enhanced_prompt = await self.enhance_prompt(user_prompt.strip())
        await logger.info(f"Time of enhancing - {loop.time() - start}")

        start = loop.time()
        relevant_chunks = await self.db.search(collection_name, query=enhanced_prompt, top_k=30)
        await logger.info(f"Time of searching - {loop.time() - start}")

        start = loop.time()
        if relevant_chunks is not None and len(relevant_chunks) > 0:
            ranks = await self.reranker.rank(query=enhanced_prompt, chunks=relevant_chunks)
            relevant_chunks = [relevant_chunks[rank["corpus_id"]] for rank in ranks]
        else:
            relevant_chunks = []

        sources = ""
        prompt = ""

        for chunk in relevant_chunks[: min(10, len(relevant_chunks))]:
            citation = (
                f"[Source: {chunk.filename}, "
                f"Page: {chunk.page_number}, "
                f"Lines: {chunk.start_line}-{chunk.end_line}, "
                f"Start: {chunk.start_index}]\n\n"
            )
            sources += f"Original text:\n{await chunk.get_raw_text()}\nCitation:{citation}"

        await logger.info(f"Time of reranking - {loop.time() - start}")

        start = loop.time()
        async with aiofiles.open(
            os.path.join(BASE_DIR, "app", "prompt_templates", "test2.txt")
        ) as prompt_file:
            prompt = await prompt_file.read()

        prompt += (
            "**QUESTION**: "
            f"{enhanced_prompt}\n"
            "**CONTEXT DOCUMENTS**:\n"
            f"{sources}\n"
        )
        await logger.info(f"Time of preparing prompt - {loop.time() - start}")

        return prompt

    async def enhance_prompt(self, original_prompt: str) -> str:
        path_to_wrapping_prompt = os.path.join(BASE_DIR, "app", "prompt_templates", "wrapper.txt")
        enhanced_prompt = ""
        async with aiofiles.open(path_to_wrapping_prompt, "r") as f:
            enhanced_prompt = (await f.read()).replace("[USERS_PROMPT]", original_prompt)
        return await self.wrapper.wrap(enhanced_prompt)

    async def upload_documents(self, collection_name: str, documents: list[str], split_by: int = 3) -> None:
        loop = asyncio.get_event_loop()
        for i in range(0, len(documents), split_by):

            if settings.debug:
                await logger.info("New document group is taken into processing")

            docs = documents[i : i + split_by]

            loading_time = 0
            chunk_generating_time = 0
            db_saving_time = 0

            if settings.debug:
                await logger.info("Start loading the documents")

            start = loop.time()
            await self.processor.load_documents(documents=docs)
            loading_time = loop.time() - start

            if settings.debug:
                await logger.info("Start loading chunk generation")

            start = loop.time()
            await self.processor.generate_chunks()
            chunk_generating_time = loop.time() - start

            if settings.debug:
                await logger.info("Start saving to db")

            start = loop.time()
            chunks = await self.processor.get_and_save_unsaved_chunks()
            await self.db.store(collection_name, chunks)
            db_saving_time = loop.time() - start

            if settings.debug:
                await logger.info(
                    f"loading time = {loading_time}, chunk generation time = {chunk_generating_time}, saving time = {db_saving_time}\n"
                )

    async def extract_text(self, response) -> str:
        text = ""
        try:
            text = response.candidates[0].content.parts[0].text
        except Exception as e:
            print(e)
        return text

    async def generate_response(self, collection_name: str, user_prompt: str, stream: bool = True) -> str:
        general_prompt = await self.get_general_prompt(
            user_prompt=user_prompt, collection_name=collection_name
        )

        return self.llm.get_response(prompt=general_prompt)

    async def generate_response_stream(self, collection_name: str, user_prompt: str, stream: bool = True) -> AsyncGenerator[Any, Any]:
        loop = asyncio.get_event_loop()
        start = loop.time()
        general_prompt = await self.get_general_prompt(
            user_prompt=user_prompt, collection_name=collection_name
        )
        logger.info(f"Time of getting prompt message - {loop.time() - start}")

        async for chunk in self.llm.get_streaming_response(
            prompt=general_prompt
        ):
            yield await self.extract_text(chunk)

    async def get_relevant_chunks(self, collection_name: str, query):
        relevant_chunks = await self.db.search(collection_name, query=query, top_k=15)
        relevant_chunks = [
            relevant_chunks[ranked["corpus_id"]]
            for ranked in await self.reranker.rank(query=query, chunks=relevant_chunks)
        ]
        return relevant_chunks

    async def create_new_collection(self, collection_name: str) -> None:
        await self.db.create_collection(collection_name)

    async def get_collections_names(self) -> list[str]:
        return await self.db.get_collections()