import chromadb from langchain.text_splitter import RecursiveCharacterTextSplitter from sentence_transformers import SentenceTransformer import google.generativeai as genai import os import logging from concurrent.futures import ProcessPoolExecutor, as_completed from Llm.llm_endpoints import get_llm_response from utils.get_link import get_source_link from Prompts.huberman_prompt import huberman_prompt from tqdm import tqdm import time # Configuration API_KEY = os.getenv("GOOGLE_API_KEY") if API_KEY: genai.configure(api_key=API_KEY) chromadb_path = "app/Rag/chromadb.db" embedding_model = SentenceTransformer('all-MiniLM-L6-v2') # Logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s') def get_gemini_embedding(text, max_retries=5): model = "models/embedding-001" retry_count = 0 while retry_count < max_retries: try: result = genai.embed_content(model=model, content=text) return result["embedding"] except ResourceExhausted: retry_count += 1 wait_time = min(2 ** retry_count + random.random(), 60) # Exponential backoff with jitter print(f"Rate limit hit. Waiting for {wait_time:.2f} seconds before retry {retry_count}/{max_retries}") time.sleep(wait_time) except Exception as e: print(f"Error: {e}") raise # If we've exhausted all retries raise Exception("Maximum retries reached when attempting to get embeddings") # Helper Functions def split_text_to_chunks(docs, chunk_size=1000, chunk_overlap=200): """Split text into manageable chunks.""" text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap) chunks = text_splitter.split_text(docs) return chunks def get_new_files(transcripts_folder_path, collection): """Find new transcript files that haven't been processed yet.""" all_files = [f for f in os.listdir(transcripts_folder_path) if f.endswith(".txt")] existing_files = [meta["source"] for meta in collection.get()['metadatas']] return [f for f in all_files if f not in existing_files] def process_single_file(file_path): """Process a single file and return its chunks.""" with open(file_path, 'r') as f: content = f.read() chunks = split_text_to_chunks(content) return chunks, os.path.basename(file_path) def batch_embed_chunks(chunks, batch_size=10): # Reduced batch size to avoid hitting limits """Embed chunks in batches using Gemini with rate limiting.""" embeddings = [] for i in tqdm(range(0, len(chunks), batch_size), desc="Embedding chunks"): batch = chunks[i:i + batch_size] batch_embeddings = [] # Process each chunk individually with Gemini for chunk in tqdm(batch, desc="Processing batch", leave=False): # Add a small delay between requests to avoid hitting rate limits time.sleep(0.5) # Wait half a second between requests embedding = get_gemini_embedding(chunk) batch_embeddings.append(embedding) embeddings.extend(batch_embeddings) return embeddings def process_and_add_new_files(transcripts_folder_path, collection): """Process and add new transcript files to the vector database.""" new_files = get_new_files(transcripts_folder_path, collection) if not new_files: logging.info("No new files to process") return False # Use a reasonable number of workers (4 is usually a good default) n_workers = min(8, len(new_files)) logging.info(f"Using {n_workers} workers for processing") all_chunks = [] all_metadata = [] all_ids = [] # Process files in parallel with ProcessPoolExecutor(max_workers=n_workers) as executor: futures = { executor.submit(process_single_file, os.path.join(transcripts_folder_path, file)): file for file in new_files } for future in as_completed(futures): file = futures[future] try: chunks, filename = future.result() file_metadata = [{"source": filename} for _ in range(len(chunks))] file_ids = [f"{filename}_chunk_{i}" for i in range(len(chunks))] all_chunks.extend(chunks) all_metadata.extend(file_metadata) all_ids.extend(file_ids) logging.info(f"Processed {filename}") except Exception as e: logging.error(f"Error processing {file}: {str(e)}") continue # Process embeddings in batches logging.info(f"Generating embeddings for {len(all_chunks)} chunks") embeddings = batch_embed_chunks(all_chunks) # Add to database in batches batch_size = 500 for i in range(0, len(all_chunks), batch_size): end_idx = min(i + batch_size, len(all_chunks)) collection.upsert( documents=all_chunks[i:end_idx], embeddings=embeddings[i:end_idx], metadatas=all_metadata[i:end_idx], ids=all_ids[i:end_idx] ) logging.info(f"Added batch {i // batch_size + 1} to database") logging.info(f"Successfully processed {len(new_files)} files") return True def query_database(collection, query_text, n_results=3): """Retrieve the most relevant chunks for the query.""" query_embeddings = embedding_model.encode(query_text).tolist() results = collection.query(query_embeddings=query_embeddings, n_results=n_results) retrieved_docs = results['documents'][0] metadatas = results['metadatas'][0] return retrieved_docs, metadatas def enhance_query_with_history(query_text, summarized_history): enhance_query = f"{query_text}*2\n\n{summarized_history}" return enhance_query def update_conversation_history(history, user_query, bot_response): """Update and keeps track of conversation history between user and the bot.""" history.append({"user": user_query, "bot": bot_response}) return history def generate_response(conversation_history, query_text, retrieved_docs, source_links): """Generate a response using retrieved documents and the generative AI model.""" context = " ".join(retrieved_docs) history_str = "\n".join([f"User: {turn['user']}\nBot: {turn['bot']}" for turn in conversation_history]) sources_str = "\n".join(source_links) prompt = huberman_prompt.format( context=context, sources=sources_str, history=history_str, question=query_text ) response = get_llm_response(prompt) full_response = f"{response}\n\nSources:\n{sources_str}" return full_response def main_workflow(transcripts_folder_path, collection): """Run the full RAG workflow.""" new_files_added = process_and_add_new_files(transcripts_folder_path, collection) if new_files_added: logging.info("New transcripts added to the database.") else: logging.info("No new files found. Using existing database.") conversation_history = [] while True: query_text = input("\nEnter your query(or type 'exit' to end):").strip() if query_text.lower() == "exit": print("Ending the conversation. Goodbye") break query_text_with_conversation_history = enhance_query_with_history(query_text, conversation_history) retrived_docs, metadatas = query_database(collection, query_text_with_conversation_history) print("-" * 50) source_link = get_source_link(metadatas) print(source_link) print("-" * 50) if not retrived_docs: print("No relevent documents is found") continue response = generate_response(conversation_history, query_text, retrived_docs, source_link) conversation_history = update_conversation_history(conversation_history, query_text, response) print("\nGenerated Response:") print(response)