|
""" |
|
Memory Manager Component |
|
|
|
This module provides memory capabilities for the GAIA agent, |
|
handling cache operations and integrating with Supabase or local storage. |
|
""" |
|
|
|
import logging |
|
import os |
|
import json |
|
import time |
|
import hashlib |
|
from typing import Dict, Any, List, Optional, Union |
|
import traceback |
|
|
|
logger = logging.getLogger("gaia_agent.components.memory_manager") |
|
|
|
class MemoryManager: |
|
""" |
|
Manages memory operations for the GAIA agent. |
|
Provides caching, persistence, and retrieval functionalities. |
|
""" |
|
|
|
def __init__(self, config: Optional[Dict[str, Any]] = None): |
|
""" |
|
Initialize the memory manager with configuration. |
|
|
|
Args: |
|
config: Configuration dictionary for memory operations |
|
""" |
|
self.config = config or {} |
|
self.supabase_client = None |
|
self.use_supabase = self.config.get("use_supabase", False) |
|
self.cache_enabled = self.config.get("cache_enabled", True) |
|
self.cache_ttl = self.config.get("cache_ttl", 3600) |
|
self.local_cache = {} |
|
|
|
|
|
self.memory_dir = self.config.get("memory_dir", "agent_memory") |
|
os.makedirs(self.memory_dir, exist_ok=True) |
|
|
|
|
|
if self.use_supabase: |
|
self._initialize_supabase() |
|
|
|
logger.info(f"MemoryManager initialized (Supabase: {self.use_supabase}, Cache: {self.cache_enabled})") |
|
|
|
def _initialize_supabase(self): |
|
"""Initialize Supabase connection if credentials are available.""" |
|
try: |
|
import os |
|
from supabase import create_client, Client |
|
|
|
supabase_url = os.getenv("SUPABASE_URL", "") |
|
supabase_key = os.getenv("SUPABASE_KEY", "") |
|
|
|
if not supabase_url or not supabase_key: |
|
logger.warning("Supabase credentials not found in environment variables") |
|
return |
|
|
|
|
|
self.supabase_client = create_client(supabase_url, supabase_key) |
|
|
|
|
|
user_query = self.supabase_client.table("interactions").select("*").limit(1).execute() |
|
logger.info("Successfully connected to Supabase") |
|
|
|
except ImportError as e: |
|
logger.warning(f"Supabase library not available: {str(e)}") |
|
self.use_supabase = False |
|
except Exception as e: |
|
logger.error(f"Failed to initialize Supabase: {str(e)}") |
|
logger.debug(traceback.format_exc()) |
|
self.use_supabase = False |
|
|
|
def _generate_key(self, data: Union[str, Dict[str, Any]]) -> str: |
|
""" |
|
Generate a unique key for the provided data. |
|
|
|
Args: |
|
data: Data to generate key for |
|
|
|
Returns: |
|
Unique string key |
|
""" |
|
if isinstance(data, dict): |
|
|
|
data_str = json.dumps(data, sort_keys=True) |
|
else: |
|
data_str = str(data) |
|
|
|
return hashlib.md5(data_str.encode()).hexdigest() |
|
|
|
def store_local(self, key: str, data: Any, ttl: Optional[int] = None) -> bool: |
|
""" |
|
Store data in local cache with optional TTL. |
|
|
|
Args: |
|
key: Cache key |
|
data: Data to store |
|
ttl: Time to live in seconds (optional) |
|
|
|
Returns: |
|
bool: True if operation succeeded |
|
""" |
|
try: |
|
ttl = ttl or self.cache_ttl |
|
|
|
cache_entry = { |
|
"data": data, |
|
"timestamp": time.time(), |
|
"expires": time.time() + ttl |
|
} |
|
|
|
|
|
self.local_cache[key] = cache_entry |
|
|
|
|
|
cache_file = os.path.join(self.memory_dir, f"{key}.json") |
|
with open(cache_file, 'w') as f: |
|
json.dump(cache_entry, f) |
|
|
|
logger.debug(f"Stored data with key '{key}' locally") |
|
return True |
|
|
|
except Exception as e: |
|
logger.error(f"Error storing data locally: {str(e)}") |
|
return False |
|
|
|
def retrieve_local(self, key: str) -> Optional[Any]: |
|
""" |
|
Retrieve data from local cache if available and not expired. |
|
|
|
Args: |
|
key: Cache key |
|
|
|
Returns: |
|
Cached data or None if not found or expired |
|
""" |
|
try: |
|
|
|
if key in self.local_cache: |
|
entry = self.local_cache[key] |
|
|
|
|
|
if entry["expires"] > time.time(): |
|
logger.debug(f"Retrieved data for key '{key}' from memory cache") |
|
return entry["data"] |
|
else: |
|
|
|
del self.local_cache[key] |
|
|
|
|
|
cache_file = os.path.join(self.memory_dir, f"{key}.json") |
|
if os.path.exists(cache_file): |
|
with open(cache_file, 'r') as f: |
|
entry = json.load(f) |
|
|
|
|
|
if entry["expires"] > time.time(): |
|
|
|
self.local_cache[key] = entry |
|
logger.debug(f"Retrieved data for key '{key}' from disk cache") |
|
return entry["data"] |
|
else: |
|
|
|
os.remove(cache_file) |
|
|
|
logger.debug(f"No valid cache entry found for key '{key}'") |
|
return None |
|
|
|
except Exception as e: |
|
logger.error(f"Error retrieving data locally: {str(e)}") |
|
return None |
|
|
|
def store_in_supabase(self, interaction_data: Dict[str, Any]) -> bool: |
|
""" |
|
Store interaction data in Supabase. |
|
|
|
Args: |
|
interaction_data: Interaction data to store |
|
|
|
Returns: |
|
bool: True if operation succeeded |
|
""" |
|
if not self.use_supabase or not self.supabase_client: |
|
logger.warning("Supabase storage requested but not available") |
|
return False |
|
|
|
try: |
|
|
|
required_fields = ["user_id", "question"] |
|
for field in required_fields: |
|
if field not in interaction_data: |
|
interaction_data[field] = f"unknown_{field}_{time.time()}" |
|
|
|
|
|
if "timestamp" not in interaction_data: |
|
interaction_data["timestamp"] = time.time() |
|
|
|
|
|
result = self.supabase_client.table("interactions").insert(interaction_data).execute() |
|
|
|
if "error" in result: |
|
logger.error(f"Error storing in Supabase: {result.get('error')}") |
|
return False |
|
|
|
logger.info(f"Stored interaction data in Supabase") |
|
return True |
|
|
|
except Exception as e: |
|
logger.error(f"Error storing data in Supabase: {str(e)}") |
|
logger.debug(traceback.format_exc()) |
|
return False |
|
|
|
def retrieve_from_supabase(self, filters: Dict[str, Any], limit: int = 10) -> List[Dict[str, Any]]: |
|
""" |
|
Retrieve interaction data from Supabase using filters. |
|
|
|
Args: |
|
filters: Field-value pairs to filter by |
|
limit: Maximum number of records to return |
|
|
|
Returns: |
|
List of matching interaction records |
|
""" |
|
if not self.use_supabase or not self.supabase_client: |
|
logger.warning("Supabase retrieval requested but not available") |
|
return [] |
|
|
|
try: |
|
|
|
query = self.supabase_client.table("interactions").select("*") |
|
|
|
|
|
for field, value in filters.items(): |
|
query = query.eq(field, value) |
|
|
|
|
|
query = query.order("timestamp", desc=True).limit(limit) |
|
|
|
|
|
result = query.execute() |
|
|
|
if "error" in result: |
|
logger.error(f"Error retrieving from Supabase: {result.get('error')}") |
|
return [] |
|
|
|
data = result.get("data", []) |
|
logger.info(f"Retrieved {len(data)} records from Supabase") |
|
return data |
|
|
|
except Exception as e: |
|
logger.error(f"Error retrieving data from Supabase: {str(e)}") |
|
logger.debug(traceback.format_exc()) |
|
return [] |
|
|
|
def cache_question_answer(self, question: str, answer: str, metadata: Optional[Dict[str, Any]] = None) -> bool: |
|
""" |
|
Cache a question-answer pair both locally and in Supabase if enabled. |
|
|
|
Args: |
|
question: The question text |
|
answer: The answer text |
|
metadata: Additional metadata about the interaction |
|
|
|
Returns: |
|
bool: True if operation succeeded locally or in Supabase |
|
""" |
|
if not self.cache_enabled: |
|
return False |
|
|
|
try: |
|
metadata = metadata or {} |
|
|
|
|
|
cache_key = self._generate_key(question) |
|
|
|
|
|
qa_data = { |
|
"question": question, |
|
"answer": answer, |
|
"timestamp": time.time(), |
|
"metadata": metadata |
|
} |
|
|
|
|
|
local_success = self.store_local(cache_key, qa_data) |
|
|
|
|
|
supabase_success = False |
|
if self.use_supabase: |
|
supabase_data = { |
|
"user_id": metadata.get("user_id", "anonymous"), |
|
"question": question, |
|
"answer": answer, |
|
"timestamp": time.time(), |
|
"metadata": json.dumps(metadata) |
|
} |
|
supabase_success = self.store_in_supabase(supabase_data) |
|
|
|
return local_success or supabase_success |
|
|
|
except Exception as e: |
|
logger.error(f"Error caching question-answer: {str(e)}") |
|
return False |
|
|
|
def get_cached_answer(self, question: str) -> Optional[str]: |
|
""" |
|
Retrieve cached answer for a question if available. |
|
|
|
Args: |
|
question: The question text |
|
|
|
Returns: |
|
Cached answer or None if not found |
|
""" |
|
if not self.cache_enabled: |
|
return None |
|
|
|
try: |
|
|
|
cache_key = self._generate_key(question) |
|
|
|
|
|
cached_data = self.retrieve_local(cache_key) |
|
|
|
if cached_data and "answer" in cached_data: |
|
logger.info(f"Retrieved cached answer for question") |
|
return cached_data["answer"] |
|
|
|
|
|
if self.use_supabase: |
|
|
|
filters = {"question": question} |
|
supabase_results = self.retrieve_from_supabase(filters, limit=1) |
|
|
|
if supabase_results and len(supabase_results) > 0: |
|
result = supabase_results[0] |
|
answer = result.get("answer") |
|
|
|
if answer: |
|
|
|
qa_data = { |
|
"question": question, |
|
"answer": answer, |
|
"timestamp": time.time(), |
|
"metadata": {"source": "supabase_retrieval"} |
|
} |
|
self.store_local(cache_key, qa_data) |
|
|
|
logger.info(f"Retrieved answer from Supabase") |
|
return answer |
|
|
|
return None |
|
|
|
except Exception as e: |
|
logger.error(f"Error retrieving cached answer: {str(e)}") |
|
return None |
|
|
|
def get_similar_questions(self, question: str, limit: int = 5) -> List[Dict[str, Any]]: |
|
""" |
|
Retrieve questions similar to the input question. |
|
|
|
Args: |
|
question: The question text to find similarities for |
|
limit: Maximum number of similar questions to return |
|
|
|
Returns: |
|
List of similar question-answer pairs with metadata |
|
""" |
|
if not self.use_supabase or not self.supabase_client: |
|
logger.warning("Similar questions retrieval requested but Supabase not available") |
|
return [] |
|
|
|
|
|
|
|
try: |
|
|
|
stopwords = {"the", "is", "are", "a", "an", "in", "on", "at", "by", "for", "with", "about"} |
|
question_words = set(question.lower().split()) - stopwords |
|
|
|
if not question_words: |
|
return [] |
|
|
|
similar_questions = [] |
|
|
|
|
|
all_questions = self.retrieve_from_supabase({}, limit=100) |
|
|
|
for entry in all_questions: |
|
entry_question = entry.get("question", "").lower() |
|
entry_words = set(entry_question.split()) - stopwords |
|
|
|
|
|
common_words = question_words.intersection(entry_words) |
|
if not common_words: |
|
continue |
|
|
|
similarity = len(common_words) / max(len(question_words), len(entry_words)) |
|
|
|
if similarity > 0.3: |
|
similar_questions.append({ |
|
"question": entry.get("question"), |
|
"answer": entry.get("answer"), |
|
"timestamp": entry.get("timestamp"), |
|
"similarity": similarity |
|
}) |
|
|
|
|
|
similar_questions.sort(key=lambda x: x.get("similarity", 0), reverse=True) |
|
return similar_questions[:limit] |
|
|
|
except Exception as e: |
|
logger.error(f"Error retrieving similar questions: {str(e)}") |
|
return [] |
|
|
|
def clear_cache(self) -> bool: |
|
""" |
|
Clear all local cache data. |
|
|
|
Returns: |
|
bool: True if operation succeeded |
|
""" |
|
try: |
|
|
|
self.local_cache = {} |
|
|
|
|
|
for filename in os.listdir(self.memory_dir): |
|
if filename.endswith(".json"): |
|
file_path = os.path.join(self.memory_dir, filename) |
|
os.remove(file_path) |
|
|
|
logger.info("Cache cleared successfully") |
|
return True |
|
|
|
except Exception as e: |
|
logger.error(f"Error clearing cache: {str(e)}") |
|
return False |