File size: 14,142 Bytes
c05fcc5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
"""
Core MandelMem system - main interface and orchestration.
Brings together all components into a unified memory architecture.
"""

import torch
import numpy as np
import time
from typing import Dict, Any, List, Optional, Union
from dataclasses import dataclass

from .encoders import MultiModalEncoder
from .quadtree import QuadTree, MemoryItem
from .dynamics import IterativeDynamics, AdaptiveThreshold
from .retrieval import MultiScaleRetriever, ExplainableRetriever
from .renormalization import RenormalizationEngine, MemorySummarizer
from .training import MandelMemTrainer, MandelMemLoss


@dataclass
class WriteResult:
    """Result of memory write operation."""
    persisted: bool
    tile_id: str
    stability_score: float
    band: str  # 'stable', 'plastic', 'escape'
    complex_coord: complex
    write_time: float


@dataclass
class ReadResult:
    """Result of memory read operation."""
    items: List[MemoryItem]
    similarities: List[float]
    trace: List[str]
    confidence: float
    total_time: float
    explanation: Optional[Dict[str, Any]] = None


class MandelMem:
    """Main MandelMem system interface."""
    
    def __init__(self, 
                 depth: int = 6,
                 embedding_dim: int = 768,
                 max_slots: int = 32,
                 max_buffer: int = 64,
                 base_tau: float = 2.0,
                 delta: float = 0.3):
        """Initialize MandelMem system.
        
        Args:
            depth: Quadtree depth
            embedding_dim: Vector embedding dimension
            max_slots: Maximum persistent slots per tile
            max_buffer: Maximum buffer items per tile
            base_tau: Base persistence threshold
            delta: Plasticity band width
        """
        self.depth = depth
        self.embedding_dim = embedding_dim
        self.base_tau = base_tau
        self.delta = delta
        
        # Initialize core components
        self.encoder = MultiModalEncoder(embedding_dim)
        self.quadtree = QuadTree(depth, embedding_dim)
        self.dynamics = IterativeDynamics(embedding_dim)
        self.retriever = MultiScaleRetriever(self.quadtree, self.encoder)
        self.explainable_retriever = ExplainableRetriever(self.retriever)
        
        # Renormalization system
        self.summarizer = MemorySummarizer(embedding_dim)
        self.renorm_engine = RenormalizationEngine(self.quadtree, self.summarizer)
        
        # Adaptive thresholds
        self.adaptive_threshold = AdaptiveThreshold(base_tau)
        
        # Training system
        self.loss_fn = MandelMemLoss()
        self.trainer = MandelMemTrainer(
            self.quadtree, self.dynamics, self.retriever, self.loss_fn
        )
        
        # System state
        self.total_writes = 0
        self.total_reads = 0
        self.policies: Dict[str, Dict[str, Any]] = {}
        
    def write(self, content: str, meta: Optional[Dict[str, Any]] = None) -> WriteResult:
        """Write content to memory with persistence dynamics.
        
        Args:
            content: Content to store
            meta: Optional metadata (importance, source, etc.)
            
        Returns:
            WriteResult with persistence decision and details
        """
        start_time = time.time()
        
        if meta is None:
            meta = {'importance': 0.5, 'source': 'user'}
        
        # Add timestamp and recency
        meta['timestamp'] = start_time
        meta['recency_weight'] = 1.0
        
        # Encode content
        encoding = self.encoder.encode(content, meta, start_time)
        
        # Route to leaf tile
        path = self.quadtree.route_to_leaf(encoding.complex_coord)
        leaf_tile = self.quadtree.tiles[path[-1]]
        
        # Get adaptive threshold for this tile
        tile_tau = self.adaptive_threshold.get_threshold(leaf_tile.tile_id)
        
        # Perform iterative write
        iter_result = self.dynamics.iterate_write(
            leaf_tile, encoding.vector, encoding.complex_coord, 
            meta, tau=tile_tau, delta=self.delta
        )
        
        # Create memory item
        memory_item = MemoryItem(
            vector=encoding.vector,
            content=content,
            metadata=meta,
            timestamp=start_time,
            stability_score=1.0 - iter_result.max_potential / 3.0
        )
        
        # Store based on persistence decision
        if iter_result.persist:
            if iter_result.band == 'stable':
                leaf_tile.add_to_slots(memory_item)
            elif iter_result.band == 'plastic':
                leaf_tile.add_to_plastic(memory_item)
        else:
            leaf_tile.add_to_buffer(memory_item)
        
        # Update statistics
        self.total_writes += 1
        leaf_tile.stats.write_count += 1
        
        # Check if renormalization is needed
        if self._should_renormalize(leaf_tile):
            self.renorm_engine.renormalize_tile(leaf_tile)
        
        write_time = time.time() - start_time
        
        return WriteResult(
            persisted=iter_result.persist,
            tile_id=leaf_tile.tile_id,
            stability_score=memory_item.stability_score,
            band=iter_result.band,
            complex_coord=encoding.complex_coord,
            write_time=write_time
        )
    
    def read(self, query: str, k: int = 5, with_trace: bool = False,
             with_explanation: bool = False) -> ReadResult:
        """Read memories matching query.
        
        Args:
            query: Search query
            k: Number of results to return
            with_trace: Include routing trace
            with_explanation: Include detailed explanation
            
        Returns:
            ReadResult with retrieved memories and metadata
        """
        start_time = time.time()
        
        if with_explanation:
            # Use explainable retriever
            explanation = self.explainable_retriever.retrieve_with_explanation(query, k)
            result = explanation['results']
            explanation_data = explanation
        else:
            # Use standard retriever
            result = self.retriever.retrieve(query, k, with_trace)
            explanation_data = None
        
        # Update statistics
        self.total_reads += 1
        for item in result.items:
            item.access_count += 1
            item.last_access = time.time()
        
        return ReadResult(
            items=result.items,
            similarities=result.similarities,
            trace=result.trace if with_trace else [],
            confidence=result.confidence,
            total_time=time.time() - start_time,
            explanation=explanation_data
        )
    
    def set_policy(self, tile_pattern: str, **kwargs):
        """Set policy for tile pattern.
        
        Args:
            tile_pattern: Tile pattern (e.g., '/finance/**')
            **kwargs: Policy parameters (persist, encrypt, etc.)
        """
        self.policies[tile_pattern] = kwargs
        
        # Apply to matching tiles
        for tile_id, tile in self.quadtree.tiles.items():
            if self._matches_pattern(tile_id, tile_pattern):
                tile.policies.update(kwargs)
    
    def set_threshold(self, tile_pattern: str, tau: float, delta: Optional[float] = None):
        """Set persistence threshold for tile pattern.
        
        Args:
            tile_pattern: Tile pattern
            tau: Persistence threshold
            delta: Plasticity band width
        """
        for tile_id, tile in self.quadtree.tiles.items():
            if self._matches_pattern(tile_id, tile_pattern):
                self.adaptive_threshold.tile_thresholds[tile_id] = tau
                if delta is not None:
                    # Store delta in tile policies
                    tile.policies['delta'] = delta
    
    def get_statistics(self) -> Dict[str, Any]:
        """Get comprehensive system statistics."""
        quadtree_stats = self.quadtree.get_statistics()
        compression_stats = self.renorm_engine.get_compression_stats()
        training_stats = self.trainer.get_training_stats()
        
        # Memory distribution
        memory_dist = self._analyze_memory_distribution()
        
        # Performance metrics
        perf_metrics = {
            'total_writes': self.total_writes,
            'total_reads': self.total_reads,
            'avg_write_time': self._get_avg_write_time(),
            'avg_read_time': self._get_avg_read_time()
        }
        
        return {
            'quadtree': quadtree_stats,
            'compression': compression_stats,
            'training': training_stats,
            'memory_distribution': memory_dist,
            'performance': perf_metrics,
            'policies': len(self.policies)
        }
    
    def _should_renormalize(self, tile) -> bool:
        """Check if tile needs renormalization."""
        # Simple heuristic - can be made more sophisticated
        return (len(tile.slots) >= tile.max_slots * 0.9 or
                len(tile.buffer) >= tile.max_buffer * 0.9 or
                tile.stats.write_count % 1000 == 0)
    
    def _matches_pattern(self, tile_id: str, pattern: str) -> bool:
        """Check if tile ID matches pattern."""
        # Simple pattern matching - can be enhanced
        if pattern.endswith('/**'):
            prefix = pattern[:-3]
            return tile_id.startswith(prefix)
        elif pattern.endswith('/*'):
            prefix = pattern[:-2]
            return tile_id.startswith(prefix) and '/' not in tile_id[len(prefix):]
        else:
            return tile_id == pattern
    
    def _analyze_memory_distribution(self) -> Dict[str, Any]:
        """Analyze memory distribution across tiles."""
        leaf_tiles = self.quadtree.get_leaf_tiles()
        
        slot_counts = [len(tile.slots) for tile in leaf_tiles]
        buffer_counts = [len(tile.buffer) for tile in leaf_tiles]
        plastic_counts = [len(tile.plastic) for tile in leaf_tiles]
        
        return {
            'total_leaf_tiles': len(leaf_tiles),
            'avg_slots_per_tile': np.mean(slot_counts) if slot_counts else 0,
            'avg_buffer_per_tile': np.mean(buffer_counts) if buffer_counts else 0,
            'avg_plastic_per_tile': np.mean(plastic_counts) if plastic_counts else 0,
            'max_slots_used': max(slot_counts) if slot_counts else 0,
            'tiles_at_capacity': sum(1 for c in slot_counts if c >= 30)  # Near max_slots
        }
    
    def _get_avg_write_time(self) -> float:
        """Get average write time (placeholder)."""
        # In a real implementation, you'd track this
        return 0.05  # 50ms average
    
    def _get_avg_read_time(self) -> float:
        """Get average read time (placeholder)."""
        # In a real implementation, you'd track this
        return 0.02  # 20ms average
    
    def export_memory_state(self, filepath: str):
        """Export current memory state to file."""
        state = {
            'quadtree_tiles': {},
            'statistics': self.get_statistics(),
            'policies': self.policies,
            'thresholds': self.adaptive_threshold.tile_thresholds
        }
        
        # Export tile contents (simplified)
        for tile_id, tile in self.quadtree.tiles.items():
            if tile.get_all_items():
                state['quadtree_tiles'][tile_id] = {
                    'attractor': tile.attractor.tolist(),
                    'item_count': len(tile.get_all_items()),
                    'bounds': [str(tile.bounds[0]), str(tile.bounds[1])],
                    'stats': tile.stats.__dict__
                }
        
        import json
        with open(filepath, 'w') as f:
            json.dump(state, f, indent=2, default=str)
    
    def safe_mode(self, enabled: bool = True):
        """Enable/disable safe mode (higher thresholds, less persistence)."""
        if enabled:
            # Raise all thresholds
            for tile_id in self.quadtree.tiles.keys():
                current_tau = self.adaptive_threshold.get_threshold(tile_id)
                self.adaptive_threshold.tile_thresholds[tile_id] = current_tau + 0.5
        else:
            # Reset to base thresholds
            self.adaptive_threshold.tile_thresholds.clear()


