rag / rag.py
munibz's picture
Upload 6 files
58d133e verified
#!/usr/bin/env python3
"""
RAG (Retrieval Augmented Generation) System
-------------------------------------------
This module implements a RAG system that processes PDF documents,
uses ChromaDB as a vector database, sentence-transformers for embeddings,
and Google's Gemini as the main LLM. The system follows a conversational pattern.
"""
import os
import logging
from typing import List, Dict, Any, Optional
# Document processing
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
# Embeddings
from sentence_transformers import SentenceTransformer
# Vector database
import chromadb
from chromadb.utils import embedding_functions
# For Gemini LLM integration
from gemini_wrapper import GoogleGeminiWrapper
from gtts import gTTS
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
class RAGSystem:
"""
A Retrieval Augmented Generation system that processes PDF documents,
stores their embeddings in a vector database, and generates responses
using the Google Gemini model.
"""
def __init__(
self,
pdf_dir: str,
gemini_api_key: str,
embedding_model_name: str = "sentence-transformers/all-MiniLM-L6-v2",
chunk_size: int = 1000,
chunk_overlap: int = 200,
db_directory: str = "./chroma_db"
):
"""
Initialize the RAG system.
Args:
pdf_dir: Directory containing PDF documents
gemini_api_key: API key for Google Gemini
embedding_model_name: Name of the sentence-transformers model
chunk_size: Size of text chunks for splitting documents
chunk_overlap: Overlap between consecutive chunks
db_directory: Directory to store the ChromaDB database
"""
self.pdf_dir = pdf_dir
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
self.db_directory = db_directory
# Initialize the embedding model
logger.info(f"Loading embedding model: {embedding_model_name}")
self.embedding_model = SentenceTransformer(embedding_model_name)
# Initialize the text splitter
self.text_splitter = RecursiveCharacterTextSplitter(
chunk_size=self.chunk_size,
chunk_overlap=self.chunk_overlap,
)
# Initialize ChromaDB
logger.info(f"Initializing ChromaDB at {db_directory}")
self.client = chromadb.PersistentClient(path=db_directory)
# Create a custom embedding function that uses sentence-transformers
self.sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(
model_name=embedding_model_name
)
# Create or get the collection
self.collection = self.client.get_or_create_collection(
name="pdf_documents",
embedding_function=self.sentence_transformer_ef
)
# Initialize the Gemini LLM
logger.info("Initializing Google Gemini")
self.llm = GoogleGeminiWrapper(api_key=gemini_api_key)
# Load conversation history
self.conversation_history = []
def process_documents(self) -> None:
"""
Process all PDF documents in the specified directory,
split them into chunks, generate embeddings, and store in ChromaDB.
"""
logger.info(f"Processing documents from: {self.pdf_dir}")
# Check if documents are already processed
if self.collection.count() > 0:
logger.info(f"Found {self.collection.count()} existing document chunks in the database")
return
# Process each PDF file in the directory
pdf_files = [f for f in os.listdir(self.pdf_dir) if f.endswith('.pdf')]
if not pdf_files:
logger.warning(f"No PDF files found in {self.pdf_dir}")
return
logger.info(f"Found {len(pdf_files)} PDF files")
doc_chunks = []
metadatas = []
ids = []
chunk_idx = 0
for pdf_file in pdf_files:
pdf_path = os.path.join(self.pdf_dir, pdf_file)
logger.info(f"Processing: {pdf_path}")
# Load PDF
loader = PyPDFLoader(pdf_path)
documents = loader.load()
# Split documents into chunks
chunks = self.text_splitter.split_documents(documents)
logger.info(f"Split {pdf_file} into {len(chunks)} chunks")
# Prepare data for ChromaDB
for chunk in chunks:
doc_chunks.append(chunk.page_content)
metadatas.append({
"source": pdf_file,
"page": chunk.metadata.get("page", 0),
})
ids.append(f"chunk_{chunk_idx}")
chunk_idx += 1
# Add documents to ChromaDB
if doc_chunks:
logger.info(f"Adding {len(doc_chunks)} chunks to ChromaDB")
self.collection.add(
documents=doc_chunks,
metadatas=metadatas,
ids=ids
)
logger.info("Documents successfully processed and stored")
else:
logger.warning("No document chunks were generated")
def retrieve_relevant_chunks(self, query: str, k: int = 3) -> List[Dict[str, Any]]:
"""
Retrieve the k most relevant document chunks for a given query.
Args:
query: The query text
k: Number of relevant chunks to retrieve
Returns:
List of relevant document chunks with their metadata
"""
logger.info(f"Retrieving {k} relevant chunks for query: {query}")
results = self.collection.query(
query_texts=[query],
n_results=k
)
relevant_chunks = []
if results and results["documents"] and results["documents"][0]:
for i, doc in enumerate(results["documents"][0]):
relevant_chunks.append({
"content": doc,
"metadata": results["metadatas"][0][i] if results["metadatas"] and results["metadatas"][0] else {},
"id": results["ids"][0][i] if results["ids"] and results["ids"][0] else f"unknown_{i}"
})
return relevant_chunks
def generate_response(self, query: str, k: int = 3) -> str:
"""
Generate a response for a user query using RAG.
Args:
query: User query
k: Number of relevant chunks to retrieve
Returns:
Generated response from the LLM
"""
# Retrieve relevant document chunks
relevant_chunks = self.retrieve_relevant_chunks(query, k=k)
if not relevant_chunks:
logger.warning("No relevant chunks found for the query")
return "I couldn't find relevant information to answer your question."
# Format context from retrieved chunks
context = "\n\n".join([f"Document {i+1} (from {chunk['metadata'].get('source', 'unknown')}, page {chunk['metadata'].get('page', 'unknown')}):\n{chunk['content']}"
for i, chunk in enumerate(relevant_chunks)])
# Create prompt for the LLM
prompt = f"""
You are a helpful assistant that answers questions based on the provided context.
CONTEXT:
{context}
QUESTION:
{query}
Please provide a comprehensive and accurate answer based only on the information in the provided context.
If the context doesn't contain enough information to answer the question, please say so.
"""
# Generate response using Gemini
response = self.llm.ask(prompt, max_tokens=500, temperature=0.3)
return response
def chat(self, user_input: str = None) -> Optional[str]:
"""
Conduct a conversation with the user using the RAG system.
Args:
user_input: User's input. If None, starts a new conversation.
Returns:
System's response or None to exit
"""
if user_input is None:
# Initialize conversation
print("RAG System Initialized. Type 'exit' or 'quit' to end the conversation.")
user_input = input("You: ")
if user_input.lower() in ['exit', 'quit']:
print("Ending conversation. Goodbye!")
return None
# Generate response using RAG
response = self.generate_response(user_input)
# Update conversation history
self.conversation_history.append({"user": user_input, "system": response})
return response
def interactive_session(self) -> None:
"""
Start an interactive chat session with the RAG system.
"""
print("Welcome to the RAG System!")
print("Type 'exit' or 'quit' to end the conversation.")
while True:
user_input = input("\nYou: ")
if user_input.lower() in ['exit', 'quit']:
print("Ending conversation. Goodbye!")
break
response = self.generate_response(user_input)
print(f"\nRAG System: {response}")
# Function to convert text to speech
def text_to_speech(response):
tts = gTTS(response)
audio_path = "response_audio.mp3"
tts.save(audio_path)
return audio_path
def main():
"""
Main function to demonstrate the RAG system.
"""
# Attempt to get the Gemini API key from environment variable
gemini_api_key = os.getenv("GEMINI_API_KEY")
if not gemini_api_key:
# If environment variable is not set or is empty, fallback to the hardcoded key
hardcoded_api_key = "AIzaSyBisxoehBz8UF0i9kX42f1V3jp-9RNq04g" # Your hardcoded key
# Check if the environment variable was truly not set (vs. set to an empty string)
# to decide if we should print the INFO message.
if os.getenv("GEMINI_API_KEY") is None: # More specific check for unset env variable
print("INFO: GEMINI_API_KEY environment variable not found. Using hardcoded API key from rag.py.")
gemini_api_key = hardcoded_api_key
# Final check: if the key is still not set (e.g. if hardcoded key was also empty or None)
if not gemini_api_key:
print("Error: Gemini API key is not set.")
print("Please set the GEMINI_API_KEY environment variable, or ensure it's correctly hardcoded in rag.py.")
print("To set as environment variable:")
print(" export GEMINI_API_KEY='your_api_key' # For Linux/macOS")
print(" set GEMINI_API_KEY=your_api_key # For Windows CMD")
print(" $env:GEMINI_API_KEY='your_api_key' # For Windows PowerShell")
return
# Set paths
current_dir = os.path.dirname(os.path.abspath(__file__))
pdf_dir = os.path.join(current_dir, "material")
db_dir = os.path.join(current_dir, "chroma_db")
# Initialize the RAG system
rag = RAGSystem(
pdf_dir=pdf_dir,
gemini_api_key=gemini_api_key,
db_directory=db_dir
)
# Process documents
rag.process_documents()
# Start interactive session
rag.interactive_session()
if __name__ == "__main__":
main()