File size: 16,937 Bytes
c05fcc5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
"""
Training objectives and loss functions for MandelMem system.
Implements the multi-objective training described in the blueprint.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Dict, Any, List, Tuple, Optional
from dataclasses import dataclass

from .quadtree import QuadTree, MemoryItem
from .dynamics import IterativeDynamics
from .retrieval import MultiScaleRetriever


@dataclass
class TrainingBatch:
    """Batch of training data for MandelMem."""
    queries: List[str]
    targets: List[str]
    importance_labels: List[float]
    noise_items: List[str]
    metadata: List[Dict[str, Any]]


@dataclass
class LossComponents:
    """Individual loss components."""
    retrieval_loss: float
    stability_loss: float
    escape_loss: float
    margin_loss: float
    renorm_loss: float
    total_loss: float


class MandelMemLoss(nn.Module):
    """Multi-objective loss function for MandelMem training."""
    
    def __init__(self, lambda_read: float = 1.0, lambda_stab: float = 0.5,
                 lambda_esc: float = 0.3, lambda_margin: float = 0.2,
                 lambda_renorm: float = 0.1):
        super().__init__()
        self.lambda_read = lambda_read
        self.lambda_stab = lambda_stab
        self.lambda_esc = lambda_esc
        self.lambda_margin = lambda_margin
        self.lambda_renorm = lambda_renorm
        
    def forward(self, batch: TrainingBatch, model_outputs: Dict[str, Any]) -> LossComponents:
        """Compute all loss components."""
        
        # L_read: Retrieval loss (CE/contrastive on correct item)
        retrieval_loss = self._compute_retrieval_loss(
            batch.queries, batch.targets, model_outputs['retrievals']
        )
        
        # L_stab: Stability loss for important items
        stability_loss = self._compute_stability_loss(
            batch.importance_labels, model_outputs['stability_scores']
        )
        
        # L_esc: Escape loss for noise items
        escape_loss = self._compute_escape_loss(
            batch.noise_items, model_outputs['escape_scores']
        )
        
        # L_margin: Margin loss between persist vs non-persist
        margin_loss = self._compute_margin_loss(
            model_outputs['persist_scores'], model_outputs['escape_scores']
        )
        
        # L_renorm: Compression loss (preserve retrieval after merges)
        renorm_loss = self._compute_renorm_loss(
            model_outputs.get('pre_renorm_retrievals'),
            model_outputs.get('post_renorm_retrievals')
        )
        
        # Total loss
        total_loss = (self.lambda_read * retrieval_loss +
                     self.lambda_stab * stability_loss +
                     self.lambda_esc * escape_loss +
                     self.lambda_margin * margin_loss +
                     self.lambda_renorm * renorm_loss)
        
        return LossComponents(
            retrieval_loss=retrieval_loss,
            stability_loss=stability_loss,
            escape_loss=escape_loss,
            margin_loss=margin_loss,
            renorm_loss=renorm_loss,
            total_loss=total_loss
        )
    
    def _compute_retrieval_loss(self, queries: List[str], targets: List[str],
                              retrievals: List[List[MemoryItem]]) -> float:
        """Compute retrieval accuracy loss."""
        total_loss = 0.0
        valid_pairs = 0
        
        for query, target, retrieved_items in zip(queries, targets, retrievals):
            if not retrieved_items:
                continue
                
            # Find target in retrieved items
            target_found = False
            target_rank = len(retrieved_items)
            
            for i, item in enumerate(retrieved_items):
                if target.lower() in item.content.lower():
                    target_found = True
                    target_rank = i
                    break
            
            if target_found:
                # Ranking loss - penalize lower ranks
                rank_loss = np.log(1 + target_rank)
                total_loss += rank_loss
            else:
                # Target not found - high penalty
                total_loss += 10.0
            
            valid_pairs += 1
        
        return total_loss / max(valid_pairs, 1)
    
    def _compute_stability_loss(self, importance_labels: List[float],
                              stability_scores: List[float]) -> float:
        """Keep important items stable (low potential)."""
        if not importance_labels or not stability_scores:
            return 0.0
        
        loss = 0.0
        for importance, stability in zip(importance_labels, stability_scores):
            if importance > 0.7:  # Important items
                # Want high stability (low escape potential)
                target_stability = importance
                loss += F.mse_loss(
                    torch.tensor(stability),
                    torch.tensor(target_stability)
                )
        
        return loss
    
    def _compute_escape_loss(self, noise_items: List[str],
                           escape_scores: List[float]) -> float:
        """Push noise items to escape (high potential)."""
        if not noise_items or not escape_scores:
            return 0.0
        
        loss = 0.0
        tau = 2.0  # Base threshold
        
        for escape_score in escape_scores:
            # Want escape scores above threshold
            if escape_score <= tau:
                loss += (tau - escape_score) ** 2
        
        return loss / len(escape_scores)
    
    def _compute_margin_loss(self, persist_scores: List[float],
                           escape_scores: List[float]) -> float:
        """Enlarge gap between persist vs escape near boundaries."""
        if not persist_scores or not escape_scores:
            return 0.0
        
        margin = 0.5
        loss = 0.0
        count = 0
        
        for p_score in persist_scores:
            for e_score in escape_scores:
                # Want persist_score + margin < escape_score
                gap = e_score - p_score
                if gap < margin:
                    loss += (margin - gap) ** 2
                count += 1
        
        return loss / max(count, 1)
    
    def _compute_renorm_loss(self, pre_renorm: Optional[List[List[MemoryItem]]],
                           post_renorm: Optional[List[List[MemoryItem]]]) -> float:
        """Preserve retrieval quality after renormalization."""
        if not pre_renorm or not post_renorm:
            return 0.0
        
        loss = 0.0
        valid_comparisons = 0
        
        for pre_items, post_items in zip(pre_renorm, post_renorm):
            if not pre_items or not post_items:
                continue
            
            # Compare top retrieved items
            pre_top = pre_items[0] if pre_items else None
            post_top = post_items[0] if post_items else None
            
            if pre_top and post_top:
                # Cosine similarity between top results
                similarity = F.cosine_similarity(
                    pre_top.vector.unsqueeze(0),
                    post_top.vector.unsqueeze(0)
                )
                loss += 1.0 - similarity.item()
                valid_comparisons += 1
        
        return loss / max(valid_comparisons, 1)


class MandelMemTrainer:
    """Training orchestrator for MandelMem system."""
    
    def __init__(self, quadtree: QuadTree, dynamics: IterativeDynamics,
                 retriever: MultiScaleRetriever, loss_fn: MandelMemLoss):
        self.quadtree = quadtree
        self.dynamics = dynamics
        self.retriever = retriever
        self.loss_fn = loss_fn
        
        # Training state
        self.training_history: List[Dict[str, Any]] = []
        self.current_epoch = 0
        
    def train_epoch(self, train_batches: List[TrainingBatch],
                   learning_rate: float = 1e-4) -> Dict[str, float]:
        """Train for one epoch."""
        
        # Set up optimizer
        all_params = list(self.dynamics.parameters())
        if hasattr(self.retriever.encoder, 'parameters'):
            all_params.extend(self.retriever.encoder.parameters())
        
        optimizer = torch.optim.Adam(all_params, lr=learning_rate)
        
        epoch_losses = []
        
        for batch in train_batches:
            optimizer.zero_grad()
            
            # Forward pass
            model_outputs = self._forward_pass(batch)
            
            # Compute losses
            losses = self.loss_fn(batch, model_outputs)
            
            # Backward pass
            losses.total_loss.backward()
            
            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(all_params, max_norm=1.0)
            
            optimizer.step()
            
            epoch_losses.append({
                'total': losses.total_loss.item(),
                'retrieval': losses.retrieval_loss,
                'stability': losses.stability_loss,
                'escape': losses.escape_loss,
                'margin': losses.margin_loss,
                'renorm': losses.renorm_loss
            })
        
        # Aggregate epoch results
        epoch_stats = self._aggregate_losses(epoch_losses)
        self.training_history.append(epoch_stats)
        self.current_epoch += 1
        
        return epoch_stats
    
    def _forward_pass(self, batch: TrainingBatch) -> Dict[str, Any]:
        """Forward pass through the model."""
        outputs = {
            'retrievals': [],
            'stability_scores': [],
            'escape_scores': [],
            'persist_scores': [],
            'pre_renorm_retrievals': None,
            'post_renorm_retrievals': None
        }
        
        # Process each query
        for i, (query, target, importance, meta) in enumerate(
            zip(batch.queries, batch.targets, batch.importance_labels, batch.metadata)
        ):
            # Encode and write target to memory
            encoding = self.retriever.encoder.encode(target, meta)
            
            # Route to tile and perform iterative write
            path = self.quadtree.route_to_leaf(encoding.complex_coord)
            leaf_tile = self.quadtree.tiles[path[-1]]
            
            # Get iteration result
            iter_result = self.dynamics.iterate_write(
                leaf_tile, encoding.vector, encoding.complex_coord, meta
            )
            
            # Store scores
            outputs['stability_scores'].append(1.0 - iter_result.max_potential / 3.0)
            outputs['escape_scores'].append(iter_result.max_potential)
            
            if iter_result.persist:
                outputs['persist_scores'].append(iter_result.max_potential)
            
            # Retrieve for query
            retrieval_result = self.retriever.retrieve(query, k=5)
            outputs['retrievals'].append(retrieval_result.items)
        
        # Process noise items
        for noise_item in batch.noise_items:
            noise_encoding = self.retriever.encoder.encode(
                noise_item, {'importance': 0.1, 'source': 'noise'}
            )
            
            path = self.quadtree.route_to_leaf(noise_encoding.complex_coord)
            leaf_tile = self.quadtree.tiles[path[-1]]
            
            iter_result = self.dynamics.iterate_write(
                leaf_tile, noise_encoding.vector, noise_encoding.complex_coord,
                {'importance': 0.1, 'source': 'noise'}
            )
            
            outputs['escape_scores'].append(iter_result.max_potential)
        
        return outputs
    
    def _aggregate_losses(self, epoch_losses: List[Dict[str, float]]) -> Dict[str, float]:
        """Aggregate losses across batches."""
        if not epoch_losses:
            return {}
        
        aggregated = {}
        for key in epoch_losses[0].keys():
            values = [loss[key] for loss in epoch_losses if key in loss]
            aggregated[f'avg_{key}'] = np.mean(values)
            aggregated[f'std_{key}'] = np.std(values)
        
        aggregated['epoch'] = self.current_epoch
        return aggregated
    
    def evaluate(self, test_batches: List[TrainingBatch]) -> Dict[str, float]:
        """Evaluate model on test data."""
        with torch.no_grad():
            test_losses = []
            
            for batch in test_batches:
                model_outputs = self._forward_pass(batch)
                losses = self.loss_fn(batch, model_outputs)
                
                test_losses.append({
                    'total': losses.total_loss.item(),
                    'retrieval': losses.retrieval_loss,
                    'stability': losses.stability_loss,
                    'escape': losses.escape_loss,
                    'margin': losses.margin_loss,
                    'renorm': losses.renorm_loss
                })
        
        return self._aggregate_losses(test_losses)
    
    def get_training_stats(self) -> Dict[str, Any]:
        """Get training statistics."""
        if not self.training_history:
            return {'message': 'No training history available'}
        
        latest = self.training_history[-1]
        
        # Compute trends
        if len(self.training_history) >= 2:
            prev = self.training_history[-2]
            trends = {
                f'{key}_trend': latest[key] - prev[key]
                for key in latest.keys()
                if key.startswith('avg_') and key in prev
            }
        else:
            trends = {}
        
        return {
            'current_epoch': self.current_epoch,
            'latest_losses': latest,
            'trends': trends,
            'total_epochs_trained': len(self.training_history)
        }


class SyntheticDataGenerator:
    """Generate synthetic training data for MandelMem."""
    
    def __init__(self, vocab_size: int = 1000):
        self.vocab_size = vocab_size
        self.word_list = self._generate_vocabulary()
        
    def _generate_vocabulary(self) -> List[str]:
        """Generate synthetic vocabulary."""
        categories = {
            'objects': ['car', 'house', 'tree', 'book', 'phone', 'computer', 'chair', 'table'],
            'actions': ['run', 'jump', 'read', 'write', 'think', 'learn', 'create', 'build'],
            'adjectives': ['big', 'small', 'red', 'blue', 'fast', 'slow', 'bright', 'dark'],
            'concepts': ['memory', 'knowledge', 'wisdom', 'truth', 'beauty', 'justice', 'freedom']
        }
        
        vocab = []
        for category, words in categories.items():
            vocab.extend(words)
        
        # Add numbers and common words
        vocab.extend([str(i) for i in range(100)])
        vocab.extend(['the', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'from'])
        
        return vocab[:self.vocab_size]
    
    def generate_batch(self, batch_size: int = 32) -> TrainingBatch:
        """Generate a training batch."""
        queries = []
        targets = []
        importance_labels = []
        noise_items = []
        metadata = []
        
        for _ in range(batch_size):
            # Generate query-target pair
            target_words = np.random.choice(self.word_list, size=np.random.randint(3, 8))
            target = ' '.join(target_words)
            
            # Query is partial target with some noise
            query_words = target_words[:np.random.randint(1, len(target_words))]
            if np.random.random() < 0.3:  # Add noise word
                noise_word = np.random.choice(self.word_list)
                query_words = np.append(query_words, noise_word)
            
            query = ' '.join(query_words)
            
            # Random importance
            importance = np.random.beta(2, 2)  # Bias toward middle values
            
            # Metadata
            meta = {
                'importance': importance,
                'source': np.random.choice(['user', 'system', 'external']),
                'pii': np.random.random() < 0.1,
                'repeats': np.random.poisson(1),
                'recency_weight': np.random.exponential(0.5)
            }
            
            queries.append(query)
            targets.append(target)
            importance_labels.append(importance)
            metadata.append(meta)
        
        # Generate noise items
        for _ in range(batch_size // 4):
            noise_words = np.random.choice(self.word_list, size=np.random.randint(2, 6))
            noise_items.append(' '.join(noise_words))
        
        return TrainingBatch(
            queries=queries,
            targets=targets,
            importance_labels=importance_labels,
            noise_items=noise_items,
            metadata=metadata
        )