File size: 12,102 Bytes
58d133e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
#!/usr/bin/env python3
"""

RAG (Retrieval Augmented Generation) System

-------------------------------------------

This module implements a RAG system that processes PDF documents,

uses ChromaDB as a vector database, sentence-transformers for embeddings,

and Google's Gemini as the main LLM. The system follows a conversational pattern.

"""

import os
import logging
from typing import List, Dict, Any, Optional

# Document processing
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter

# Embeddings
from sentence_transformers import SentenceTransformer

# Vector database
import chromadb
from chromadb.utils import embedding_functions

# For Gemini LLM integration
from gemini_wrapper import GoogleGeminiWrapper

from gtts import gTTS

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

class RAGSystem:
    """

    A Retrieval Augmented Generation system that processes PDF documents,

    stores their embeddings in a vector database, and generates responses

    using the Google Gemini model.

    """
    
    def __init__(

        self, 

        pdf_dir: str, 

        gemini_api_key: str,

        embedding_model_name: str = "sentence-transformers/all-MiniLM-L6-v2",

        chunk_size: int = 1000,

        chunk_overlap: int = 200,

        db_directory: str = "./chroma_db"

    ):
        """

        Initialize the RAG system.

        

        Args:

            pdf_dir: Directory containing PDF documents

            gemini_api_key: API key for Google Gemini

            embedding_model_name: Name of the sentence-transformers model

            chunk_size: Size of text chunks for splitting documents

            chunk_overlap: Overlap between consecutive chunks

            db_directory: Directory to store the ChromaDB database

        """
        self.pdf_dir = pdf_dir
        self.chunk_size = chunk_size
        self.chunk_overlap = chunk_overlap
        self.db_directory = db_directory
        
        # Initialize the embedding model
        logger.info(f"Loading embedding model: {embedding_model_name}")
        self.embedding_model = SentenceTransformer(embedding_model_name)
        
        # Initialize the text splitter
        self.text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=self.chunk_size,
            chunk_overlap=self.chunk_overlap,
        )
        
        # Initialize ChromaDB
        logger.info(f"Initializing ChromaDB at {db_directory}")
        self.client = chromadb.PersistentClient(path=db_directory)
        
        # Create a custom embedding function that uses sentence-transformers
        self.sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(
            model_name=embedding_model_name
        )
        
        # Create or get the collection
        self.collection = self.client.get_or_create_collection(
            name="pdf_documents",
            embedding_function=self.sentence_transformer_ef
        )
        
        # Initialize the Gemini LLM
        logger.info("Initializing Google Gemini")
        self.llm = GoogleGeminiWrapper(api_key=gemini_api_key)
        
        # Load conversation history
        self.conversation_history = []
        
    def process_documents(self) -> None:
        """

        Process all PDF documents in the specified directory,

        split them into chunks, generate embeddings, and store in ChromaDB.

        """
        logger.info(f"Processing documents from: {self.pdf_dir}")
        
        # Check if documents are already processed
        if self.collection.count() > 0:
            logger.info(f"Found {self.collection.count()} existing document chunks in the database")
            return
        
        # Process each PDF file in the directory
        pdf_files = [f for f in os.listdir(self.pdf_dir) if f.endswith('.pdf')]
        if not pdf_files:
            logger.warning(f"No PDF files found in {self.pdf_dir}")
            return
            
        logger.info(f"Found {len(pdf_files)} PDF files")
        
        doc_chunks = []
        metadatas = []
        ids = []
        chunk_idx = 0
        
        for pdf_file in pdf_files:
            pdf_path = os.path.join(self.pdf_dir, pdf_file)
            logger.info(f"Processing: {pdf_path}")
            
            # Load PDF
            loader = PyPDFLoader(pdf_path)
            documents = loader.load()
            
            # Split documents into chunks
            chunks = self.text_splitter.split_documents(documents)
            logger.info(f"Split {pdf_file} into {len(chunks)} chunks")
            
            # Prepare data for ChromaDB
            for chunk in chunks:
                doc_chunks.append(chunk.page_content)
                metadatas.append({
                    "source": pdf_file,
                    "page": chunk.metadata.get("page", 0),
                })
                ids.append(f"chunk_{chunk_idx}")
                chunk_idx += 1
        
        # Add documents to ChromaDB
        if doc_chunks:
            logger.info(f"Adding {len(doc_chunks)} chunks to ChromaDB")
            self.collection.add(
                documents=doc_chunks,
                metadatas=metadatas,
                ids=ids
            )
            logger.info("Documents successfully processed and stored")
        else:
            logger.warning("No document chunks were generated")
            
    def retrieve_relevant_chunks(self, query: str, k: int = 3) -> List[Dict[str, Any]]:
        """

        Retrieve the k most relevant document chunks for a given query.

        

        Args:

            query: The query text

            k: Number of relevant chunks to retrieve

            

        Returns:

            List of relevant document chunks with their metadata

        """
        logger.info(f"Retrieving {k} relevant chunks for query: {query}")
        results = self.collection.query(
            query_texts=[query],
            n_results=k
        )
        
        relevant_chunks = []
        if results and results["documents"] and results["documents"][0]:
            for i, doc in enumerate(results["documents"][0]):
                relevant_chunks.append({
                    "content": doc,
                    "metadata": results["metadatas"][0][i] if results["metadatas"] and results["metadatas"][0] else {},
                    "id": results["ids"][0][i] if results["ids"] and results["ids"][0] else f"unknown_{i}"
                })
        
        return relevant_chunks
    
    def generate_response(self, query: str, k: int = 3) -> str:
        """

        Generate a response for a user query using RAG.

        

        Args:

            query: User query

            k: Number of relevant chunks to retrieve

            

        Returns:

            Generated response from the LLM

        """
        # Retrieve relevant document chunks
        relevant_chunks = self.retrieve_relevant_chunks(query, k=k)
        
        if not relevant_chunks:
            logger.warning("No relevant chunks found for the query")
            return "I couldn't find relevant information to answer your question."
        
        # Format context from retrieved chunks
        context = "\n\n".join([f"Document {i+1} (from {chunk['metadata'].get('source', 'unknown')}, page {chunk['metadata'].get('page', 'unknown')}):\n{chunk['content']}" 
                              for i, chunk in enumerate(relevant_chunks)])
        
        # Create prompt for the LLM
        prompt = f"""

        You are a helpful assistant that answers questions based on the provided context.

        

        CONTEXT:

        {context}

        

        QUESTION:

        {query}

        

        Please provide a comprehensive and accurate answer based only on the information in the provided context.

        If the context doesn't contain enough information to answer the question, please say so.

        """
        
        # Generate response using Gemini
        response = self.llm.ask(prompt, max_tokens=500, temperature=0.3)
        return response
    
    def chat(self, user_input: str = None) -> Optional[str]:
        """

        Conduct a conversation with the user using the RAG system.

        

        Args:

            user_input: User's input. If None, starts a new conversation.

            

        Returns:

            System's response or None to exit

        """
        if user_input is None:
            # Initialize conversation
            print("RAG System Initialized. Type 'exit' or 'quit' to end the conversation.")
            user_input = input("You: ")
        
        if user_input.lower() in ['exit', 'quit']:
            print("Ending conversation. Goodbye!")
            return None
        
        # Generate response using RAG
        response = self.generate_response(user_input)
        
        # Update conversation history
        self.conversation_history.append({"user": user_input, "system": response})
        
        return response
    
    def interactive_session(self) -> None:
        """

        Start an interactive chat session with the RAG system.

        """
        print("Welcome to the RAG System!")
        print("Type 'exit' or 'quit' to end the conversation.")
        
        while True:
            user_input = input("\nYou: ")
            
            if user_input.lower() in ['exit', 'quit']:
                print("Ending conversation. Goodbye!")
                break
            
            response = self.generate_response(user_input)
            print(f"\nRAG System: {response}")

