2B / app /core /memory.py
37-AN
Update for Hugging Face Space deployment
2a735cc
import os
import sys
import time
import random
import logging
from langchain_community.vectorstores import Qdrant
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory
from qdrant_client import QdrantClient
from qdrant_client.models import Distance, VectorParams
from langchain.chains.base import Chain
from typing import Dict, List, Any
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Add project root to path for imports
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
from app.config import VECTOR_DB_PATH, COLLECTION_NAME
from app.core.llm import get_llm, get_embeddings, get_chat_model
class CustomRAGChain:
"""Custom RAG chain that always returns standardized output format."""
def __init__(self, base_chain):
self.base_chain = base_chain
logger.info("CustomRAGChain initialized")
def __call__(self, inputs):
"""Process inputs and return standardized output."""
try:
logger.info("CustomRAGChain processing query")
# Execute the underlying chain
result = self.base_chain(inputs)
logger.info(f"Base chain returned keys: {list(result.keys())}")
# Create standardized output
standardized = {
"answer": result.get("answer", "I couldn't generate an answer."),
"sources": result.get("source_documents", [])
}
return standardized
except Exception as e:
logger.error(f"Error in CustomRAGChain: {e}")
return {
"answer": f"Error processing query: {str(e)}",
"sources": []
}
class MemoryManager:
"""Manages the RAG memory system using a vector database."""
def __init__(self):
self.embeddings = get_embeddings()
self.llm = get_llm()
self.chat_model = get_chat_model()
self.client = self._init_qdrant_client()
self.vectorstore = self._init_vector_store()
self.memory = ConversationBufferMemory(
memory_key="chat_history",
return_messages=True
)
def _init_qdrant_client(self):
"""Initialize the Qdrant client with retry logic for concurrent access issues."""
# Create directory if it doesn't exist
os.makedirs(VECTOR_DB_PATH, exist_ok=True)
# Add a small random delay to reduce chance of concurrent access
time.sleep(random.uniform(0.1, 0.5))
# Generate a unique path for this instance to avoid collision
instance_id = str(random.randint(10000, 99999))
unique_path = os.path.join(VECTOR_DB_PATH, f"instance_{instance_id}")
max_retries = 3
retry_count = 0
while retry_count < max_retries:
try:
logger.info(f"Attempting to initialize Qdrant client (attempt {retry_count+1}/{max_retries})")
# Try to use the unique path first
try:
os.makedirs(unique_path, exist_ok=True)
return QdrantClient(path=unique_path)
except Exception as e:
logger.warning(f"Could not use unique path {unique_path}: {e}")
# Try the main path as fallback
return QdrantClient(path=VECTOR_DB_PATH)
except RuntimeError as e:
if "already accessed by another instance" in str(e):
retry_count += 1
wait_time = random.uniform(0.5, 2.0) * retry_count
logger.warning(f"Qdrant concurrent access detected. Retrying in {wait_time:.2f} seconds...")
time.sleep(wait_time)
else:
# Different error, don't retry
raise
# If all retries failed, try to use in-memory storage as last resort
logger.warning("All Qdrant client initialization attempts failed. Using in-memory mode.")
return QdrantClient(":memory:")
def _init_vector_store(self):
"""Initialize the vector store."""
try:
collections = self.client.get_collections().collections
collection_names = [collection.name for collection in collections]
# Get vector dimension from the embedding model
vector_size = len(self.embeddings.embed_query("test"))
if COLLECTION_NAME not in collection_names:
# Create the collection with appropriate settings
self.client.create_collection(
collection_name=COLLECTION_NAME,
vectors_config=VectorParams(size=vector_size, distance=Distance.COSINE),
)
logger.info(f"Created new collection: {COLLECTION_NAME}")
return Qdrant(
client=self.client,
collection_name=COLLECTION_NAME,
embeddings=self.embeddings
)
except Exception as e:
logger.error(f"Error initializing vector store: {e}")
# Create a simple in-memory fallback
logger.warning("Using in-memory vector store as fallback.")
return Qdrant.from_texts(
["Hello, I am your AI assistant."],
self.embeddings,
location=":memory:",
collection_name=COLLECTION_NAME
)
def get_retriever(self):
"""Get the retriever for RAG."""
return self.vectorstore.as_retriever(
search_type="similarity",
search_kwargs={"k": 5}
)
def create_rag_chain(self):
"""Create a RAG chain for question answering."""
try:
# Create the base conversational retrieval chain
logger.info("Creating base ConversationalRetrievalChain")
# Different approach: create a simple function instead
def simple_chain(query_dict):
try:
# Extract the question
question = query_dict.get("question", "")
if not question.strip():
return {
"answer": "No question provided.",
"sources": []
}
# Get relevant documents from the retriever
retriever = self.get_retriever()
relevant_docs = retriever.get_relevant_documents(question)
# Format the context from relevant documents
context_parts = []
for i, doc in enumerate(relevant_docs):
source_name = doc.metadata.get("file_name", "Unknown Source")
context_parts.append(f"Document {i+1} [{source_name}]:\n{doc.page_content}\n")
context = "\n".join(context_parts) if context_parts else "No relevant documents found."
# Get chat history from memory
chat_history = self.memory.chat_memory.messages
chat_history_str = "\n".join([f"{msg.type}: {msg.content}" for msg in chat_history])
# Create the improved prompt with better instructions
prompt = f"""You are a helpful, accurate, and precise AI assistant. Answer the following question based on the provided context.
Follow these guidelines when responding:
1. If the context contains relevant information, use it to provide a direct and specific answer.
2. Format your answer in clear, concise paragraphs with appropriate spacing.
3. If the answer is not in the context, acknowledge this and provide a general response based on your knowledge.
4. Do not mention "context" or "documents" in your answer - integrate the information naturally.
5. Keep answers factual, helpful, and to the point.
6. Never make up information that isn't supported by the context.
Context:
{context}
Chat History:
{chat_history_str}
Question: {question}
Answer:"""
# Get the answer from the LLM with a timeout and retries
try:
answer = self.llm(prompt)
# Simple quality check - if too short or generic, try again
if len(answer.strip()) < 20 or "I don't have enough information" in answer:
logger.info("Answer quality check failed, retrying with modified prompt")
# Add a more specific instruction to the prompt
enhanced_prompt = prompt + "\n\nPlease be as helpful as possible with the information available."
second_attempt = self.llm(enhanced_prompt)
# Use the better of the two responses
if len(second_attempt.strip()) > len(answer.strip()):
answer = second_attempt
except Exception as llm_error:
logger.error(f"Error getting answer from LLM: {llm_error}")
if not answer: # If answer wasn't set due to first attempt exception
answer = f"I'm having trouble generating a response right now. Please try again in a moment."
# Perform basic formatting cleanup
answer = answer.strip()
# Remove common prefixes that models sometimes add
prefixes_to_remove = ["Answer:", "AI:", "Assistant:"]
for prefix in prefixes_to_remove:
if answer.startswith(prefix):
answer = answer[len(prefix):].strip()
return {
"answer": answer,
"sources": relevant_docs
}
except Exception as e:
logger.error(f"Error in simple_chain: {e}")
return {
"answer": f"I encountered an error while processing your question. Please try again with a different query.",
"sources": []
}
return simple_chain
except Exception as e:
logger.error(f"Error creating RAG chain: {e}")
# Create a mock chain as fallback
logger.warning("Using fallback mock chain")
# Create a simple function that mimics the chain's interface
def mock_chain(inputs):
logger.info(f"Mock chain received query: {inputs.get('question', '')}")
return {
"answer": "I'm having trouble accessing the knowledge base. I can only answer general questions right now.",
"sources": []
}
return mock_chain
def add_texts(self, texts, metadatas=None):
"""Add texts to the vector store."""
try:
return self.vectorstore.add_texts(texts=texts, metadatas=metadatas)
except Exception as e:
logger.error(f"Error adding texts to vector store: {e}")
return ["error-id-" + str(random.randint(10000, 99999))]
def similarity_search(self, query, k=5):
"""Perform a similarity search."""
try:
return self.vectorstore.similarity_search(query, k=k)
except Exception as e:
logger.error(f"Error during similarity search: {e}")
return []