import os import sys import time import random import logging from langchain_community.vectorstores import Qdrant from langchain.chains import ConversationalRetrievalChain from langchain.memory import ConversationBufferMemory from qdrant_client import QdrantClient from qdrant_client.models import Distance, VectorParams from langchain.chains.base import Chain from typing import Dict, List, Any # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Add project root to path for imports sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) from app.config import VECTOR_DB_PATH, COLLECTION_NAME from app.core.llm import get_llm, get_embeddings, get_chat_model class CustomRAGChain: """Custom RAG chain that always returns standardized output format.""" def __init__(self, base_chain): self.base_chain = base_chain logger.info("CustomRAGChain initialized") def __call__(self, inputs): """Process inputs and return standardized output.""" try: logger.info("CustomRAGChain processing query") # Execute the underlying chain result = self.base_chain(inputs) logger.info(f"Base chain returned keys: {list(result.keys())}") # Create standardized output standardized = { "answer": result.get("answer", "I couldn't generate an answer."), "sources": result.get("source_documents", []) } return standardized except Exception as e: logger.error(f"Error in CustomRAGChain: {e}") return { "answer": f"Error processing query: {str(e)}", "sources": [] } class MemoryManager: """Manages the RAG memory system using a vector database.""" def __init__(self): self.embeddings = get_embeddings() self.llm = get_llm() self.chat_model = get_chat_model() self.client = self._init_qdrant_client() self.vectorstore = self._init_vector_store() self.memory = ConversationBufferMemory( memory_key="chat_history", return_messages=True ) def _init_qdrant_client(self): """Initialize the Qdrant client with retry logic for concurrent access issues.""" # Create directory if it doesn't exist os.makedirs(VECTOR_DB_PATH, exist_ok=True) # Add a small random delay to reduce chance of concurrent access time.sleep(random.uniform(0.1, 0.5)) # Generate a unique path for this instance to avoid collision instance_id = str(random.randint(10000, 99999)) unique_path = os.path.join(VECTOR_DB_PATH, f"instance_{instance_id}") max_retries = 3 retry_count = 0 while retry_count < max_retries: try: logger.info(f"Attempting to initialize Qdrant client (attempt {retry_count+1}/{max_retries})") # Try to use the unique path first try: os.makedirs(unique_path, exist_ok=True) return QdrantClient(path=unique_path) except Exception as e: logger.warning(f"Could not use unique path {unique_path}: {e}") # Try the main path as fallback return QdrantClient(path=VECTOR_DB_PATH) except RuntimeError as e: if "already accessed by another instance" in str(e): retry_count += 1 wait_time = random.uniform(0.5, 2.0) * retry_count logger.warning(f"Qdrant concurrent access detected. Retrying in {wait_time:.2f} seconds...") time.sleep(wait_time) else: # Different error, don't retry raise # If all retries failed, try to use in-memory storage as last resort logger.warning("All Qdrant client initialization attempts failed. Using in-memory mode.") return QdrantClient(":memory:") def _init_vector_store(self): """Initialize the vector store.""" try: collections = self.client.get_collections().collections collection_names = [collection.name for collection in collections] # Get vector dimension from the embedding model vector_size = len(self.embeddings.embed_query("test")) if COLLECTION_NAME not in collection_names: # Create the collection with appropriate settings self.client.create_collection( collection_name=COLLECTION_NAME, vectors_config=VectorParams(size=vector_size, distance=Distance.COSINE), ) logger.info(f"Created new collection: {COLLECTION_NAME}") return Qdrant( client=self.client, collection_name=COLLECTION_NAME, embeddings=self.embeddings ) except Exception as e: logger.error(f"Error initializing vector store: {e}") # Create a simple in-memory fallback logger.warning("Using in-memory vector store as fallback.") return Qdrant.from_texts( ["Hello, I am your AI assistant."], self.embeddings, location=":memory:", collection_name=COLLECTION_NAME ) def get_retriever(self): """Get the retriever for RAG.""" return self.vectorstore.as_retriever( search_type="similarity", search_kwargs={"k": 5} ) def create_rag_chain(self): """Create a RAG chain for question answering.""" try: # Create the base conversational retrieval chain logger.info("Creating base ConversationalRetrievalChain") # Different approach: create a simple function instead def simple_chain(query_dict): try: # Extract the question question = query_dict.get("question", "") if not question.strip(): return { "answer": "No question provided.", "sources": [] } # Get relevant documents from the retriever retriever = self.get_retriever() relevant_docs = retriever.get_relevant_documents(question) # Format the context from relevant documents context_parts = [] for i, doc in enumerate(relevant_docs): source_name = doc.metadata.get("file_name", "Unknown Source") context_parts.append(f"Document {i+1} [{source_name}]:\n{doc.page_content}\n") context = "\n".join(context_parts) if context_parts else "No relevant documents found." # Get chat history from memory chat_history = self.memory.chat_memory.messages chat_history_str = "\n".join([f"{msg.type}: {msg.content}" for msg in chat_history]) # Create the improved prompt with better instructions prompt = f"""You are a helpful, accurate, and precise AI assistant. Answer the following question based on the provided context. Follow these guidelines when responding: 1. If the context contains relevant information, use it to provide a direct and specific answer. 2. Format your answer in clear, concise paragraphs with appropriate spacing. 3. If the answer is not in the context, acknowledge this and provide a general response based on your knowledge. 4. Do not mention "context" or "documents" in your answer - integrate the information naturally. 5. Keep answers factual, helpful, and to the point. 6. Never make up information that isn't supported by the context. Context: {context} Chat History: {chat_history_str} Question: {question} Answer:""" # Get the answer from the LLM with a timeout and retries try: answer = self.llm(prompt) # Simple quality check - if too short or generic, try again if len(answer.strip()) < 20 or "I don't have enough information" in answer: logger.info("Answer quality check failed, retrying with modified prompt") # Add a more specific instruction to the prompt enhanced_prompt = prompt + "\n\nPlease be as helpful as possible with the information available." second_attempt = self.llm(enhanced_prompt) # Use the better of the two responses if len(second_attempt.strip()) > len(answer.strip()): answer = second_attempt except Exception as llm_error: logger.error(f"Error getting answer from LLM: {llm_error}") if not answer: # If answer wasn't set due to first attempt exception answer = f"I'm having trouble generating a response right now. Please try again in a moment." # Perform basic formatting cleanup answer = answer.strip() # Remove common prefixes that models sometimes add prefixes_to_remove = ["Answer:", "AI:", "Assistant:"] for prefix in prefixes_to_remove: if answer.startswith(prefix): answer = answer[len(prefix):].strip() return { "answer": answer, "sources": relevant_docs } except Exception as e: logger.error(f"Error in simple_chain: {e}") return { "answer": f"I encountered an error while processing your question. Please try again with a different query.", "sources": [] } return simple_chain except Exception as e: logger.error(f"Error creating RAG chain: {e}") # Create a mock chain as fallback logger.warning("Using fallback mock chain") # Create a simple function that mimics the chain's interface def mock_chain(inputs): logger.info(f"Mock chain received query: {inputs.get('question', '')}") return { "answer": "I'm having trouble accessing the knowledge base. I can only answer general questions right now.", "sources": [] } return mock_chain def add_texts(self, texts, metadatas=None): """Add texts to the vector store.""" try: return self.vectorstore.add_texts(texts=texts, metadatas=metadatas) except Exception as e: logger.error(f"Error adding texts to vector store: {e}") return ["error-id-" + str(random.randint(10000, 99999))] def similarity_search(self, query, k=5): """Perform a similarity search.""" try: return self.vectorstore.similarity_search(query, k=k) except Exception as e: logger.error(f"Error during similarity search: {e}") return []