""" Enhanced Memory System for GAIA-Ready AI Agent This module provides an advanced memory system for the AI agent, including short-term, long-term, and working memory components, as well as semantic retrieval capabilities. """ import os import json from typing import List, Dict, Any, Optional, Union from datetime import datetime import re import numpy as np from collections import defaultdict try: from sentence_transformers import SentenceTransformer except ImportError: import subprocess subprocess.check_call(["pip", "install", "sentence-transformers"]) from sentence_transformers import SentenceTransformer class EnhancedMemoryManager: """ Advanced memory manager for the agent that maintains short-term, long-term, and working memory with semantic retrieval capabilities. """ def __init__(self, use_semantic_search=True): self.short_term_memory = [] # Current conversation context self.long_term_memory = [] # Key facts and results self.working_memory = {} # Temporary storage for complex tasks self.max_short_term_items = 15 self.max_long_term_items = 100 self.use_semantic_search = use_semantic_search # Initialize semantic search if enabled if self.use_semantic_search: try: self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2') self.memory_embeddings = [] except Exception as e: print(f"Warning: Could not initialize semantic search: {str(e)}") self.use_semantic_search = False # Memory persistence self.memory_file = "agent_memory.json" self.load_memories() def add_to_short_term(self, item: Dict[str, Any]) -> None: """Add an item to short-term memory, maintaining size limit""" # Ensure item has all required fields if "content" not in item: raise ValueError("Memory item must have 'content' field") if "timestamp" not in item: item["timestamp"] = datetime.now().isoformat() if "type" not in item: item["type"] = "general" self.short_term_memory.append(item) # Update semantic embeddings if enabled if self.use_semantic_search: try: content = item.get("content", "") embedding = self.embedding_model.encode(content) self.memory_embeddings.append((embedding, len(self.short_term_memory) - 1, "short_term")) except Exception as e: print(f"Warning: Could not create embedding for memory item: {str(e)}") # Maintain size limit if len(self.short_term_memory) > self.max_short_term_items: removed_item = self.short_term_memory.pop(0) # Remove corresponding embedding if it exists if self.use_semantic_search: self.memory_embeddings = [(emb, idx, mem_type) for emb, idx, mem_type in self.memory_embeddings if not (mem_type == "short_term" and idx == 0)] # Update indices for remaining short-term memories self.memory_embeddings = [(emb, idx-1 if mem_type == "short_term" else idx, mem_type) for emb, idx, mem_type in self.memory_embeddings] # Save memories periodically self.save_memories() def add_to_long_term(self, item: Dict[str, Any]) -> None: """Add an important item to long-term memory, maintaining size limit""" # Ensure item has all required fields if "content" not in item: raise ValueError("Memory item must have 'content' field") if "timestamp" not in item: item["timestamp"] = datetime.now().isoformat() if "type" not in item: item["type"] = "general" # Add importance score if not present if "importance" not in item: # Calculate importance based on content length and type content_length = len(item.get("content", "")) type_importance = { "final_answer": 0.9, "key_fact": 0.8, "reasoning": 0.7, "general": 0.5 } item["importance"] = min(1.0, (content_length / 1000) * type_importance.get(item["type"], 0.5)) self.long_term_memory.append(item) # Update semantic embeddings if enabled if self.use_semantic_search: try: content = item.get("content", "") embedding = self.embedding_model.encode(content) self.memory_embeddings.append((embedding, len(self.long_term_memory) - 1, "long_term")) except Exception as e: print(f"Warning: Could not create embedding for memory item: {str(e)}") # Sort long-term memory by importance (descending) self.long_term_memory.sort(key=lambda x: x.get("importance", 0), reverse=True) # Maintain size limit if len(self.long_term_memory) > self.max_long_term_items: # Remove least important memory removed_item = self.long_term_memory.pop() # Remove corresponding embedding if it exists if self.use_semantic_search: self.memory_embeddings = [(emb, idx, mem_type) for emb, idx, mem_type in self.memory_embeddings if not (mem_type == "long_term" and idx == len(self.long_term_memory))] # Update indices for remaining long-term memories # This is more complex since we sorted by importance, so we need to rebuild indices long_term_embeddings = [] for i, item in enumerate(self.long_term_memory): content = item.get("content", "") embedding = self.embedding_model.encode(content) long_term_embeddings.append((embedding, i, "long_term")) # Keep short-term embeddings and replace long-term ones self.memory_embeddings = [(emb, idx, mem_type) for emb, idx, mem_type in self.memory_embeddings if mem_type == "short_term"] + long_term_embeddings # Save memories periodically self.save_memories() def store_in_working_memory(self, key: str, value: Any) -> None: """Store a value in working memory under the specified key""" self.working_memory[key] = value # Working memory is not persisted between sessions def get_from_working_memory(self, key: str) -> Optional[Any]: """Retrieve a value from working memory by key""" return self.working_memory.get(key) def clear_working_memory(self) -> None: """Clear the working memory""" self.working_memory = {} def get_relevant_memories(self, query: str, max_results: int = 10) -> List[Dict[str, Any]]: """ Retrieve memories relevant to the current query Args: query: The query to find relevant memories for max_results: Maximum number of results to return Returns: List of relevant memory items """ if self.use_semantic_search: try: # Use semantic search to find relevant memories query_embedding = self.embedding_model.encode(query) # Calculate cosine similarity with all memory embeddings similarities = [] for embedding, idx, mem_type in self.memory_embeddings: similarity = np.dot(query_embedding, embedding) / (np.linalg.norm(query_embedding) * np.linalg.norm(embedding)) similarities.append((similarity, idx, mem_type)) # Sort by similarity (descending) similarities.sort(reverse=True) # Get top results relevant_memories = [] for similarity, idx, mem_type in similarities[:max_results]: if mem_type == "short_term": memory = self.short_term_memory[idx] else: # long_term memory = self.long_term_memory[idx] # Add similarity score to memory item memory_with_score = memory.copy() memory_with_score["relevance_score"] = float(similarity) relevant_memories.append(memory_with_score) return relevant_memories except Exception as e: print(f"Warning: Semantic search failed: {str(e)}. Falling back to keyword search.") return self._keyword_search(query, max_results) else: return self._keyword_search(query, max_results) def _keyword_search(self, query: str, max_results: int = 10) -> List[Dict[str, Any]]: """ Fallback keyword-based search for relevant memories Args: query: The query to find relevant memories for max_results: Maximum number of results to return Returns: List of relevant memory items """ relevant_memories = [] query_keywords = set(re.findall(r'\b\w+\b', query.lower())) # Score function for keyword matching def score_memory(memory): content = memory.get("content", "").lower() content_words = set(re.findall(r'\b\w+\b', content)) # Count matching keywords matches = len(query_keywords.intersection(content_words)) # Consider memory type and recency type_boost = { "final_answer": 2.0, "key_fact": 1.5, "reasoning": 1.2, "general": 1.0 } # Calculate recency (assuming ISO format timestamps) try: timestamp = datetime.fromisoformat(memory.get("timestamp", "2000-01-01T00:00:00")) now = datetime.now() hours_ago = (now - timestamp).total_seconds() / 3600 recency_factor = max(0.5, 1.0 - (hours_ago / 24)) # Decay over 24 hours except: recency_factor = 0.5 # Calculate final score score = matches * type_boost.get(memory.get("type", "general"), 1.0) * recency_factor return score # Score all memories scored_memories = [] # Check long-term memory first (more important) for memory in self.long_term_memory: score = score_memory(memory) if score > 0: memory_with_score = memory.copy() memory_with_score["relevance_score"] = score scored_memories.append((score, memory_with_score)) # Then check short-term memory for memory in self.short_term_memory: score = score_memory(memory) if score > 0: memory_with_score = memory.copy() memory_with_score["relevance_score"] = score scored_memories.append((score, memory_with_score)) # Sort by score (descending) and take top results scored_memories.sort(reverse=True, key=lambda x: x[0]) relevant_memories = [memory for _, memory in scored_memories[:max_results]] return relevant_memories def get_memory_summary(self) -> str: """Get a summary of the current memory state for the agent""" # Get most recent short-term memories recent_short_term = self.short_term_memory[-5:] if self.short_term_memory else [] short_term_summary = "\n".join([f"- [{m.get('type', 'general')}] {m.get('content', '')[:100]}..." for m in recent_short_term]) # Get most important long-term memories important_long_term = sorted(self.long_term_memory, key=lambda x: x.get("importance", 0), reverse=True)[:5] if self.long_term_memory else [] long_term_summary = "\n".join([f"- [{m.get('type', 'general')}] {m.get('content', '')[:100]}..." for m in important_long_term]) # Summarize working memory working_memory_summary = "\n".join([f"- {k}: {str(v)[:50]}..." if isinstance(v, str) and len(str(v)) > 50 else f"- {k}: {v}" for k, v in self.working_memory.items()]) return f""" MEMORY SUMMARY: -------------- Recent Short-Term Memory: {short_term_summary if short_term_summary else "No recent short-term memories."} Important Long-Term Memory: {long_term_summary if long_term_summary else "No important long-term memories."} Working Memory: {working_memory_summary if working_memory_summary else "Working memory is empty."} """ def save_memories(self) -> None: """Save memories to disk for persistence""" try: # Only save short-term and long-term memories (not working memory) memories = { "short_term": self.short_term_memory, "long_term": self.long_term_memory, "last_updated": datetime.now().isoformat() } with open(self.memory_file, 'w') as f: json.dump(memories, f, indent=2) except Exception as e: print(f"Warning: Could not save memories: {str(e)}") def load_memories(self) -> None: """Load memories from disk if available""" try: if os.path.exists(self.memory_file): with open(self.memory_file, 'r') as f: memories = json.load(f) self.short_term_memory = memories.get("short_term", []) self.long_term_memory = memories.get("long_term", []) # Rebuild embeddings if semantic search is enabled if self.use_semantic_search: self.memory_embeddings = [] # Add embeddings for short-term memories for i, memory in enumerate(self.short_term_memory): try: content = memory.get("content", "") embedding = self.embedding_model.encode(content) self.memory_embeddings.append((embedding, i, "short_term")) except Exception as e: print(f"Warning: Could not create embedding for memory item: {str(e)}") # Add embeddings for long-term memories for i, memory in enumerate(self.long_term_memory): try: content = memory.get("content", "") embedding = self.embedding_model.encode(content) self.memory_embeddings.append((embedding, i, "long_term")) except Exception as e: print(f"Warning: Could not create embedding for memory item: {str(e)}") print(f"Loaded {len(self.short_term_memory)} short-term and {len(self.long_term_memory)} long-term memories.") except Exception as e: print(f"Warning: Could not load memories: {str(e)}") def forget_old_memories(self, days_threshold: int = 30) -> None: """ Remove memories older than the specified threshold Args: days_threshold: Age threshold in days """ try: now = datetime.now() threshold = days_threshold * 24 * 60 * 60 # Convert to seconds # Filter short-term memories new_short_term = [] for i, memory in enumerate(self.short_term_memory): try: timestamp = datetime.fromisoformat(memory.get("timestamp", "2000-01-01T00:00:00")) age = (now - timestamp).total_seconds() if age < threshold: new_short_term.append(memory) except: # Keep memories with invalid timestamps new_short_term.append(memory) # Filter long-term memories new_long_term = [] for i, memory in enumerate(self.long_term_memory): try: timestamp = datetime.fromisoformat(memory.get("timestamp", "2000-01-01T00:00:00")) age = (now - timestamp).total_seconds() # For long-term, also consider importance importance = memory.get("importance", 0.5) # More important memories have a higher threshold adjusted_threshold = threshold * (1 + importance) if age < adjusted_threshold: new_long_term.append(memory) except: # Keep memories with invalid timestamps new_long_term.append(memory) # Update memories removed_short_term = len(self.short_term_memory) - len(new_short_term) removed_long_term = len(self.long_term_memory) - len(new_long_term) self.short_term_memory = new_short_term self.long_term_memory = new_long_term # Rebuild embeddings if semantic search is enabled if self.use_semantic_search: self.memory_embeddings = [] # Add embeddings for short-term memories for i, memory in enumerate(self.short_term_memory): try: content = memory.get("content", "") embedding = self.embedding_model.encode(content) self.memory_embeddings.append((embedding, i, "short_term")) except Exception as e: print(f"Warning: Could not create embedding for memory item: {str(e)}") # Add embeddings for long-term memories for i, memory in enumerate(self.long_term_memory): try: content = memory.get("content", "") embedding = self.embedding_model.encode(content) self.memory_embeddings.append((embedding, i, "long_term")) except Exception as e: print(f"Warning: Could not create embedding for memory item: {str(e)}") # Save updated memories self.save_memories() print(f"Forgot {removed_short_term} short-term and {removed_long_term} long-term memories older than {days_threshold} days.") except Exception as e: print(f"Warning: Could not forget old memories: {str(e)}") # Example usage if __name__ == "__main__": # Initialize the memory manager memory_manager = EnhancedMemoryManager(use_semantic_search=True) # Add some test memories memory_manager.add_to_short_term({ "type": "query", "content": "What is the capital of France?", "timestamp": datetime.now().isoformat() }) memory_manager.add_to_long_term({ "type": "key_fact", "content": "Paris is the capital of France with a population of about 2.2 million people.", "timestamp": datetime.now().isoformat() }) memory_manager.store_in_working_memory("current_task", "Finding information about France") # Test retrieval relevant_memories = memory_manager.get_relevant_memories("What is the population of Paris?") print("\nRelevant memories for 'What is the population of Paris?':") for memory in relevant_memories: print(f"- Score: {memory.get('relevance_score', 0):.2f}, Content: {memory.get('content', '')}") # Print memory summary print("\nMemory Summary:") print(memory_manager.get_memory_summary())