|
import os |
|
import sys |
|
import time |
|
import random |
|
import logging |
|
from langchain_community.vectorstores import Qdrant |
|
from langchain.chains import ConversationalRetrievalChain |
|
from langchain.memory import ConversationBufferMemory |
|
from qdrant_client import QdrantClient |
|
from qdrant_client.models import Distance, VectorParams |
|
from langchain.chains.base import Chain |
|
from typing import Dict, List, Any |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) |
|
from app.config import VECTOR_DB_PATH, COLLECTION_NAME |
|
from app.core.llm import get_llm, get_embeddings, get_chat_model |
|
|
|
class CustomRAGChain: |
|
"""Custom RAG chain that always returns standardized output format.""" |
|
|
|
def __init__(self, base_chain): |
|
self.base_chain = base_chain |
|
logger.info("CustomRAGChain initialized") |
|
|
|
def __call__(self, inputs): |
|
"""Process inputs and return standardized output.""" |
|
try: |
|
logger.info("CustomRAGChain processing query") |
|
|
|
result = self.base_chain(inputs) |
|
logger.info(f"Base chain returned keys: {list(result.keys())}") |
|
|
|
|
|
standardized = { |
|
"answer": result.get("answer", "I couldn't generate an answer."), |
|
"sources": result.get("source_documents", []) |
|
} |
|
return standardized |
|
except Exception as e: |
|
logger.error(f"Error in CustomRAGChain: {e}") |
|
return { |
|
"answer": f"Error processing query: {str(e)}", |
|
"sources": [] |
|
} |
|
|
|
class MemoryManager: |
|
"""Manages the RAG memory system using a vector database.""" |
|
|
|
def __init__(self): |
|
self.embeddings = get_embeddings() |
|
self.llm = get_llm() |
|
self.chat_model = get_chat_model() |
|
self.client = self._init_qdrant_client() |
|
self.vectorstore = self._init_vector_store() |
|
self.memory = ConversationBufferMemory( |
|
memory_key="chat_history", |
|
return_messages=True |
|
) |
|
|
|
def _init_qdrant_client(self): |
|
"""Initialize the Qdrant client with retry logic for concurrent access issues.""" |
|
|
|
os.makedirs(VECTOR_DB_PATH, exist_ok=True) |
|
|
|
|
|
time.sleep(random.uniform(0.1, 0.5)) |
|
|
|
|
|
instance_id = str(random.randint(10000, 99999)) |
|
unique_path = os.path.join(VECTOR_DB_PATH, f"instance_{instance_id}") |
|
|
|
max_retries = 3 |
|
retry_count = 0 |
|
|
|
while retry_count < max_retries: |
|
try: |
|
logger.info(f"Attempting to initialize Qdrant client (attempt {retry_count+1}/{max_retries})") |
|
|
|
try: |
|
os.makedirs(unique_path, exist_ok=True) |
|
return QdrantClient(path=unique_path) |
|
except Exception as e: |
|
logger.warning(f"Could not use unique path {unique_path}: {e}") |
|
|
|
|
|
return QdrantClient(path=VECTOR_DB_PATH) |
|
|
|
except RuntimeError as e: |
|
if "already accessed by another instance" in str(e): |
|
retry_count += 1 |
|
wait_time = random.uniform(0.5, 2.0) * retry_count |
|
logger.warning(f"Qdrant concurrent access detected. Retrying in {wait_time:.2f} seconds...") |
|
time.sleep(wait_time) |
|
else: |
|
|
|
raise |
|
|
|
|
|
logger.warning("All Qdrant client initialization attempts failed. Using in-memory mode.") |
|
return QdrantClient(":memory:") |
|
|
|
def _init_vector_store(self): |
|
"""Initialize the vector store.""" |
|
try: |
|
collections = self.client.get_collections().collections |
|
collection_names = [collection.name for collection in collections] |
|
|
|
|
|
vector_size = len(self.embeddings.embed_query("test")) |
|
|
|
if COLLECTION_NAME not in collection_names: |
|
|
|
self.client.create_collection( |
|
collection_name=COLLECTION_NAME, |
|
vectors_config=VectorParams(size=vector_size, distance=Distance.COSINE), |
|
) |
|
logger.info(f"Created new collection: {COLLECTION_NAME}") |
|
|
|
return Qdrant( |
|
client=self.client, |
|
collection_name=COLLECTION_NAME, |
|
embeddings=self.embeddings |
|
) |
|
except Exception as e: |
|
logger.error(f"Error initializing vector store: {e}") |
|
|
|
logger.warning("Using in-memory vector store as fallback.") |
|
return Qdrant.from_texts( |
|
["Hello, I am your AI assistant."], |
|
self.embeddings, |
|
location=":memory:", |
|
collection_name=COLLECTION_NAME |
|
) |
|
|
|
def get_retriever(self): |
|
"""Get the retriever for RAG.""" |
|
return self.vectorstore.as_retriever( |
|
search_type="similarity", |
|
search_kwargs={"k": 5} |
|
) |
|
|
|
def create_rag_chain(self): |
|
"""Create a RAG chain for question answering.""" |
|
try: |
|
|
|
logger.info("Creating base ConversationalRetrievalChain") |
|
|
|
|
|
def simple_chain(query_dict): |
|
try: |
|
|
|
question = query_dict.get("question", "") |
|
if not question.strip(): |
|
return { |
|
"answer": "No question provided.", |
|
"sources": [] |
|
} |
|
|
|
|
|
retriever = self.get_retriever() |
|
relevant_docs = retriever.get_relevant_documents(question) |
|
|
|
|
|
context_parts = [] |
|
for i, doc in enumerate(relevant_docs): |
|
source_name = doc.metadata.get("file_name", "Unknown Source") |
|
context_parts.append(f"Document {i+1} [{source_name}]:\n{doc.page_content}\n") |
|
|
|
context = "\n".join(context_parts) if context_parts else "No relevant documents found." |
|
|
|
|
|
chat_history = self.memory.chat_memory.messages |
|
chat_history_str = "\n".join([f"{msg.type}: {msg.content}" for msg in chat_history]) |
|
|
|
|
|
prompt = f"""You are a helpful, accurate, and precise AI assistant. Answer the following question based on the provided context. |
|
|
|
Follow these guidelines when responding: |
|
1. If the context contains relevant information, use it to provide a direct and specific answer. |
|
2. Format your answer in clear, concise paragraphs with appropriate spacing. |
|
3. If the answer is not in the context, acknowledge this and provide a general response based on your knowledge. |
|
4. Do not mention "context" or "documents" in your answer - integrate the information naturally. |
|
5. Keep answers factual, helpful, and to the point. |
|
6. Never make up information that isn't supported by the context. |
|
|
|
Context: |
|
{context} |
|
|
|
Chat History: |
|
{chat_history_str} |
|
|
|
Question: {question} |
|
Answer:""" |
|
|
|
|
|
try: |
|
answer = self.llm(prompt) |
|
|
|
|
|
if len(answer.strip()) < 20 or "I don't have enough information" in answer: |
|
logger.info("Answer quality check failed, retrying with modified prompt") |
|
|
|
|
|
enhanced_prompt = prompt + "\n\nPlease be as helpful as possible with the information available." |
|
second_attempt = self.llm(enhanced_prompt) |
|
|
|
|
|
if len(second_attempt.strip()) > len(answer.strip()): |
|
answer = second_attempt |
|
except Exception as llm_error: |
|
logger.error(f"Error getting answer from LLM: {llm_error}") |
|
if not answer: |
|
answer = f"I'm having trouble generating a response right now. Please try again in a moment." |
|
|
|
|
|
answer = answer.strip() |
|
|
|
|
|
prefixes_to_remove = ["Answer:", "AI:", "Assistant:"] |
|
for prefix in prefixes_to_remove: |
|
if answer.startswith(prefix): |
|
answer = answer[len(prefix):].strip() |
|
|
|
return { |
|
"answer": answer, |
|
"sources": relevant_docs |
|
} |
|
except Exception as e: |
|
logger.error(f"Error in simple_chain: {e}") |
|
return { |
|
"answer": f"I encountered an error while processing your question. Please try again with a different query.", |
|
"sources": [] |
|
} |
|
|
|
return simple_chain |
|
except Exception as e: |
|
logger.error(f"Error creating RAG chain: {e}") |
|
|
|
|
|
logger.warning("Using fallback mock chain") |
|
|
|
|
|
def mock_chain(inputs): |
|
logger.info(f"Mock chain received query: {inputs.get('question', '')}") |
|
return { |
|
"answer": "I'm having trouble accessing the knowledge base. I can only answer general questions right now.", |
|
"sources": [] |
|
} |
|
|
|
return mock_chain |
|
|
|
def add_texts(self, texts, metadatas=None): |
|
"""Add texts to the vector store.""" |
|
try: |
|
return self.vectorstore.add_texts(texts=texts, metadatas=metadatas) |
|
except Exception as e: |
|
logger.error(f"Error adding texts to vector store: {e}") |
|
return ["error-id-" + str(random.randint(10000, 99999))] |
|
|
|
def similarity_search(self, query, k=5): |
|
"""Perform a similarity search.""" |
|
try: |
|
return self.vectorstore.similarity_search(query, k=k) |
|
except Exception as e: |
|
logger.error(f"Error during similarity search: {e}") |
|
return [] |