Hubermanbot2 / Rag /rag_pipeline.py
Nightwing11's picture
Update Rag/rag_pipeline.py
58207ab verified
import chromadb
from langchain.text_splitter import RecursiveCharacterTextSplitter
from sentence_transformers import SentenceTransformer
import google.generativeai as genai
import os
import logging
from concurrent.futures import ProcessPoolExecutor, as_completed
from Llm.llm_endpoints import get_llm_response
from utils.get_link import get_source_link
from Prompts.huberman_prompt import huberman_prompt
from tqdm import tqdm
import time
# Configuration
API_KEY = os.getenv("GOOGLE_API_KEY")
if API_KEY:
genai.configure(api_key=API_KEY)
chromadb_path = "app/Rag/chromadb.db"
embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
# Logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')
def get_gemini_embedding(text, max_retries=5):
model = "models/embedding-001"
retry_count = 0
while retry_count < max_retries:
try:
result = genai.embed_content(model=model, content=text)
return result["embedding"]
except ResourceExhausted:
retry_count += 1
wait_time = min(2 ** retry_count + random.random(), 60) # Exponential backoff with jitter
print(f"Rate limit hit. Waiting for {wait_time:.2f} seconds before retry {retry_count}/{max_retries}")
time.sleep(wait_time)
except Exception as e:
print(f"Error: {e}")
raise
# If we've exhausted all retries
raise Exception("Maximum retries reached when attempting to get embeddings")
# Helper Functions
def split_text_to_chunks(docs, chunk_size=1000, chunk_overlap=200):
"""Split text into manageable chunks."""
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
chunks = text_splitter.split_text(docs)
return chunks
def get_new_files(transcripts_folder_path, collection):
"""Find new transcript files that haven't been processed yet."""
all_files = [f for f in os.listdir(transcripts_folder_path) if f.endswith(".txt")]
existing_files = [meta["source"] for meta in collection.get()['metadatas']]
return [f for f in all_files if f not in existing_files]
def process_single_file(file_path):
"""Process a single file and return its chunks."""
with open(file_path, 'r') as f:
content = f.read()
chunks = split_text_to_chunks(content)
return chunks, os.path.basename(file_path)
def batch_embed_chunks(chunks, batch_size=10): # Reduced batch size to avoid hitting limits
"""Embed chunks in batches using Gemini with rate limiting."""
embeddings = []
for i in tqdm(range(0, len(chunks), batch_size), desc="Embedding chunks"):
batch = chunks[i:i + batch_size]
batch_embeddings = []
# Process each chunk individually with Gemini
for chunk in tqdm(batch, desc="Processing batch", leave=False):
# Add a small delay between requests to avoid hitting rate limits
time.sleep(0.5) # Wait half a second between requests
embedding = get_gemini_embedding(chunk)
batch_embeddings.append(embedding)
embeddings.extend(batch_embeddings)
return embeddings
def process_and_add_new_files(transcripts_folder_path, collection):
"""Process and add new transcript files to the vector database."""
new_files = get_new_files(transcripts_folder_path, collection)
if not new_files:
logging.info("No new files to process")
return False
# Use a reasonable number of workers (4 is usually a good default)
n_workers = min(8, len(new_files))
logging.info(f"Using {n_workers} workers for processing")
all_chunks = []
all_metadata = []
all_ids = []
# Process files in parallel
with ProcessPoolExecutor(max_workers=n_workers) as executor:
futures = {
executor.submit(process_single_file, os.path.join(transcripts_folder_path, file)): file
for file in new_files
}
for future in as_completed(futures):
file = futures[future]
try:
chunks, filename = future.result()
file_metadata = [{"source": filename} for _ in range(len(chunks))]
file_ids = [f"{filename}_chunk_{i}" for i in range(len(chunks))]
all_chunks.extend(chunks)
all_metadata.extend(file_metadata)
all_ids.extend(file_ids)
logging.info(f"Processed {filename}")
except Exception as e:
logging.error(f"Error processing {file}: {str(e)}")
continue
# Process embeddings in batches
logging.info(f"Generating embeddings for {len(all_chunks)} chunks")
embeddings = batch_embed_chunks(all_chunks)
# Add to database in batches
batch_size = 500
for i in range(0, len(all_chunks), batch_size):
end_idx = min(i + batch_size, len(all_chunks))
collection.upsert(
documents=all_chunks[i:end_idx],
embeddings=embeddings[i:end_idx],
metadatas=all_metadata[i:end_idx],
ids=all_ids[i:end_idx]
)
logging.info(f"Added batch {i // batch_size + 1} to database")
logging.info(f"Successfully processed {len(new_files)} files")
return True
def query_database(collection, query_text, n_results=3):
"""Retrieve the most relevant chunks for the query."""
query_embeddings = embedding_model.encode(query_text).tolist()
results = collection.query(query_embeddings=query_embeddings, n_results=n_results)
retrieved_docs = results['documents'][0]
metadatas = results['metadatas'][0]
return retrieved_docs, metadatas
def enhance_query_with_history(query_text, summarized_history):
enhance_query = f"{query_text}*2\n\n{summarized_history}"
return enhance_query
def update_conversation_history(history, user_query, bot_response):
"""Update and keeps track of conversation history between user and the bot."""
history.append({"user": user_query, "bot": bot_response})
return history
def generate_response(conversation_history, query_text, retrieved_docs, source_links):
"""Generate a response using retrieved documents and the generative AI model."""
context = " ".join(retrieved_docs)
history_str = "\n".join([f"User: {turn['user']}\nBot: {turn['bot']}" for turn in conversation_history])
sources_str = "\n".join(source_links)
prompt = huberman_prompt.format(
context=context,
sources=sources_str,
history=history_str,
question=query_text
)
response = get_llm_response(prompt)
full_response = f"{response}\n\nSources:\n{sources_str}"
return full_response
def main_workflow(transcripts_folder_path, collection):
"""Run the full RAG workflow."""
new_files_added = process_and_add_new_files(transcripts_folder_path, collection)
if new_files_added:
logging.info("New transcripts added to the database.")
else:
logging.info("No new files found. Using existing database.")
conversation_history = []
while True:
query_text = input("\nEnter your query(or type 'exit' to end):").strip()
if query_text.lower() == "exit":
print("Ending the conversation. Goodbye")
break
query_text_with_conversation_history = enhance_query_with_history(query_text, conversation_history)
retrived_docs, metadatas = query_database(collection, query_text_with_conversation_history)
print("-" * 50)
source_link = get_source_link(metadatas)
print(source_link)
print("-" * 50)
if not retrived_docs:
print("No relevent documents is found")
continue
response = generate_response(conversation_history, query_text, retrived_docs, source_link)
conversation_history = update_conversation_history(conversation_history, query_text, response)
print("\nGenerated Response:")
print(response)