|
import sys |
|
import os |
|
import logging |
|
from typing import List, Dict, Any |
|
from langchain.prompts import PromptTemplate |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) |
|
from app.core.memory import MemoryManager |
|
from app.core.llm import get_llm |
|
|
|
class AssistantAgent: |
|
"""Orchestrates the assistant's functionality, managing RAG and tools.""" |
|
|
|
def __init__(self): |
|
self.memory_manager = MemoryManager() |
|
self.rag_chain = self.memory_manager.create_rag_chain() |
|
self.llm = get_llm() |
|
|
|
|
|
self.system_template = """You are a personal AI assistant that helps the user with their tasks and questions. |
|
You have access to the user's documents and notes through a retrieval system. |
|
When answering questions, leverage this knowledge base to provide specific, factual information. |
|
If the answer is not in the provided context, acknowledge that and give the best general answer you can. |
|
|
|
Context from the user's documents: |
|
{context} |
|
|
|
Chat History: |
|
{chat_history} |
|
|
|
User: {question} |
|
Assistant:""" |
|
|
|
self.rag_prompt = PromptTemplate( |
|
input_variables=["context", "chat_history", "question"], |
|
template=self.system_template |
|
) |
|
|
|
logger.info("AssistantAgent initialized successfully") |
|
|
|
def query(self, question: str) -> Dict[str, Any]: |
|
"""Process a user query and return a response.""" |
|
try: |
|
logger.info(f"Processing query: {question[:50]}...") |
|
|
|
|
|
response = self.rag_chain({"question": question}) |
|
logger.info(f"Raw response keys: {list(response.keys())}") |
|
|
|
|
|
answer = response.get("answer", "I couldn't generate a proper response.") |
|
|
|
|
|
source_docs = response.get("sources", []) |
|
|
|
|
|
sources = [] |
|
for doc in source_docs: |
|
metadata = getattr(doc, 'metadata', {}) |
|
page_content = getattr(doc, 'page_content', str(doc)[:100]) |
|
|
|
sources.append({ |
|
"content": page_content[:100] + "..." if len(page_content) > 100 else page_content, |
|
"source": metadata.get("source", "Unknown"), |
|
"file_name": metadata.get("file_name", "Unknown"), |
|
"page": metadata.get("page", "N/A") if "page" in metadata else None |
|
}) |
|
|
|
logger.info(f"Query processed successfully with {len(sources)} sources") |
|
return { |
|
"answer": answer, |
|
"sources": sources |
|
} |
|
except Exception as e: |
|
logger.error(f"Error in query method: {str(e)}") |
|
|
|
return { |
|
"answer": f"I encountered an error while processing your question. Error details: {str(e)}", |
|
"sources": [] |
|
} |
|
|
|
def add_conversation_to_memory(self, question: str, answer: str): |
|
"""Add a conversation exchange to the memory for future context.""" |
|
try: |
|
|
|
metadata = { |
|
"type": "conversation", |
|
"question": question |
|
} |
|
|
|
|
|
logger.info("Adding conversation to memory") |
|
self.memory_manager.add_texts([answer], [metadata]) |
|
except Exception as e: |
|
logger.error(f"Error adding conversation to memory: {str(e)}") |
|
|