# Function to convert text to speech
def text_to_speech(response):
    tts = gTTS(response)
    audio_path = "response_audio.mp3"
    tts.save(audio_path)
    return audio_path

def main():
    """

    Main function to demonstrate the RAG system.

    """
    # Attempt to get the Gemini API key from environment variable
    gemini_api_key = os.getenv("GEMINI_API_KEY")

    if not gemini_api_key:
        # If environment variable is not set or is empty, fallback to the hardcoded key
        hardcoded_api_key = "AIzaSyBisxoehBz8UF0i9kX42f1V3jp-9RNq04g" # Your hardcoded key
        # Check if the environment variable was truly not set (vs. set to an empty string)
        # to decide if we should print the INFO message.
        if os.getenv("GEMINI_API_KEY") is None: # More specific check for unset env variable
            print("INFO: GEMINI_API_KEY environment variable not found. Using hardcoded API key from rag.py.")
        gemini_api_key = hardcoded_api_key
    
    # Final check: if the key is still not set (e.g. if hardcoded key was also empty or None)
    if not gemini_api_key:
        print("Error: Gemini API key is not set.")
        print("Please set the GEMINI_API_KEY environment variable, or ensure it's correctly hardcoded in rag.py.")
        print("To set as environment variable:")
        print("  export GEMINI_API_KEY='your_api_key'  # For Linux/macOS")
        print("  set GEMINI_API_KEY=your_api_key      # For Windows CMD")
        print("  $env:GEMINI_API_KEY='your_api_key'   # For Windows PowerShell")
        return
    
    # Set paths
    current_dir = os.path.dirname(os.path.abspath(__file__))
    pdf_dir = os.path.join(current_dir, "material")
    db_dir = os.path.join(current_dir, "chroma_db")
    
    # Initialize the RAG system
    rag = RAGSystem(
        pdf_dir=pdf_dir,
        gemini_api_key=gemini_api_key,
        db_directory=db_dir
    )
    
    # Process documents
    rag.process_documents()
    
    # Start interactive session
    rag.interactive_session()


if __name__ == "__main__":
    main()