|
""" |
|
Enhanced Memory System for GAIA-Ready AI Agent |
|
|
|
This module provides an advanced memory system for the AI agent, |
|
including short-term, long-term, and working memory components, |
|
as well as semantic retrieval capabilities. |
|
""" |
|
|
|
import os |
|
import json |
|
from typing import List, Dict, Any, Optional, Union |
|
from datetime import datetime |
|
import re |
|
import numpy as np |
|
from collections import defaultdict |
|
|
|
try: |
|
from sentence_transformers import SentenceTransformer |
|
except ImportError: |
|
import subprocess |
|
subprocess.check_call(["pip", "install", "sentence-transformers"]) |
|
from sentence_transformers import SentenceTransformer |
|
|
|
|
|
class EnhancedMemoryManager: |
|
""" |
|
Advanced memory manager for the agent that maintains short-term, long-term, |
|
and working memory with semantic retrieval capabilities. |
|
""" |
|
def __init__(self, use_semantic_search=True): |
|
self.short_term_memory = [] |
|
self.long_term_memory = [] |
|
self.working_memory = {} |
|
self.max_short_term_items = 15 |
|
self.max_long_term_items = 100 |
|
self.use_semantic_search = use_semantic_search |
|
|
|
|
|
if self.use_semantic_search: |
|
try: |
|
self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2') |
|
self.memory_embeddings = [] |
|
except Exception as e: |
|
print(f"Warning: Could not initialize semantic search: {str(e)}") |
|
self.use_semantic_search = False |
|
|
|
|
|
self.memory_file = "agent_memory.json" |
|
self.load_memories() |
|
|
|
def add_to_short_term(self, item: Dict[str, Any]) -> None: |
|
"""Add an item to short-term memory, maintaining size limit""" |
|
|
|
if "content" not in item: |
|
raise ValueError("Memory item must have 'content' field") |
|
|
|
if "timestamp" not in item: |
|
item["timestamp"] = datetime.now().isoformat() |
|
|
|
if "type" not in item: |
|
item["type"] = "general" |
|
|
|
self.short_term_memory.append(item) |
|
|
|
|
|
if self.use_semantic_search: |
|
try: |
|
content = item.get("content", "") |
|
embedding = self.embedding_model.encode(content) |
|
self.memory_embeddings.append((embedding, len(self.short_term_memory) - 1, "short_term")) |
|
except Exception as e: |
|
print(f"Warning: Could not create embedding for memory item: {str(e)}") |
|
|
|
|
|
if len(self.short_term_memory) > self.max_short_term_items: |
|
removed_item = self.short_term_memory.pop(0) |
|
|
|
if self.use_semantic_search: |
|
self.memory_embeddings = [(emb, idx, mem_type) for emb, idx, mem_type in self.memory_embeddings |
|
if not (mem_type == "short_term" and idx == 0)] |
|
|
|
self.memory_embeddings = [(emb, idx-1 if mem_type == "short_term" else idx, mem_type) |
|
for emb, idx, mem_type in self.memory_embeddings] |
|
|
|
|
|
self.save_memories() |
|
|
|
def add_to_long_term(self, item: Dict[str, Any]) -> None: |
|
"""Add an important item to long-term memory, maintaining size limit""" |
|
|
|
if "content" not in item: |
|
raise ValueError("Memory item must have 'content' field") |
|
|
|
if "timestamp" not in item: |
|
item["timestamp"] = datetime.now().isoformat() |
|
|
|
if "type" not in item: |
|
item["type"] = "general" |
|
|
|
|
|
if "importance" not in item: |
|
|
|
content_length = len(item.get("content", "")) |
|
type_importance = { |
|
"final_answer": 0.9, |
|
"key_fact": 0.8, |
|
"reasoning": 0.7, |
|
"general": 0.5 |
|
} |
|
item["importance"] = min(1.0, (content_length / 1000) * type_importance.get(item["type"], 0.5)) |
|
|
|
self.long_term_memory.append(item) |
|
|
|
|
|
if self.use_semantic_search: |
|
try: |
|
content = item.get("content", "") |
|
embedding = self.embedding_model.encode(content) |
|
self.memory_embeddings.append((embedding, len(self.long_term_memory) - 1, "long_term")) |
|
except Exception as e: |
|
print(f"Warning: Could not create embedding for memory item: {str(e)}") |
|
|
|
|
|
self.long_term_memory.sort(key=lambda x: x.get("importance", 0), reverse=True) |
|
|
|
|
|
if len(self.long_term_memory) > self.max_long_term_items: |
|
|
|
removed_item = self.long_term_memory.pop() |
|
|
|
if self.use_semantic_search: |
|
self.memory_embeddings = [(emb, idx, mem_type) for emb, idx, mem_type in self.memory_embeddings |
|
if not (mem_type == "long_term" and idx == len(self.long_term_memory))] |
|
|
|
|
|
long_term_embeddings = [] |
|
for i, item in enumerate(self.long_term_memory): |
|
content = item.get("content", "") |
|
embedding = self.embedding_model.encode(content) |
|
long_term_embeddings.append((embedding, i, "long_term")) |
|
|
|
|
|
self.memory_embeddings = [(emb, idx, mem_type) for emb, idx, mem_type in self.memory_embeddings |
|
if mem_type == "short_term"] + long_term_embeddings |
|
|
|
|
|
self.save_memories() |
|
|
|
def store_in_working_memory(self, key: str, value: Any) -> None: |
|
"""Store a value in working memory under the specified key""" |
|
self.working_memory[key] = value |
|
|
|
|
|
def get_from_working_memory(self, key: str) -> Optional[Any]: |
|
"""Retrieve a value from working memory by key""" |
|
return self.working_memory.get(key) |
|
|
|
def clear_working_memory(self) -> None: |
|
"""Clear the working memory""" |
|
self.working_memory = {} |
|
|
|
def get_relevant_memories(self, query: str, max_results: int = 10) -> List[Dict[str, Any]]: |
|
""" |
|
Retrieve memories relevant to the current query |
|
|
|
Args: |
|
query: The query to find relevant memories for |
|
max_results: Maximum number of results to return |
|
|
|
Returns: |
|
List of relevant memory items |
|
""" |
|
if self.use_semantic_search: |
|
try: |
|
|
|
query_embedding = self.embedding_model.encode(query) |
|
|
|
|
|
similarities = [] |
|
for embedding, idx, mem_type in self.memory_embeddings: |
|
similarity = np.dot(query_embedding, embedding) / (np.linalg.norm(query_embedding) * np.linalg.norm(embedding)) |
|
similarities.append((similarity, idx, mem_type)) |
|
|
|
|
|
similarities.sort(reverse=True) |
|
|
|
|
|
relevant_memories = [] |
|
for similarity, idx, mem_type in similarities[:max_results]: |
|
if mem_type == "short_term": |
|
memory = self.short_term_memory[idx] |
|
else: |
|
memory = self.long_term_memory[idx] |
|
|
|
|
|
memory_with_score = memory.copy() |
|
memory_with_score["relevance_score"] = float(similarity) |
|
relevant_memories.append(memory_with_score) |
|
|
|
return relevant_memories |
|
except Exception as e: |
|
print(f"Warning: Semantic search failed: {str(e)}. Falling back to keyword search.") |
|
return self._keyword_search(query, max_results) |
|
else: |
|
return self._keyword_search(query, max_results) |
|
|
|
def _keyword_search(self, query: str, max_results: int = 10) -> List[Dict[str, Any]]: |
|
""" |
|
Fallback keyword-based search for relevant memories |
|
|
|
Args: |
|
query: The query to find relevant memories for |
|
max_results: Maximum number of results to return |
|
|
|
Returns: |
|
List of relevant memory items |
|
""" |
|
relevant_memories = [] |
|
query_keywords = set(re.findall(r'\b\w+\b', query.lower())) |
|
|
|
|
|
def score_memory(memory): |
|
content = memory.get("content", "").lower() |
|
content_words = set(re.findall(r'\b\w+\b', content)) |
|
|
|
|
|
matches = len(query_keywords.intersection(content_words)) |
|
|
|
|
|
type_boost = { |
|
"final_answer": 2.0, |
|
"key_fact": 1.5, |
|
"reasoning": 1.2, |
|
"general": 1.0 |
|
} |
|
|
|
|
|
try: |
|
timestamp = datetime.fromisoformat(memory.get("timestamp", "2000-01-01T00:00:00")) |
|
now = datetime.now() |
|
hours_ago = (now - timestamp).total_seconds() / 3600 |
|
recency_factor = max(0.5, 1.0 - (hours_ago / 24)) |
|
except: |
|
recency_factor = 0.5 |
|
|
|
|
|
score = matches * type_boost.get(memory.get("type", "general"), 1.0) * recency_factor |
|
|
|
return score |
|
|
|
|
|
scored_memories = [] |
|
|
|
|
|
for memory in self.long_term_memory: |
|
score = score_memory(memory) |
|
if score > 0: |
|
memory_with_score = memory.copy() |
|
memory_with_score["relevance_score"] = score |
|
scored_memories.append((score, memory_with_score)) |
|
|
|
|
|
for memory in self.short_term_memory: |
|
score = score_memory(memory) |
|
if score > 0: |
|
memory_with_score = memory.copy() |
|
memory_with_score["relevance_score"] = score |
|
scored_memories.append((score, memory_with_score)) |
|
|
|
|
|
scored_memories.sort(reverse=True, key=lambda x: x[0]) |
|
relevant_memories = [memory for _, memory in scored_memories[:max_results]] |
|
|
|
return relevant_memories |
|
|
|
def get_memory_summary(self) -> str: |
|
"""Get a summary of the current memory state for the agent""" |
|
|
|
recent_short_term = self.short_term_memory[-5:] if self.short_term_memory else [] |
|
short_term_summary = "\n".join([f"- [{m.get('type', 'general')}] {m.get('content', '')[:100]}..." |
|
for m in recent_short_term]) |
|
|
|
|
|
important_long_term = sorted(self.long_term_memory, |
|
key=lambda x: x.get("importance", 0), |
|
reverse=True)[:5] if self.long_term_memory else [] |
|
long_term_summary = "\n".join([f"- [{m.get('type', 'general')}] {m.get('content', '')[:100]}..." |
|
for m in important_long_term]) |
|
|
|
|
|
working_memory_summary = "\n".join([f"- {k}: {str(v)[:50]}..." if isinstance(v, str) and len(str(v)) > 50 |
|
else f"- {k}: {v}" for k, v in self.working_memory.items()]) |
|
|
|
return f""" |
|
MEMORY SUMMARY: |
|
-------------- |
|
Recent Short-Term Memory: |
|
{short_term_summary if short_term_summary else "No recent short-term memories."} |
|
|
|
Important Long-Term Memory: |
|
{long_term_summary if long_term_summary else "No important long-term memories."} |
|
|
|
Working Memory: |
|
{working_memory_summary if working_memory_summary else "Working memory is empty."} |
|
""" |
|
|
|
def save_memories(self) -> None: |
|
"""Save memories to disk for persistence""" |
|
try: |
|
|
|
memories = { |
|
"short_term": self.short_term_memory, |
|
"long_term": self.long_term_memory, |
|
"last_updated": datetime.now().isoformat() |
|
} |
|
|
|
with open(self.memory_file, 'w') as f: |
|
json.dump(memories, f, indent=2) |
|
except Exception as e: |
|
print(f"Warning: Could not save memories: {str(e)}") |
|
|
|
def load_memories(self) -> None: |
|
"""Load memories from disk if available""" |
|
try: |
|
if os.path.exists(self.memory_file): |
|
with open(self.memory_file, 'r') as f: |
|
memories = json.load(f) |
|
|
|
self.short_term_memory = memories.get("short_term", []) |
|
self.long_term_memory = memories.get("long_term", []) |
|
|
|
|
|
if self.use_semantic_search: |
|
self.memory_embeddings = [] |
|
|
|
|
|
for i, memory in enumerate(self.short_term_memory): |
|
try: |
|
content = memory.get("content", "") |
|
embedding = self.embedding_model.encode(content) |
|
self.memory_embeddings.append((embedding, i, "short_term")) |
|
except Exception as e: |
|
print(f"Warning: Could not create embedding for memory item: {str(e)}") |
|
|
|
|
|
for i, memory in enumerate(self.long_term_memory): |
|
try: |
|
content = memory.get("content", "") |
|
embedding = self.embedding_model.encode(content) |
|
self.memory_embeddings.append((embedding, i, "long_term")) |
|
except Exception as e: |
|
print(f"Warning: Could not create embedding for memory item: {str(e)}") |
|
|
|
print(f"Loaded {len(self.short_term_memory)} short-term and {len(self.long_term_memory)} long-term memories.") |
|
except Exception as e: |
|
print(f"Warning: Could not load memories: {str(e)}") |
|
|
|
def forget_old_memories(self, days_threshold: int = 30) -> None: |
|
""" |
|
Remove memories older than the specified threshold |
|
|
|
Args: |
|
days_threshold: Age threshold in days |
|
""" |
|
try: |
|
now = datetime.now() |
|
threshold = days_threshold * 24 * 60 * 60 |
|
|
|
|
|
new_short_term = [] |
|
for i, memory in enumerate(self.short_term_memory): |
|
try: |
|
timestamp = datetime.fromisoformat(memory.get("timestamp", "2000-01-01T00:00:00")) |
|
age = (now - timestamp).total_seconds() |
|
if age < threshold: |
|
new_short_term.append(memory) |
|
except: |
|
|
|
new_short_term.append(memory) |
|
|
|
|
|
new_long_term = [] |
|
for i, memory in enumerate(self.long_term_memory): |
|
try: |
|
timestamp = datetime.fromisoformat(memory.get("timestamp", "2000-01-01T00:00:00")) |
|
age = (now - timestamp).total_seconds() |
|
|
|
importance = memory.get("importance", 0.5) |
|
|
|
adjusted_threshold = threshold * (1 + importance) |
|
if age < adjusted_threshold: |
|
new_long_term.append(memory) |
|
except: |
|
|
|
new_long_term.append(memory) |
|
|
|
|
|
removed_short_term = len(self.short_term_memory) - len(new_short_term) |
|
removed_long_term = len(self.long_term_memory) - len(new_long_term) |
|
|
|
self.short_term_memory = new_short_term |
|
self.long_term_memory = new_long_term |
|
|
|
|
|
if self.use_semantic_search: |
|
self.memory_embeddings = [] |
|
|
|
|
|
for i, memory in enumerate(self.short_term_memory): |
|
try: |
|
content = memory.get("content", "") |
|
embedding = self.embedding_model.encode(content) |
|
self.memory_embeddings.append((embedding, i, "short_term")) |
|
except Exception as e: |
|
print(f"Warning: Could not create embedding for memory item: {str(e)}") |
|
|
|
|
|
for i, memory in enumerate(self.long_term_memory): |
|
try: |
|
content = memory.get("content", "") |
|
embedding = self.embedding_model.encode(content) |
|
self.memory_embeddings.append((embedding, i, "long_term")) |
|
except Exception as e: |
|
print(f"Warning: Could not create embedding for memory item: {str(e)}") |
|
|
|
|
|
self.save_memories() |
|
|
|
print(f"Forgot {removed_short_term} short-term and {removed_long_term} long-term memories older than {days_threshold} days.") |
|
except Exception as e: |
|
print(f"Warning: Could not forget old memories: {str(e)}") |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
memory_manager = EnhancedMemoryManager(use_semantic_search=True) |
|
|
|
|
|
memory_manager.add_to_short_term({ |
|
"type": "query", |
|
"content": "What is the capital of France?", |
|
"timestamp": datetime.now().isoformat() |
|
}) |
|
|
|
memory_manager.add_to_long_term({ |
|
"type": "key_fact", |
|
"content": "Paris is the capital of France with a population of about 2.2 million people.", |
|
"timestamp": datetime.now().isoformat() |
|
}) |
|
|
|
memory_manager.store_in_working_memory("current_task", "Finding information about France") |
|
|
|
|
|
relevant_memories = memory_manager.get_relevant_memories("What is the population of Paris?") |
|
print("\nRelevant memories for 'What is the population of Paris?':") |
|
for memory in relevant_memories: |
|
print(f"- Score: {memory.get('relevance_score', 0):.2f}, Content: {memory.get('content', '')}") |
|
|
|
|
|
print("\nMemory Summary:") |
|
print(memory_manager.get_memory_summary()) |
|
|