final / memory_system.py
yoshizen's picture
Upload 4 files
162ee47 verified
"""
Enhanced Memory System for GAIA-Ready AI Agent
This module provides an advanced memory system for the AI agent,
including short-term, long-term, and working memory components,
as well as semantic retrieval capabilities.
"""
import os
import json
from typing import List, Dict, Any, Optional, Union
from datetime import datetime
import re
import numpy as np
from collections import defaultdict
try:
from sentence_transformers import SentenceTransformer
except ImportError:
import subprocess
subprocess.check_call(["pip", "install", "sentence-transformers"])
from sentence_transformers import SentenceTransformer
class EnhancedMemoryManager:
"""
Advanced memory manager for the agent that maintains short-term, long-term,
and working memory with semantic retrieval capabilities.
"""
def __init__(self, use_semantic_search=True):
self.short_term_memory = [] # Current conversation context
self.long_term_memory = [] # Key facts and results
self.working_memory = {} # Temporary storage for complex tasks
self.max_short_term_items = 15
self.max_long_term_items = 100
self.use_semantic_search = use_semantic_search
# Initialize semantic search if enabled
if self.use_semantic_search:
try:
self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
self.memory_embeddings = []
except Exception as e:
print(f"Warning: Could not initialize semantic search: {str(e)}")
self.use_semantic_search = False
# Memory persistence
self.memory_file = "agent_memory.json"
self.load_memories()
def add_to_short_term(self, item: Dict[str, Any]) -> None:
"""Add an item to short-term memory, maintaining size limit"""
# Ensure item has all required fields
if "content" not in item:
raise ValueError("Memory item must have 'content' field")
if "timestamp" not in item:
item["timestamp"] = datetime.now().isoformat()
if "type" not in item:
item["type"] = "general"
self.short_term_memory.append(item)
# Update semantic embeddings if enabled
if self.use_semantic_search:
try:
content = item.get("content", "")
embedding = self.embedding_model.encode(content)
self.memory_embeddings.append((embedding, len(self.short_term_memory) - 1, "short_term"))
except Exception as e:
print(f"Warning: Could not create embedding for memory item: {str(e)}")
# Maintain size limit
if len(self.short_term_memory) > self.max_short_term_items:
removed_item = self.short_term_memory.pop(0)
# Remove corresponding embedding if it exists
if self.use_semantic_search:
self.memory_embeddings = [(emb, idx, mem_type) for emb, idx, mem_type in self.memory_embeddings
if not (mem_type == "short_term" and idx == 0)]
# Update indices for remaining short-term memories
self.memory_embeddings = [(emb, idx-1 if mem_type == "short_term" else idx, mem_type)
for emb, idx, mem_type in self.memory_embeddings]
# Save memories periodically
self.save_memories()
def add_to_long_term(self, item: Dict[str, Any]) -> None:
"""Add an important item to long-term memory, maintaining size limit"""
# Ensure item has all required fields
if "content" not in item:
raise ValueError("Memory item must have 'content' field")
if "timestamp" not in item:
item["timestamp"] = datetime.now().isoformat()
if "type" not in item:
item["type"] = "general"
# Add importance score if not present
if "importance" not in item:
# Calculate importance based on content length and type
content_length = len(item.get("content", ""))
type_importance = {
"final_answer": 0.9,
"key_fact": 0.8,
"reasoning": 0.7,
"general": 0.5
}
item["importance"] = min(1.0, (content_length / 1000) * type_importance.get(item["type"], 0.5))
self.long_term_memory.append(item)
# Update semantic embeddings if enabled
if self.use_semantic_search:
try:
content = item.get("content", "")
embedding = self.embedding_model.encode(content)
self.memory_embeddings.append((embedding, len(self.long_term_memory) - 1, "long_term"))
except Exception as e:
print(f"Warning: Could not create embedding for memory item: {str(e)}")
# Sort long-term memory by importance (descending)
self.long_term_memory.sort(key=lambda x: x.get("importance", 0), reverse=True)
# Maintain size limit
if len(self.long_term_memory) > self.max_long_term_items:
# Remove least important memory
removed_item = self.long_term_memory.pop()
# Remove corresponding embedding if it exists
if self.use_semantic_search:
self.memory_embeddings = [(emb, idx, mem_type) for emb, idx, mem_type in self.memory_embeddings
if not (mem_type == "long_term" and idx == len(self.long_term_memory))]
# Update indices for remaining long-term memories
# This is more complex since we sorted by importance, so we need to rebuild indices
long_term_embeddings = []
for i, item in enumerate(self.long_term_memory):
content = item.get("content", "")
embedding = self.embedding_model.encode(content)
long_term_embeddings.append((embedding, i, "long_term"))
# Keep short-term embeddings and replace long-term ones
self.memory_embeddings = [(emb, idx, mem_type) for emb, idx, mem_type in self.memory_embeddings
if mem_type == "short_term"] + long_term_embeddings
# Save memories periodically
self.save_memories()
def store_in_working_memory(self, key: str, value: Any) -> None:
"""Store a value in working memory under the specified key"""
self.working_memory[key] = value
# Working memory is not persisted between sessions
def get_from_working_memory(self, key: str) -> Optional[Any]:
"""Retrieve a value from working memory by key"""
return self.working_memory.get(key)
def clear_working_memory(self) -> None:
"""Clear the working memory"""
self.working_memory = {}
def get_relevant_memories(self, query: str, max_results: int = 10) -> List[Dict[str, Any]]:
"""
Retrieve memories relevant to the current query
Args:
query: The query to find relevant memories for
max_results: Maximum number of results to return
Returns:
List of relevant memory items
"""
if self.use_semantic_search:
try:
# Use semantic search to find relevant memories
query_embedding = self.embedding_model.encode(query)
# Calculate cosine similarity with all memory embeddings
similarities = []
for embedding, idx, mem_type in self.memory_embeddings:
similarity = np.dot(query_embedding, embedding) / (np.linalg.norm(query_embedding) * np.linalg.norm(embedding))
similarities.append((similarity, idx, mem_type))
# Sort by similarity (descending)
similarities.sort(reverse=True)
# Get top results
relevant_memories = []
for similarity, idx, mem_type in similarities[:max_results]:
if mem_type == "short_term":
memory = self.short_term_memory[idx]
else: # long_term
memory = self.long_term_memory[idx]
# Add similarity score to memory item
memory_with_score = memory.copy()
memory_with_score["relevance_score"] = float(similarity)
relevant_memories.append(memory_with_score)
return relevant_memories
except Exception as e:
print(f"Warning: Semantic search failed: {str(e)}. Falling back to keyword search.")
return self._keyword_search(query, max_results)
else:
return self._keyword_search(query, max_results)
def _keyword_search(self, query: str, max_results: int = 10) -> List[Dict[str, Any]]:
"""
Fallback keyword-based search for relevant memories
Args:
query: The query to find relevant memories for
max_results: Maximum number of results to return
Returns:
List of relevant memory items
"""
relevant_memories = []
query_keywords = set(re.findall(r'\b\w+\b', query.lower()))
# Score function for keyword matching
def score_memory(memory):
content = memory.get("content", "").lower()
content_words = set(re.findall(r'\b\w+\b', content))
# Count matching keywords
matches = len(query_keywords.intersection(content_words))
# Consider memory type and recency
type_boost = {
"final_answer": 2.0,
"key_fact": 1.5,
"reasoning": 1.2,
"general": 1.0
}
# Calculate recency (assuming ISO format timestamps)
try:
timestamp = datetime.fromisoformat(memory.get("timestamp", "2000-01-01T00:00:00"))
now = datetime.now()
hours_ago = (now - timestamp).total_seconds() / 3600
recency_factor = max(0.5, 1.0 - (hours_ago / 24)) # Decay over 24 hours
except:
recency_factor = 0.5
# Calculate final score
score = matches * type_boost.get(memory.get("type", "general"), 1.0) * recency_factor
return score
# Score all memories
scored_memories = []
# Check long-term memory first (more important)
for memory in self.long_term_memory:
score = score_memory(memory)
if score > 0:
memory_with_score = memory.copy()
memory_with_score["relevance_score"] = score
scored_memories.append((score, memory_with_score))
# Then check short-term memory
for memory in self.short_term_memory:
score = score_memory(memory)
if score > 0:
memory_with_score = memory.copy()
memory_with_score["relevance_score"] = score
scored_memories.append((score, memory_with_score))
# Sort by score (descending) and take top results
scored_memories.sort(reverse=True, key=lambda x: x[0])
relevant_memories = [memory for _, memory in scored_memories[:max_results]]
return relevant_memories
def get_memory_summary(self) -> str:
"""Get a summary of the current memory state for the agent"""
# Get most recent short-term memories
recent_short_term = self.short_term_memory[-5:] if self.short_term_memory else []
short_term_summary = "\n".join([f"- [{m.get('type', 'general')}] {m.get('content', '')[:100]}..."
for m in recent_short_term])
# Get most important long-term memories
important_long_term = sorted(self.long_term_memory,
key=lambda x: x.get("importance", 0),
reverse=True)[:5] if self.long_term_memory else []
long_term_summary = "\n".join([f"- [{m.get('type', 'general')}] {m.get('content', '')[:100]}..."
for m in important_long_term])
# Summarize working memory
working_memory_summary = "\n".join([f"- {k}: {str(v)[:50]}..." if isinstance(v, str) and len(str(v)) > 50
else f"- {k}: {v}" for k, v in self.working_memory.items()])
return f"""
MEMORY SUMMARY:
--------------
Recent Short-Term Memory:
{short_term_summary if short_term_summary else "No recent short-term memories."}
Important Long-Term Memory:
{long_term_summary if long_term_summary else "No important long-term memories."}
Working Memory:
{working_memory_summary if working_memory_summary else "Working memory is empty."}
"""
def save_memories(self) -> None:
"""Save memories to disk for persistence"""
try:
# Only save short-term and long-term memories (not working memory)
memories = {
"short_term": self.short_term_memory,
"long_term": self.long_term_memory,
"last_updated": datetime.now().isoformat()
}
with open(self.memory_file, 'w') as f:
json.dump(memories, f, indent=2)
except Exception as e:
print(f"Warning: Could not save memories: {str(e)}")
def load_memories(self) -> None:
"""Load memories from disk if available"""
try:
if os.path.exists(self.memory_file):
with open(self.memory_file, 'r') as f:
memories = json.load(f)
self.short_term_memory = memories.get("short_term", [])
self.long_term_memory = memories.get("long_term", [])
# Rebuild embeddings if semantic search is enabled
if self.use_semantic_search:
self.memory_embeddings = []
# Add embeddings for short-term memories
for i, memory in enumerate(self.short_term_memory):
try:
content = memory.get("content", "")
embedding = self.embedding_model.encode(content)
self.memory_embeddings.append((embedding, i, "short_term"))
except Exception as e:
print(f"Warning: Could not create embedding for memory item: {str(e)}")
# Add embeddings for long-term memories
for i, memory in enumerate(self.long_term_memory):
try:
content = memory.get("content", "")
embedding = self.embedding_model.encode(content)
self.memory_embeddings.append((embedding, i, "long_term"))
except Exception as e:
print(f"Warning: Could not create embedding for memory item: {str(e)}")
print(f"Loaded {len(self.short_term_memory)} short-term and {len(self.long_term_memory)} long-term memories.")
except Exception as e:
print(f"Warning: Could not load memories: {str(e)}")
def forget_old_memories(self, days_threshold: int = 30) -> None:
"""
Remove memories older than the specified threshold
Args:
days_threshold: Age threshold in days
"""
try:
now = datetime.now()
threshold = days_threshold * 24 * 60 * 60 # Convert to seconds
# Filter short-term memories
new_short_term = []
for i, memory in enumerate(self.short_term_memory):
try:
timestamp = datetime.fromisoformat(memory.get("timestamp", "2000-01-01T00:00:00"))
age = (now - timestamp).total_seconds()
if age < threshold:
new_short_term.append(memory)
except:
# Keep memories with invalid timestamps
new_short_term.append(memory)
# Filter long-term memories
new_long_term = []
for i, memory in enumerate(self.long_term_memory):
try:
timestamp = datetime.fromisoformat(memory.get("timestamp", "2000-01-01T00:00:00"))
age = (now - timestamp).total_seconds()
# For long-term, also consider importance
importance = memory.get("importance", 0.5)
# More important memories have a higher threshold
adjusted_threshold = threshold * (1 + importance)
if age < adjusted_threshold:
new_long_term.append(memory)
except:
# Keep memories with invalid timestamps
new_long_term.append(memory)
# Update memories
removed_short_term = len(self.short_term_memory) - len(new_short_term)
removed_long_term = len(self.long_term_memory) - len(new_long_term)
self.short_term_memory = new_short_term
self.long_term_memory = new_long_term
# Rebuild embeddings if semantic search is enabled
if self.use_semantic_search:
self.memory_embeddings = []
# Add embeddings for short-term memories
for i, memory in enumerate(self.short_term_memory):
try:
content = memory.get("content", "")
embedding = self.embedding_model.encode(content)
self.memory_embeddings.append((embedding, i, "short_term"))
except Exception as e:
print(f"Warning: Could not create embedding for memory item: {str(e)}")
# Add embeddings for long-term memories
for i, memory in enumerate(self.long_term_memory):
try:
content = memory.get("content", "")
embedding = self.embedding_model.encode(content)
self.memory_embeddings.append((embedding, i, "long_term"))
except Exception as e:
print(f"Warning: Could not create embedding for memory item: {str(e)}")
# Save updated memories
self.save_memories()
print(f"Forgot {removed_short_term} short-term and {removed_long_term} long-term memories older than {days_threshold} days.")
except Exception as e:
print(f"Warning: Could not forget old memories: {str(e)}")
# Example usage
if __name__ == "__main__":
# Initialize the memory manager
memory_manager = EnhancedMemoryManager(use_semantic_search=True)
# Add some test memories
memory_manager.add_to_short_term({
"type": "query",
"content": "What is the capital of France?",
"timestamp": datetime.now().isoformat()
})
memory_manager.add_to_long_term({
"type": "key_fact",
"content": "Paris is the capital of France with a population of about 2.2 million people.",
"timestamp": datetime.now().isoformat()
})
memory_manager.store_in_working_memory("current_task", "Finding information about France")
# Test retrieval
relevant_memories = memory_manager.get_relevant_memories("What is the population of Paris?")
print("\nRelevant memories for 'What is the population of Paris?':")
for memory in relevant_memories:
print(f"- Score: {memory.get('relevance_score', 0):.2f}, Content: {memory.get('content', '')}")
# Print memory summary
print("\nMemory Summary:")
print(memory_manager.get_memory_summary())