File size: 20,643 Bytes
162ee47 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 |
"""
Enhanced Memory System for GAIA-Ready AI Agent
This module provides an advanced memory system for the AI agent,
including short-term, long-term, and working memory components,
as well as semantic retrieval capabilities.
"""
import os
import json
from typing import List, Dict, Any, Optional, Union
from datetime import datetime
import re
import numpy as np
from collections import defaultdict
try:
from sentence_transformers import SentenceTransformer
except ImportError:
import subprocess
subprocess.check_call(["pip", "install", "sentence-transformers"])
from sentence_transformers import SentenceTransformer
class EnhancedMemoryManager:
"""
Advanced memory manager for the agent that maintains short-term, long-term,
and working memory with semantic retrieval capabilities.
"""
def __init__(self, use_semantic_search=True):
self.short_term_memory = [] # Current conversation context
self.long_term_memory = [] # Key facts and results
self.working_memory = {} # Temporary storage for complex tasks
self.max_short_term_items = 15
self.max_long_term_items = 100
self.use_semantic_search = use_semantic_search
# Initialize semantic search if enabled
if self.use_semantic_search:
try:
self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
self.memory_embeddings = []
except Exception as e:
print(f"Warning: Could not initialize semantic search: {str(e)}")
self.use_semantic_search = False
# Memory persistence
self.memory_file = "agent_memory.json"
self.load_memories()
def add_to_short_term(self, item: Dict[str, Any]) -> None:
"""Add an item to short-term memory, maintaining size limit"""
# Ensure item has all required fields
if "content" not in item:
raise ValueError("Memory item must have 'content' field")
if "timestamp" not in item:
item["timestamp"] = datetime.now().isoformat()
if "type" not in item:
item["type"] = "general"
self.short_term_memory.append(item)
# Update semantic embeddings if enabled
if self.use_semantic_search:
try:
content = item.get("content", "")
embedding = self.embedding_model.encode(content)
self.memory_embeddings.append((embedding, len(self.short_term_memory) - 1, "short_term"))
except Exception as e:
print(f"Warning: Could not create embedding for memory item: {str(e)}")
# Maintain size limit
if len(self.short_term_memory) > self.max_short_term_items:
removed_item = self.short_term_memory.pop(0)
# Remove corresponding embedding if it exists
if self.use_semantic_search:
self.memory_embeddings = [(emb, idx, mem_type) for emb, idx, mem_type in self.memory_embeddings
if not (mem_type == "short_term" and idx == 0)]
# Update indices for remaining short-term memories
self.memory_embeddings = [(emb, idx-1 if mem_type == "short_term" else idx, mem_type)
for emb, idx, mem_type in self.memory_embeddings]
# Save memories periodically
self.save_memories()
def add_to_long_term(self, item: Dict[str, Any]) -> None:
"""Add an important item to long-term memory, maintaining size limit"""
# Ensure item has all required fields
if "content" not in item:
raise ValueError("Memory item must have 'content' field")
if "timestamp" not in item:
item["timestamp"] = datetime.now().isoformat()
if "type" not in item:
item["type"] = "general"
# Add importance score if not present
if "importance" not in item:
# Calculate importance based on content length and type
content_length = len(item.get("content", ""))
type_importance = {
"final_answer": 0.9,
"key_fact": 0.8,
"reasoning": 0.7,
"general": 0.5
}
item["importance"] = min(1.0, (content_length / 1000) * type_importance.get(item["type"], 0.5))
self.long_term_memory.append(item)
# Update semantic embeddings if enabled
if self.use_semantic_search:
try:
content = item.get("content", "")
embedding = self.embedding_model.encode(content)
self.memory_embeddings.append((embedding, len(self.long_term_memory) - 1, "long_term"))
except Exception as e:
print(f"Warning: Could not create embedding for memory item: {str(e)}")
# Sort long-term memory by importance (descending)
self.long_term_memory.sort(key=lambda x: x.get("importance", 0), reverse=True)
# Maintain size limit
if len(self.long_term_memory) > self.max_long_term_items:
# Remove least important memory
removed_item = self.long_term_memory.pop()
# Remove corresponding embedding if it exists
if self.use_semantic_search:
self.memory_embeddings = [(emb, idx, mem_type) for emb, idx, mem_type in self.memory_embeddings
if not (mem_type == "long_term" and idx == len(self.long_term_memory))]
# Update indices for remaining long-term memories
# This is more complex since we sorted by importance, so we need to rebuild indices
long_term_embeddings = []
for i, item in enumerate(self.long_term_memory):
content = item.get("content", "")
embedding = self.embedding_model.encode(content)
long_term_embeddings.append((embedding, i, "long_term"))
# Keep short-term embeddings and replace long-term ones
self.memory_embeddings = [(emb, idx, mem_type) for emb, idx, mem_type in self.memory_embeddings
if mem_type == "short_term"] + long_term_embeddings
# Save memories periodically
self.save_memories()
def store_in_working_memory(self, key: str, value: Any) -> None:
"""Store a value in working memory under the specified key"""
self.working_memory[key] = value
# Working memory is not persisted between sessions
def get_from_working_memory(self, key: str) -> Optional[Any]:
"""Retrieve a value from working memory by key"""
return self.working_memory.get(key)
def clear_working_memory(self) -> None:
"""Clear the working memory"""
self.working_memory = {}
def get_relevant_memories(self, query: str, max_results: int = 10) -> List[Dict[str, Any]]:
"""
Retrieve memories relevant to the current query
Args:
query: The query to find relevant memories for
max_results: Maximum number of results to return
Returns:
List of relevant memory items
"""
if self.use_semantic_search:
try:
# Use semantic search to find relevant memories
query_embedding = self.embedding_model.encode(query)
# Calculate cosine similarity with all memory embeddings
similarities = []
for embedding, idx, mem_type in self.memory_embeddings:
similarity = np.dot(query_embedding, embedding) / (np.linalg.norm(query_embedding) * np.linalg.norm(embedding))
similarities.append((similarity, idx, mem_type))
# Sort by similarity (descending)
similarities.sort(reverse=True)
# Get top results
relevant_memories = []
for similarity, idx, mem_type in similarities[:max_results]:
if mem_type == "short_term":
memory = self.short_term_memory[idx]
else: # long_term
memory = self.long_term_memory[idx]
# Add similarity score to memory item
memory_with_score = memory.copy()
memory_with_score["relevance_score"] = float(similarity)
relevant_memories.append(memory_with_score)
return relevant_memories
except Exception as e:
print(f"Warning: Semantic search failed: {str(e)}. Falling back to keyword search.")
return self._keyword_search(query, max_results)
else:
return self._keyword_search(query, max_results)
def _keyword_search(self, query: str, max_results: int = 10) -> List[Dict[str, Any]]:
"""
Fallback keyword-based search for relevant memories
Args:
query: The query to find relevant memories for
max_results: Maximum number of results to return
Returns:
List of relevant memory items
"""
relevant_memories = []
query_keywords = set(re.findall(r'\b\w+\b', query.lower()))
# Score function for keyword matching
def score_memory(memory):
content = memory.get("content", "").lower()
content_words = set(re.findall(r'\b\w+\b', content))
# Count matching keywords
matches = len(query_keywords.intersection(content_words))
# Consider memory type and recency
type_boost = {
"final_answer": 2.0,
"key_fact": 1.5,
"reasoning": 1.2,
"general": 1.0
}
# Calculate recency (assuming ISO format timestamps)
try:
timestamp = datetime.fromisoformat(memory.get("timestamp", "2000-01-01T00:00:00"))
now = datetime.now()
hours_ago = (now - timestamp).total_seconds() / 3600
recency_factor = max(0.5, 1.0 - (hours_ago / 24)) # Decay over 24 hours
except:
recency_factor = 0.5
# Calculate final score
score = matches * type_boost.get(memory.get("type", "general"), 1.0) * recency_factor
return score
# Score all memories
scored_memories = []
# Check long-term memory first (more important)
for memory in self.long_term_memory:
score = score_memory(memory)
if score > 0:
memory_with_score = memory.copy()
memory_with_score["relevance_score"] = score
scored_memories.append((score, memory_with_score))
# Then check short-term memory
for memory in self.short_term_memory:
score = score_memory(memory)
if score > 0:
memory_with_score = memory.copy()
memory_with_score["relevance_score"] = score
scored_memories.append((score, memory_with_score))
# Sort by score (descending) and take top results
scored_memories.sort(reverse=True, key=lambda x: x[0])
relevant_memories = [memory for _, memory in scored_memories[:max_results]]
return relevant_memories
def get_memory_summary(self) -> str:
"""Get a summary of the current memory state for the agent"""
# Get most recent short-term memories
recent_short_term = self.short_term_memory[-5:] if self.short_term_memory else []
short_term_summary = "\n".join([f"- [{m.get('type', 'general')}] {m.get('content', '')[:100]}..."
for m in recent_short_term])
# Get most important long-term memories
important_long_term = sorted(self.long_term_memory,
key=lambda x: x.get("importance", 0),
reverse=True)[:5] if self.long_term_memory else []
long_term_summary = "\n".join([f"- [{m.get('type', 'general')}] {m.get('content', '')[:100]}..."
for m in important_long_term])
# Summarize working memory
working_memory_summary = "\n".join([f"- {k}: {str(v)[:50]}..." if isinstance(v, str) and len(str(v)) > 50
else f"- {k}: {v}" for k, v in self.working_memory.items()])
return f"""
MEMORY SUMMARY:
--------------
Recent Short-Term Memory:
{short_term_summary if short_term_summary else "No recent short-term memories."}
Important Long-Term Memory:
{long_term_summary if long_term_summary else "No important long-term memories."}
Working Memory:
{working_memory_summary if working_memory_summary else "Working memory is empty."}
"""
def save_memories(self) -> None:
"""Save memories to disk for persistence"""
try:
# Only save short-term and long-term memories (not working memory)
memories = {
"short_term": self.short_term_memory,
"long_term": self.long_term_memory,
"last_updated": datetime.now().isoformat()
}
with open(self.memory_file, 'w') as f:
json.dump(memories, f, indent=2)
except Exception as e:
print(f"Warning: Could not save memories: {str(e)}")
def load_memories(self) -> None:
"""Load memories from disk if available"""
try:
if os.path.exists(self.memory_file):
with open(self.memory_file, 'r') as f:
memories = json.load(f)
self.short_term_memory = memories.get("short_term", [])
self.long_term_memory = memories.get("long_term", [])
# Rebuild embeddings if semantic search is enabled
if self.use_semantic_search:
self.memory_embeddings = []
# Add embeddings for short-term memories
for i, memory in enumerate(self.short_term_memory):
try:
content = memory.get("content", "")
embedding = self.embedding_model.encode(content)
self.memory_embeddings.append((embedding, i, "short_term"))
except Exception as e:
print(f"Warning: Could not create embedding for memory item: {str(e)}")
# Add embeddings for long-term memories
for i, memory in enumerate(self.long_term_memory):
try:
content = memory.get("content", "")
embedding = self.embedding_model.encode(content)
self.memory_embeddings.append((embedding, i, "long_term"))
except Exception as e:
print(f"Warning: Could not create embedding for memory item: {str(e)}")
print(f"Loaded {len(self.short_term_memory)} short-term and {len(self.long_term_memory)} long-term memories.")
except Exception as e:
print(f"Warning: Could not load memories: {str(e)}")
def forget_old_memories(self, days_threshold: int = 30) -> None:
"""
Remove memories older than the specified threshold
Args:
days_threshold: Age threshold in days
"""
try:
now = datetime.now()
threshold = days_threshold * 24 * 60 * 60 # Convert to seconds
# Filter short-term memories
new_short_term = []
for i, memory in enumerate(self.short_term_memory):
try:
timestamp = datetime.fromisoformat(memory.get("timestamp", "2000-01-01T00:00:00"))
age = (now - timestamp).total_seconds()
if age < threshold:
new_short_term.append(memory)
except:
# Keep memories with invalid timestamps
new_short_term.append(memory)
# Filter long-term memories
new_long_term = []
for i, memory in enumerate(self.long_term_memory):
try:
timestamp = datetime.fromisoformat(memory.get("timestamp", "2000-01-01T00:00:00"))
age = (now - timestamp).total_seconds()
# For long-term, also consider importance
importance = memory.get("importance", 0.5)
# More important memories have a higher threshold
adjusted_threshold = threshold * (1 + importance)
if age < adjusted_threshold:
new_long_term.append(memory)
except:
# Keep memories with invalid timestamps
new_long_term.append(memory)
# Update memories
removed_short_term = len(self.short_term_memory) - len(new_short_term)
removed_long_term = len(self.long_term_memory) - len(new_long_term)
self.short_term_memory = new_short_term
self.long_term_memory = new_long_term
# Rebuild embeddings if semantic search is enabled
if self.use_semantic_search:
self.memory_embeddings = []
# Add embeddings for short-term memories
for i, memory in enumerate(self.short_term_memory):
try:
content = memory.get("content", "")
embedding = self.embedding_model.encode(content)
self.memory_embeddings.append((embedding, i, "short_term"))
except Exception as e:
print(f"Warning: Could not create embedding for memory item: {str(e)}")
# Add embeddings for long-term memories
for i, memory in enumerate(self.long_term_memory):
try:
content = memory.get("content", "")
embedding = self.embedding_model.encode(content)
self.memory_embeddings.append((embedding, i, "long_term"))
except Exception as e:
print(f"Warning: Could not create embedding for memory item: {str(e)}")
# Save updated memories
self.save_memories()
print(f"Forgot {removed_short_term} short-term and {removed_long_term} long-term memories older than {days_threshold} days.")
except Exception as e:
print(f"Warning: Could not forget old memories: {str(e)}")
# Example usage
if __name__ == "__main__":
# Initialize the memory manager
memory_manager = EnhancedMemoryManager(use_semantic_search=True)
# Add some test memories
memory_manager.add_to_short_term({
"type": "query",
"content": "What is the capital of France?",
"timestamp": datetime.now().isoformat()
})
memory_manager.add_to_long_term({
"type": "key_fact",
"content": "Paris is the capital of France with a population of about 2.2 million people.",
"timestamp": datetime.now().isoformat()
})
memory_manager.store_in_working_memory("current_task", "Finding information about France")
# Test retrieval
relevant_memories = memory_manager.get_relevant_memories("What is the population of Paris?")
print("\nRelevant memories for 'What is the population of Paris?':")
for memory in relevant_memories:
print(f"- Score: {memory.get('relevance_score', 0):.2f}, Content: {memory.get('content', '')}")
# Print memory summary
print("\nMemory Summary:")
print(memory_manager.get_memory_summary())
|