File size: 20,643 Bytes
162ee47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
"""
Enhanced Memory System for GAIA-Ready AI Agent

This module provides an advanced memory system for the AI agent,
including short-term, long-term, and working memory components,
as well as semantic retrieval capabilities.
"""

import os
import json
from typing import List, Dict, Any, Optional, Union
from datetime import datetime
import re
import numpy as np
from collections import defaultdict

try:
    from sentence_transformers import SentenceTransformer
except ImportError:
    import subprocess
    subprocess.check_call(["pip", "install", "sentence-transformers"])
    from sentence_transformers import SentenceTransformer


class EnhancedMemoryManager:
    """
    Advanced memory manager for the agent that maintains short-term, long-term,
    and working memory with semantic retrieval capabilities.
    """
    def __init__(self, use_semantic_search=True):
        self.short_term_memory = []  # Current conversation context
        self.long_term_memory = []   # Key facts and results
        self.working_memory = {}     # Temporary storage for complex tasks
        self.max_short_term_items = 15
        self.max_long_term_items = 100
        self.use_semantic_search = use_semantic_search
        
        # Initialize semantic search if enabled
        if self.use_semantic_search:
            try:
                self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
                self.memory_embeddings = []
            except Exception as e:
                print(f"Warning: Could not initialize semantic search: {str(e)}")
                self.use_semantic_search = False
        
        # Memory persistence
        self.memory_file = "agent_memory.json"
        self.load_memories()
    
    def add_to_short_term(self, item: Dict[str, Any]) -> None:
        """Add an item to short-term memory, maintaining size limit"""
        # Ensure item has all required fields
        if "content" not in item:
            raise ValueError("Memory item must have 'content' field")
        
        if "timestamp" not in item:
            item["timestamp"] = datetime.now().isoformat()
        
        if "type" not in item:
            item["type"] = "general"
        
        self.short_term_memory.append(item)
        
        # Update semantic embeddings if enabled
        if self.use_semantic_search:
            try:
                content = item.get("content", "")
                embedding = self.embedding_model.encode(content)
                self.memory_embeddings.append((embedding, len(self.short_term_memory) - 1, "short_term"))
            except Exception as e:
                print(f"Warning: Could not create embedding for memory item: {str(e)}")
        
        # Maintain size limit
        if len(self.short_term_memory) > self.max_short_term_items:
            removed_item = self.short_term_memory.pop(0)
            # Remove corresponding embedding if it exists
            if self.use_semantic_search:
                self.memory_embeddings = [(emb, idx, mem_type) for emb, idx, mem_type in self.memory_embeddings 
                                         if not (mem_type == "short_term" and idx == 0)]
                # Update indices for remaining short-term memories
                self.memory_embeddings = [(emb, idx-1 if mem_type == "short_term" else idx, mem_type) 
                                         for emb, idx, mem_type in self.memory_embeddings]
        
        # Save memories periodically
        self.save_memories()
    
    def add_to_long_term(self, item: Dict[str, Any]) -> None:
        """Add an important item to long-term memory, maintaining size limit"""
        # Ensure item has all required fields
        if "content" not in item:
            raise ValueError("Memory item must have 'content' field")
        
        if "timestamp" not in item:
            item["timestamp"] = datetime.now().isoformat()
        
        if "type" not in item:
            item["type"] = "general"
        
        # Add importance score if not present
        if "importance" not in item:
            # Calculate importance based on content length and type
            content_length = len(item.get("content", ""))
            type_importance = {
                "final_answer": 0.9,
                "key_fact": 0.8,
                "reasoning": 0.7,
                "general": 0.5
            }
            item["importance"] = min(1.0, (content_length / 1000) * type_importance.get(item["type"], 0.5))
        
        self.long_term_memory.append(item)
        
        # Update semantic embeddings if enabled
        if self.use_semantic_search:
            try:
                content = item.get("content", "")
                embedding = self.embedding_model.encode(content)
                self.memory_embeddings.append((embedding, len(self.long_term_memory) - 1, "long_term"))
            except Exception as e:
                print(f"Warning: Could not create embedding for memory item: {str(e)}")
        
        # Sort long-term memory by importance (descending)
        self.long_term_memory.sort(key=lambda x: x.get("importance", 0), reverse=True)
        
        # Maintain size limit
        if len(self.long_term_memory) > self.max_long_term_items:
            # Remove least important memory
            removed_item = self.long_term_memory.pop()
            # Remove corresponding embedding if it exists
            if self.use_semantic_search:
                self.memory_embeddings = [(emb, idx, mem_type) for emb, idx, mem_type in self.memory_embeddings 
                                         if not (mem_type == "long_term" and idx == len(self.long_term_memory))]
                # Update indices for remaining long-term memories
                # This is more complex since we sorted by importance, so we need to rebuild indices
                long_term_embeddings = []
                for i, item in enumerate(self.long_term_memory):
                    content = item.get("content", "")
                    embedding = self.embedding_model.encode(content)
                    long_term_embeddings.append((embedding, i, "long_term"))
                
                # Keep short-term embeddings and replace long-term ones
                self.memory_embeddings = [(emb, idx, mem_type) for emb, idx, mem_type in self.memory_embeddings 
                                         if mem_type == "short_term"] + long_term_embeddings
        
        # Save memories periodically
        self.save_memories()
    
    def store_in_working_memory(self, key: str, value: Any) -> None:
        """Store a value in working memory under the specified key"""
        self.working_memory[key] = value
        # Working memory is not persisted between sessions
    
    def get_from_working_memory(self, key: str) -> Optional[Any]:
        """Retrieve a value from working memory by key"""
        return self.working_memory.get(key)
    
    def clear_working_memory(self) -> None:
        """Clear the working memory"""
        self.working_memory = {}
    
    def get_relevant_memories(self, query: str, max_results: int = 10) -> List[Dict[str, Any]]:
        """
        Retrieve memories relevant to the current query
        
        Args:
            query: The query to find relevant memories for
            max_results: Maximum number of results to return
            
        Returns:
            List of relevant memory items
        """
        if self.use_semantic_search:
            try:
                # Use semantic search to find relevant memories
                query_embedding = self.embedding_model.encode(query)
                
                # Calculate cosine similarity with all memory embeddings
                similarities = []
                for embedding, idx, mem_type in self.memory_embeddings:
                    similarity = np.dot(query_embedding, embedding) / (np.linalg.norm(query_embedding) * np.linalg.norm(embedding))
                    similarities.append((similarity, idx, mem_type))
                
                # Sort by similarity (descending)
                similarities.sort(reverse=True)
                
                # Get top results
                relevant_memories = []
                for similarity, idx, mem_type in similarities[:max_results]:
                    if mem_type == "short_term":
                        memory = self.short_term_memory[idx]
                    else:  # long_term
                        memory = self.long_term_memory[idx]
                    
                    # Add similarity score to memory item
                    memory_with_score = memory.copy()
                    memory_with_score["relevance_score"] = float(similarity)
                    relevant_memories.append(memory_with_score)
                
                return relevant_memories
            except Exception as e:
                print(f"Warning: Semantic search failed: {str(e)}. Falling back to keyword search.")
                return self._keyword_search(query, max_results)
        else:
            return self._keyword_search(query, max_results)
    
    def _keyword_search(self, query: str, max_results: int = 10) -> List[Dict[str, Any]]:
        """
        Fallback keyword-based search for relevant memories
        
        Args:
            query: The query to find relevant memories for
            max_results: Maximum number of results to return
            
        Returns:
            List of relevant memory items
        """
        relevant_memories = []
        query_keywords = set(re.findall(r'\b\w+\b', query.lower()))
        
        # Score function for keyword matching
        def score_memory(memory):
            content = memory.get("content", "").lower()
            content_words = set(re.findall(r'\b\w+\b', content))
            
            # Count matching keywords
            matches = len(query_keywords.intersection(content_words))
            
            # Consider memory type and recency
            type_boost = {
                "final_answer": 2.0,
                "key_fact": 1.5,
                "reasoning": 1.2,
                "general": 1.0
            }
            
            # Calculate recency (assuming ISO format timestamps)
            try:
                timestamp = datetime.fromisoformat(memory.get("timestamp", "2000-01-01T00:00:00"))
                now = datetime.now()
                hours_ago = (now - timestamp).total_seconds() / 3600
                recency_factor = max(0.5, 1.0 - (hours_ago / 24))  # Decay over 24 hours
            except:
                recency_factor = 0.5
            
            # Calculate final score
            score = matches * type_boost.get(memory.get("type", "general"), 1.0) * recency_factor
            
            return score
        
        # Score all memories
        scored_memories = []
        
        # Check long-term memory first (more important)
        for memory in self.long_term_memory:
            score = score_memory(memory)
            if score > 0:
                memory_with_score = memory.copy()
                memory_with_score["relevance_score"] = score
                scored_memories.append((score, memory_with_score))
        
        # Then check short-term memory
        for memory in self.short_term_memory:
            score = score_memory(memory)
            if score > 0:
                memory_with_score = memory.copy()
                memory_with_score["relevance_score"] = score
                scored_memories.append((score, memory_with_score))
        
        # Sort by score (descending) and take top results
        scored_memories.sort(reverse=True, key=lambda x: x[0])
        relevant_memories = [memory for _, memory in scored_memories[:max_results]]
        
        return relevant_memories
    
    def get_memory_summary(self) -> str:
        """Get a summary of the current memory state for the agent"""
        # Get most recent short-term memories
        recent_short_term = self.short_term_memory[-5:] if self.short_term_memory else []
        short_term_summary = "\n".join([f"- [{m.get('type', 'general')}] {m.get('content', '')[:100]}..." 
                                      for m in recent_short_term])
        
        # Get most important long-term memories
        important_long_term = sorted(self.long_term_memory, 
                                    key=lambda x: x.get("importance", 0), 
                                    reverse=True)[:5] if self.long_term_memory else []
        long_term_summary = "\n".join([f"- [{m.get('type', 'general')}] {m.get('content', '')[:100]}..." 
                                     for m in important_long_term])
        
        # Summarize working memory
        working_memory_summary = "\n".join([f"- {k}: {str(v)[:50]}..." if isinstance(v, str) and len(str(v)) > 50 
                                          else f"- {k}: {v}" for k, v in self.working_memory.items()])
        
        return f"""
MEMORY SUMMARY:
--------------
Recent Short-Term Memory:
{short_term_summary if short_term_summary else "No recent short-term memories."}

Important Long-Term Memory:
{long_term_summary if long_term_summary else "No important long-term memories."}

Working Memory:
{working_memory_summary if working_memory_summary else "Working memory is empty."}
"""
    
    def save_memories(self) -> None:
        """Save memories to disk for persistence"""
        try:
            # Only save short-term and long-term memories (not working memory)
            memories = {
                "short_term": self.short_term_memory,
                "long_term": self.long_term_memory,
                "last_updated": datetime.now().isoformat()
            }
            
            with open(self.memory_file, 'w') as f:
                json.dump(memories, f, indent=2)
        except Exception as e:
            print(f"Warning: Could not save memories: {str(e)}")
    
    def load_memories(self) -> None:
        """Load memories from disk if available"""
        try:
            if os.path.exists(self.memory_file):
                with open(self.memory_file, 'r') as f:
                    memories = json.load(f)
                
                self.short_term_memory = memories.get("short_term", [])
                self.long_term_memory = memories.get("long_term", [])
                
                # Rebuild embeddings if semantic search is enabled
                if self.use_semantic_search:
                    self.memory_embeddings = []
                    
                    # Add embeddings for short-term memories
                    for i, memory in enumerate(self.short_term_memory):
                        try:
                            content = memory.get("content", "")
                            embedding = self.embedding_model.encode(content)
                            self.memory_embeddings.append((embedding, i, "short_term"))
                        except Exception as e:
                            print(f"Warning: Could not create embedding for memory item: {str(e)}")
                    
                    # Add embeddings for long-term memories
                    for i, memory in enumerate(self.long_term_memory):
                        try:
                            content = memory.get("content", "")
                            embedding = self.embedding_model.encode(content)
                            self.memory_embeddings.append((embedding, i, "long_term"))
                        except Exception as e:
                            print(f"Warning: Could not create embedding for memory item: {str(e)}")
                
                print(f"Loaded {len(self.short_term_memory)} short-term and {len(self.long_term_memory)} long-term memories.")
        except Exception as e:
            print(f"Warning: Could not load memories: {str(e)}")
    
    def forget_old_memories(self, days_threshold: int = 30) -> None:
        """
        Remove memories older than the specified threshold
        
        Args:
            days_threshold: Age threshold in days
        """
        try:
            now = datetime.now()
            threshold = days_threshold * 24 * 60 * 60  # Convert to seconds
            
            # Filter short-term memories
            new_short_term = []
            for i, memory in enumerate(self.short_term_memory):
                try:
                    timestamp = datetime.fromisoformat(memory.get("timestamp", "2000-01-01T00:00:00"))
                    age = (now - timestamp).total_seconds()
                    if age < threshold:
                        new_short_term.append(memory)
                except:
                    # Keep memories with invalid timestamps
                    new_short_term.append(memory)
            
            # Filter long-term memories
            new_long_term = []
            for i, memory in enumerate(self.long_term_memory):
                try:
                    timestamp = datetime.fromisoformat(memory.get("timestamp", "2000-01-01T00:00:00"))
                    age = (now - timestamp).total_seconds()
                    # For long-term, also consider importance
                    importance = memory.get("importance", 0.5)
                    # More important memories have a higher threshold
                    adjusted_threshold = threshold * (1 + importance)
                    if age < adjusted_threshold:
                        new_long_term.append(memory)
                except:
                    # Keep memories with invalid timestamps
                    new_long_term.append(memory)
            
            # Update memories
            removed_short_term = len(self.short_term_memory) - len(new_short_term)
            removed_long_term = len(self.long_term_memory) - len(new_long_term)
            
            self.short_term_memory = new_short_term
            self.long_term_memory = new_long_term
            
            # Rebuild embeddings if semantic search is enabled
            if self.use_semantic_search:
                self.memory_embeddings = []
                
                # Add embeddings for short-term memories
                for i, memory in enumerate(self.short_term_memory):
                    try:
                        content = memory.get("content", "")
                        embedding = self.embedding_model.encode(content)
                        self.memory_embeddings.append((embedding, i, "short_term"))
                    except Exception as e:
                        print(f"Warning: Could not create embedding for memory item: {str(e)}")
                
                # Add embeddings for long-term memories
                for i, memory in enumerate(self.long_term_memory):
                    try:
                        content = memory.get("content", "")
                        embedding = self.embedding_model.encode(content)
                        self.memory_embeddings.append((embedding, i, "long_term"))
                    except Exception as e:
                        print(f"Warning: Could not create embedding for memory item: {str(e)}")
            
            # Save updated memories
            self.save_memories()
            
            print(f"Forgot {removed_short_term} short-term and {removed_long_term} long-term memories older than {days_threshold} days.")
        except Exception as e:
            print(f"Warning: Could not forget old memories: {str(e)}")


# Example usage
if __name__ == "__main__":
    # Initialize the memory manager
    memory_manager = EnhancedMemoryManager(use_semantic_search=True)
    
    # Add some test memories
    memory_manager.add_to_short_term({
        "type": "query",
        "content": "What is the capital of France?",
        "timestamp": datetime.now().isoformat()
    })
    
    memory_manager.add_to_long_term({
        "type": "key_fact",
        "content": "Paris is the capital of France with a population of about 2.2 million people.",
        "timestamp": datetime.now().isoformat()
    })
    
    memory_manager.store_in_working_memory("current_task", "Finding information about France")
    
    # Test retrieval
    relevant_memories = memory_manager.get_relevant_memories("What is the population of Paris?")
    print("\nRelevant memories for 'What is the population of Paris?':")
    for memory in relevant_memories:
        print(f"- Score: {memory.get('relevance_score', 0):.2f}, Content: {memory.get('content', '')}")
    
    # Print memory summary
    print("\nMemory Summary:")
    print(memory_manager.get_memory_summary())