Spaces:
Running
Running
from app.core.models import Reranker, GeminiLLM, GeminiEmbed, Wrapper | |
from app.settings import settings, BASE_DIR, logger | |
from app.core.processor import DocumentProcessor | |
from app.core.database import VectorDatabase | |
from typing import Any, AsyncGenerator | |
import aiofiles | |
import asyncio | |
import os | |
class RagSystem: | |
def __init__(self): | |
self.embedder = GeminiEmbed() | |
self.reranker = Reranker(model=settings.models.reranker_model) | |
self.db = VectorDatabase(embedder=self.embedder) | |
self.llm = GeminiLLM() | |
self.wrapper = Wrapper() | |
self.processor = DocumentProcessor() | |
async def get_general_prompt(self, user_prompt: str, collection_name: str) -> str: | |
loop = asyncio.get_event_loop() | |
start = loop.time() | |
await logger.info(f"Time of initializing - {loop.time() - start}") | |
start = loop.time() | |
enhanced_prompt = await self.enhance_prompt(user_prompt.strip()) | |
await logger.info(f"Time of enhancing - {loop.time() - start}") | |
start = loop.time() | |
relevant_chunks = await self.db.search(collection_name, query=enhanced_prompt, top_k=30) | |
await logger.info(f"Time of searching - {loop.time() - start}") | |
start = loop.time() | |
if relevant_chunks is not None and len(relevant_chunks) > 0: | |
ranks = await self.reranker.rank(query=enhanced_prompt, chunks=relevant_chunks) | |
relevant_chunks = [relevant_chunks[rank["corpus_id"]] for rank in ranks] | |
else: | |
relevant_chunks = [] | |
sources = "" | |
prompt = "" | |
for chunk in relevant_chunks[: min(10, len(relevant_chunks))]: | |
citation = ( | |
f"[Source: {chunk.filename}, " | |
f"Page: {chunk.page_number}, " | |
f"Lines: {chunk.start_line}-{chunk.end_line}, " | |
f"Start: {chunk.start_index}]\n\n" | |
) | |
sources += f"Original text:\n{await chunk.get_raw_text()}\nCitation:{citation}" | |
await logger.info(f"Time of reranking - {loop.time() - start}") | |
start = loop.time() | |
async with aiofiles.open( | |
os.path.join(BASE_DIR, "app", "prompt_templates", "test2.txt") | |
) as prompt_file: | |
prompt = await prompt_file.read() | |
prompt += ( | |
"**QUESTION**: " | |
f"{enhanced_prompt}\n" | |
"**CONTEXT DOCUMENTS**:\n" | |
f"{sources}\n" | |
) | |
await logger.info(f"Time of preparing prompt - {loop.time() - start}") | |
return prompt | |
async def enhance_prompt(self, original_prompt: str) -> str: | |
path_to_wrapping_prompt = os.path.join(BASE_DIR, "app", "prompt_templates", "wrapper.txt") | |
enhanced_prompt = "" | |
async with aiofiles.open(path_to_wrapping_prompt, "r") as f: | |
enhanced_prompt = (await f.read()).replace("[USERS_PROMPT]", original_prompt) | |
return await self.wrapper.wrap(enhanced_prompt) | |
async def upload_documents(self, collection_name: str, documents: list[str], split_by: int = 3) -> None: | |
loop = asyncio.get_event_loop() | |
for i in range(0, len(documents), split_by): | |
if settings.debug: | |
await logger.info("New document group is taken into processing") | |
docs = documents[i : i + split_by] | |
loading_time = 0 | |
chunk_generating_time = 0 | |
db_saving_time = 0 | |
if settings.debug: | |
await logger.info("Start loading the documents") | |
start = loop.time() | |
await self.processor.load_documents(documents=docs) | |
loading_time = loop.time() - start | |
if settings.debug: | |
await logger.info("Start loading chunk generation") | |
start = loop.time() | |
await self.processor.generate_chunks() | |
chunk_generating_time = loop.time() - start | |
if settings.debug: | |
await logger.info("Start saving to db") | |
start = loop.time() | |
chunks = await self.processor.get_and_save_unsaved_chunks() | |
await self.db.store(collection_name, chunks) | |
db_saving_time = loop.time() - start | |
if settings.debug: | |
await logger.info( | |
f"loading time = {loading_time}, chunk generation time = {chunk_generating_time}, saving time = {db_saving_time}\n" | |
) | |
async def extract_text(self, response) -> str: | |
text = "" | |
try: | |
text = response.candidates[0].content.parts[0].text | |
except Exception as e: | |
print(e) | |
return text | |
async def generate_response(self, collection_name: str, user_prompt: str, stream: bool = True) -> str: | |
general_prompt = await self.get_general_prompt( | |
user_prompt=user_prompt, collection_name=collection_name | |
) | |
return self.llm.get_response(prompt=general_prompt) | |
async def generate_response_stream(self, collection_name: str, user_prompt: str, stream: bool = True) -> AsyncGenerator[Any, Any]: | |
loop = asyncio.get_event_loop() | |
start = loop.time() | |
general_prompt = await self.get_general_prompt( | |
user_prompt=user_prompt, collection_name=collection_name | |
) | |
logger.info(f"Time of getting prompt message - {loop.time() - start}") | |
async for chunk in self.llm.get_streaming_response( | |
prompt=general_prompt | |
): | |
yield await self.extract_text(chunk) | |
async def get_relevant_chunks(self, collection_name: str, query): | |
relevant_chunks = await self.db.search(collection_name, query=query, top_k=15) | |
relevant_chunks = [ | |
relevant_chunks[ranked["corpus_id"]] | |
for ranked in await self.reranker.rank(query=query, chunks=relevant_chunks) | |
] | |
return relevant_chunks | |
async def create_new_collection(self, collection_name: str) -> None: | |
await self.db.create_collection(collection_name) | |
async def get_collections_names(self) -> list[str]: | |
return await self.db.get_collections() | |