2B / app /core /agent.py
37-AN
Fix output keys format with wrapper function
b725ad2
import sys
import os
import logging
from typing import List, Dict, Any
from langchain.prompts import PromptTemplate
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Add project root to path for imports
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
from app.core.memory import MemoryManager
from app.core.llm import get_llm
class AssistantAgent:
"""Orchestrates the assistant's functionality, managing RAG and tools."""
def __init__(self):
self.memory_manager = MemoryManager()
self.rag_chain = self.memory_manager.create_rag_chain()
self.llm = get_llm()
# Define a system prompt template
self.system_template = """You are a personal AI assistant that helps the user with their tasks and questions.
You have access to the user's documents and notes through a retrieval system.
When answering questions, leverage this knowledge base to provide specific, factual information.
If the answer is not in the provided context, acknowledge that and give the best general answer you can.
Context from the user's documents:
{context}
Chat History:
{chat_history}
User: {question}
Assistant:"""
self.rag_prompt = PromptTemplate(
input_variables=["context", "chat_history", "question"],
template=self.system_template
)
logger.info("AssistantAgent initialized successfully")
def query(self, question: str) -> Dict[str, Any]:
"""Process a user query and return a response."""
try:
logger.info(f"Processing query: {question[:50]}...")
# Use the RAG chain to get an answer
response = self.rag_chain({"question": question})
logger.info(f"Raw response keys: {list(response.keys())}")
# Extract the answer (should now be normalized by our wrapper)
answer = response.get("answer", "I couldn't generate a proper response.")
# Extract sources (should now be normalized by our wrapper)
source_docs = response.get("sources", [])
# Format source documents for display
sources = []
for doc in source_docs:
metadata = getattr(doc, 'metadata', {})
page_content = getattr(doc, 'page_content', str(doc)[:100])
sources.append({
"content": page_content[:100] + "..." if len(page_content) > 100 else page_content,
"source": metadata.get("source", "Unknown"),
"file_name": metadata.get("file_name", "Unknown"),
"page": metadata.get("page", "N/A") if "page" in metadata else None
})
logger.info(f"Query processed successfully with {len(sources)} sources")
return {
"answer": answer,
"sources": sources
}
except Exception as e:
logger.error(f"Error in query method: {str(e)}")
# Return a graceful fallback response
return {
"answer": f"I encountered an error while processing your question. Error details: {str(e)}",
"sources": []
}
def add_conversation_to_memory(self, question: str, answer: str):
"""Add a conversation exchange to the memory for future context."""
try:
# Create metadata for the conversation
metadata = {
"type": "conversation",
"question": question
}
# Add the exchange to the vector store
logger.info("Adding conversation to memory")
self.memory_manager.add_texts([answer], [metadata])
except Exception as e:
logger.error(f"Error adding conversation to memory: {str(e)}")
# Silently fail - this is not critical for the user experience