File size: 12,090 Bytes
a33458e 9f0d171 2a735cc a33458e f8ed285 a33458e 9f0d171 a33458e f8ed285 a33458e 9f0d171 a33458e 9f0d171 a33458e 9f0d171 a33458e 9f0d171 a33458e 6c6cf17 f8ed285 b725ad2 f8ed285 b725ad2 f8ed285 b725ad2 f8ed285 28ff371 f8ed285 28ff371 f8ed285 28ff371 f8ed285 b725ad2 f8ed285 b725ad2 28ff371 b725ad2 f8ed285 6c6cf17 b725ad2 6c6cf17 a33458e 9f0d171 a33458e 9f0d171 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 |
import os
import sys
import time
import random
import logging
from langchain_community.vectorstores import Qdrant
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory
from qdrant_client import QdrantClient
from qdrant_client.models import Distance, VectorParams
from langchain.chains.base import Chain
from typing import Dict, List, Any
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Add project root to path for imports
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
from app.config import VECTOR_DB_PATH, COLLECTION_NAME
from app.core.llm import get_llm, get_embeddings, get_chat_model
class CustomRAGChain:
"""Custom RAG chain that always returns standardized output format."""
def __init__(self, base_chain):
self.base_chain = base_chain
logger.info("CustomRAGChain initialized")
def __call__(self, inputs):
"""Process inputs and return standardized output."""
try:
logger.info("CustomRAGChain processing query")
# Execute the underlying chain
result = self.base_chain(inputs)
logger.info(f"Base chain returned keys: {list(result.keys())}")
# Create standardized output
standardized = {
"answer": result.get("answer", "I couldn't generate an answer."),
"sources": result.get("source_documents", [])
}
return standardized
except Exception as e:
logger.error(f"Error in CustomRAGChain: {e}")
return {
"answer": f"Error processing query: {str(e)}",
"sources": []
}
class MemoryManager:
"""Manages the RAG memory system using a vector database."""
def __init__(self):
self.embeddings = get_embeddings()
self.llm = get_llm()
self.chat_model = get_chat_model()
self.client = self._init_qdrant_client()
self.vectorstore = self._init_vector_store()
self.memory = ConversationBufferMemory(
memory_key="chat_history",
return_messages=True
)
def _init_qdrant_client(self):
"""Initialize the Qdrant client with retry logic for concurrent access issues."""
# Create directory if it doesn't exist
os.makedirs(VECTOR_DB_PATH, exist_ok=True)
# Add a small random delay to reduce chance of concurrent access
time.sleep(random.uniform(0.1, 0.5))
# Generate a unique path for this instance to avoid collision
instance_id = str(random.randint(10000, 99999))
unique_path = os.path.join(VECTOR_DB_PATH, f"instance_{instance_id}")
max_retries = 3
retry_count = 0
while retry_count < max_retries:
try:
logger.info(f"Attempting to initialize Qdrant client (attempt {retry_count+1}/{max_retries})")
# Try to use the unique path first
try:
os.makedirs(unique_path, exist_ok=True)
return QdrantClient(path=unique_path)
except Exception as e:
logger.warning(f"Could not use unique path {unique_path}: {e}")
# Try the main path as fallback
return QdrantClient(path=VECTOR_DB_PATH)
except RuntimeError as e:
if "already accessed by another instance" in str(e):
retry_count += 1
wait_time = random.uniform(0.5, 2.0) * retry_count
logger.warning(f"Qdrant concurrent access detected. Retrying in {wait_time:.2f} seconds...")
time.sleep(wait_time)
else:
# Different error, don't retry
raise
# If all retries failed, try to use in-memory storage as last resort
logger.warning("All Qdrant client initialization attempts failed. Using in-memory mode.")
return QdrantClient(":memory:")
def _init_vector_store(self):
"""Initialize the vector store."""
try:
collections = self.client.get_collections().collections
collection_names = [collection.name for collection in collections]
# Get vector dimension from the embedding model
vector_size = len(self.embeddings.embed_query("test"))
if COLLECTION_NAME not in collection_names:
# Create the collection with appropriate settings
self.client.create_collection(
collection_name=COLLECTION_NAME,
vectors_config=VectorParams(size=vector_size, distance=Distance.COSINE),
)
logger.info(f"Created new collection: {COLLECTION_NAME}")
return Qdrant(
client=self.client,
collection_name=COLLECTION_NAME,
embeddings=self.embeddings
)
except Exception as e:
logger.error(f"Error initializing vector store: {e}")
# Create a simple in-memory fallback
logger.warning("Using in-memory vector store as fallback.")
return Qdrant.from_texts(
["Hello, I am your AI assistant."],
self.embeddings,
location=":memory:",
collection_name=COLLECTION_NAME
)
def get_retriever(self):
"""Get the retriever for RAG."""
return self.vectorstore.as_retriever(
search_type="similarity",
search_kwargs={"k": 5}
)
def create_rag_chain(self):
"""Create a RAG chain for question answering."""
try:
# Create the base conversational retrieval chain
logger.info("Creating base ConversationalRetrievalChain")
# Different approach: create a simple function instead
def simple_chain(query_dict):
try:
# Extract the question
question = query_dict.get("question", "")
if not question.strip():
return {
"answer": "No question provided.",
"sources": []
}
# Get relevant documents from the retriever
retriever = self.get_retriever()
relevant_docs = retriever.get_relevant_documents(question)
# Format the context from relevant documents
context_parts = []
for i, doc in enumerate(relevant_docs):
source_name = doc.metadata.get("file_name", "Unknown Source")
context_parts.append(f"Document {i+1} [{source_name}]:\n{doc.page_content}\n")
context = "\n".join(context_parts) if context_parts else "No relevant documents found."
# Get chat history from memory
chat_history = self.memory.chat_memory.messages
chat_history_str = "\n".join([f"{msg.type}: {msg.content}" for msg in chat_history])
# Create the improved prompt with better instructions
prompt = f"""You are a helpful, accurate, and precise AI assistant. Answer the following question based on the provided context.
Follow these guidelines when responding:
1. If the context contains relevant information, use it to provide a direct and specific answer.
2. Format your answer in clear, concise paragraphs with appropriate spacing.
3. If the answer is not in the context, acknowledge this and provide a general response based on your knowledge.
4. Do not mention "context" or "documents" in your answer - integrate the information naturally.
5. Keep answers factual, helpful, and to the point.
6. Never make up information that isn't supported by the context.
Context:
{context}
Chat History:
{chat_history_str}
Question: {question}
Answer:"""
# Get the answer from the LLM with a timeout and retries
try:
answer = self.llm(prompt)
# Simple quality check - if too short or generic, try again
if len(answer.strip()) < 20 or "I don't have enough information" in answer:
logger.info("Answer quality check failed, retrying with modified prompt")
# Add a more specific instruction to the prompt
enhanced_prompt = prompt + "\n\nPlease be as helpful as possible with the information available."
second_attempt = self.llm(enhanced_prompt)
# Use the better of the two responses
if len(second_attempt.strip()) > len(answer.strip()):
answer = second_attempt
except Exception as llm_error:
logger.error(f"Error getting answer from LLM: {llm_error}")
if not answer: # If answer wasn't set due to first attempt exception
answer = f"I'm having trouble generating a response right now. Please try again in a moment."
# Perform basic formatting cleanup
answer = answer.strip()
# Remove common prefixes that models sometimes add
prefixes_to_remove = ["Answer:", "AI:", "Assistant:"]
for prefix in prefixes_to_remove:
if answer.startswith(prefix):
answer = answer[len(prefix):].strip()
return {
"answer": answer,
"sources": relevant_docs
}
except Exception as e:
logger.error(f"Error in simple_chain: {e}")
return {
"answer": f"I encountered an error while processing your question. Please try again with a different query.",
"sources": []
}
return simple_chain
except Exception as e:
logger.error(f"Error creating RAG chain: {e}")
# Create a mock chain as fallback
logger.warning("Using fallback mock chain")
# Create a simple function that mimics the chain's interface
def mock_chain(inputs):
logger.info(f"Mock chain received query: {inputs.get('question', '')}")
return {
"answer": "I'm having trouble accessing the knowledge base. I can only answer general questions right now.",
"sources": []
}
return mock_chain
def add_texts(self, texts, metadatas=None):
"""Add texts to the vector store."""
try:
return self.vectorstore.add_texts(texts=texts, metadatas=metadatas)
except Exception as e:
logger.error(f"Error adding texts to vector store: {e}")
return ["error-id-" + str(random.randint(10000, 99999))]
def similarity_search(self, query, k=5):
"""Perform a similarity search."""
try:
return self.vectorstore.similarity_search(query, k=k)
except Exception as e:
logger.error(f"Error during similarity search: {e}")
return [] |