Spaces:
Running
Running
File size: 6,108 Bytes
48ec4db |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
from app.core.models import Reranker, GeminiLLM, GeminiEmbed, Wrapper
from app.settings import settings, BASE_DIR, logger
from app.core.processor import DocumentProcessor
from app.core.database import VectorDatabase
from typing import Any, AsyncGenerator
import aiofiles
import asyncio
import os
class RagSystem:
def __init__(self):
self.embedder = GeminiEmbed()
self.reranker = Reranker(model=settings.models.reranker_model)
self.db = VectorDatabase(embedder=self.embedder)
self.llm = GeminiLLM()
self.wrapper = Wrapper()
self.processor = DocumentProcessor()
async def get_general_prompt(self, user_prompt: str, collection_name: str) -> str:
loop = asyncio.get_event_loop()
start = loop.time()
await logger.info(f"Time of initializing - {loop.time() - start}")
start = loop.time()
enhanced_prompt = await self.enhance_prompt(user_prompt.strip())
await logger.info(f"Time of enhancing - {loop.time() - start}")
start = loop.time()
relevant_chunks = await self.db.search(collection_name, query=enhanced_prompt, top_k=30)
await logger.info(f"Time of searching - {loop.time() - start}")
start = loop.time()
if relevant_chunks is not None and len(relevant_chunks) > 0:
ranks = await self.reranker.rank(query=enhanced_prompt, chunks=relevant_chunks)
relevant_chunks = [relevant_chunks[rank["corpus_id"]] for rank in ranks]
else:
relevant_chunks = []
sources = ""
prompt = ""
for chunk in relevant_chunks[: min(10, len(relevant_chunks))]:
citation = (
f"[Source: {chunk.filename}, "
f"Page: {chunk.page_number}, "
f"Lines: {chunk.start_line}-{chunk.end_line}, "
f"Start: {chunk.start_index}]\n\n"
)
sources += f"Original text:\n{await chunk.get_raw_text()}\nCitation:{citation}"
await logger.info(f"Time of reranking - {loop.time() - start}")
start = loop.time()
async with aiofiles.open(
os.path.join(BASE_DIR, "app", "prompt_templates", "test2.txt")
) as prompt_file:
prompt = await prompt_file.read()
prompt += (
"**QUESTION**: "
f"{enhanced_prompt}\n"
"**CONTEXT DOCUMENTS**:\n"
f"{sources}\n"
)
await logger.info(f"Time of preparing prompt - {loop.time() - start}")
return prompt
async def enhance_prompt(self, original_prompt: str) -> str:
path_to_wrapping_prompt = os.path.join(BASE_DIR, "app", "prompt_templates", "wrapper.txt")
enhanced_prompt = ""
async with aiofiles.open(path_to_wrapping_prompt, "r") as f:
enhanced_prompt = (await f.read()).replace("[USERS_PROMPT]", original_prompt)
return await self.wrapper.wrap(enhanced_prompt)
async def upload_documents(self, collection_name: str, documents: list[str], split_by: int = 3) -> None:
loop = asyncio.get_event_loop()
for i in range(0, len(documents), split_by):
if settings.debug:
await logger.info("New document group is taken into processing")
docs = documents[i : i + split_by]
loading_time = 0
chunk_generating_time = 0
db_saving_time = 0
if settings.debug:
await logger.info("Start loading the documents")
start = loop.time()
await self.processor.load_documents(documents=docs)
loading_time = loop.time() - start
if settings.debug:
await logger.info("Start loading chunk generation")
start = loop.time()
await self.processor.generate_chunks()
chunk_generating_time = loop.time() - start
if settings.debug:
await logger.info("Start saving to db")
start = loop.time()
chunks = await self.processor.get_and_save_unsaved_chunks()
await self.db.store(collection_name, chunks)
db_saving_time = loop.time() - start
if settings.debug:
await logger.info(
f"loading time = {loading_time}, chunk generation time = {chunk_generating_time}, saving time = {db_saving_time}\n"
)
async def extract_text(self, response) -> str:
text = ""
try:
text = response.candidates[0].content.parts[0].text
except Exception as e:
print(e)
return text
async def generate_response(self, collection_name: str, user_prompt: str, stream: bool = True) -> str:
general_prompt = await self.get_general_prompt(
user_prompt=user_prompt, collection_name=collection_name
)
return self.llm.get_response(prompt=general_prompt)
async def generate_response_stream(self, collection_name: str, user_prompt: str, stream: bool = True) -> AsyncGenerator[Any, Any]:
loop = asyncio.get_event_loop()
start = loop.time()
general_prompt = await self.get_general_prompt(
user_prompt=user_prompt, collection_name=collection_name
)
logger.info(f"Time of getting prompt message - {loop.time() - start}")
async for chunk in self.llm.get_streaming_response(
prompt=general_prompt
):
yield await self.extract_text(chunk)
async def get_relevant_chunks(self, collection_name: str, query):
relevant_chunks = await self.db.search(collection_name, query=query, top_k=15)
relevant_chunks = [
relevant_chunks[ranked["corpus_id"]]
for ranked in await self.reranker.rank(query=query, chunks=relevant_chunks)
]
return relevant_chunks
async def create_new_collection(self, collection_name: str) -> None:
await self.db.create_collection(collection_name)
async def get_collections_names(self) -> list[str]:
return await self.db.get_collections()
|