# Convenience functions for quick usage
def create_mandelmem(config: Optional[Dict[str, Any]] = None) -> MandelMem:
    """Create MandelMem instance with optional configuration."""
    if config is None:
        config = {}
    
    return MandelMem(
        depth=config.get('depth', 6),
        embedding_dim=config.get('embedding_dim', 768),
        max_slots=config.get('max_slots', 32),
        max_buffer=config.get('max_buffer', 64),
        base_tau=config.get('base_tau', 2.0),
        delta=config.get('delta', 0.3)
    )


def quick_demo() -> Dict[str, Any]:
    """Quick demonstration of MandelMem capabilities."""
    # Create system
    memory = create_mandelmem()
    
    # Write some memories
    important_info = "The quarterly revenue increased by 15% due to strong product sales"
    casual_note = "Remember to buy milk on the way home"
    technical_detail = "The algorithm uses gradient descent with learning rate 0.001"
    
    write_results = []
    write_results.append(memory.write(important_info, {'importance': 0.9, 'source': 'user'}))
    write_results.append(memory.write(casual_note, {'importance': 0.3, 'source': 'user'}))
    write_results.append(memory.write(technical_detail, {'importance': 0.7, 'source': 'system'}))
    
    # Read memories
    read_results = []
    read_results.append(memory.read("revenue sales", k=3, with_explanation=True))
    read_results.append(memory.read("algorithm learning", k=3, with_trace=True))
    read_results.append(memory.read("milk shopping", k=3))
    
    # Get statistics
    stats = memory.get_statistics()
    
    return {
        'write_results': write_results,
        'read_results': read_results,
        'statistics': stats,
        'demo_completed': True
    }