Spaces:
Sleeping
Sleeping
| """ | |
| Training objectives and loss functions for MandelMem system. | |
| Implements the multi-objective training described in the blueprint. | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import numpy as np | |
| from typing import Dict, Any, List, Tuple, Optional | |
| from dataclasses import dataclass | |
| from .quadtree import QuadTree, MemoryItem | |
| from .dynamics import IterativeDynamics | |
| from .retrieval import MultiScaleRetriever | |
| class TrainingBatch: | |
| """Batch of training data for MandelMem.""" | |
| queries: List[str] | |
| targets: List[str] | |
| importance_labels: List[float] | |
| noise_items: List[str] | |
| metadata: List[Dict[str, Any]] | |
| class LossComponents: | |
| """Individual loss components.""" | |
| retrieval_loss: float | |
| stability_loss: float | |
| escape_loss: float | |
| margin_loss: float | |
| renorm_loss: float | |
| total_loss: float | |
| class MandelMemLoss(nn.Module): | |
| """Multi-objective loss function for MandelMem training.""" | |
| def __init__(self, lambda_read: float = 1.0, lambda_stab: float = 0.5, | |
| lambda_esc: float = 0.3, lambda_margin: float = 0.2, | |
| lambda_renorm: float = 0.1): | |
| super().__init__() | |
| self.lambda_read = lambda_read | |
| self.lambda_stab = lambda_stab | |
| self.lambda_esc = lambda_esc | |
| self.lambda_margin = lambda_margin | |
| self.lambda_renorm = lambda_renorm | |
| def forward(self, batch: TrainingBatch, model_outputs: Dict[str, Any]) -> LossComponents: | |
| """Compute all loss components.""" | |
| # L_read: Retrieval loss (CE/contrastive on correct item) | |
| retrieval_loss = self._compute_retrieval_loss( | |
| batch.queries, batch.targets, model_outputs['retrievals'] | |
| ) | |
| # L_stab: Stability loss for important items | |
| stability_loss = self._compute_stability_loss( | |
| batch.importance_labels, model_outputs['stability_scores'] | |
| ) | |
| # L_esc: Escape loss for noise items | |
| escape_loss = self._compute_escape_loss( | |
| batch.noise_items, model_outputs['escape_scores'] | |
| ) | |
| # L_margin: Margin loss between persist vs non-persist | |
| margin_loss = self._compute_margin_loss( | |
| model_outputs['persist_scores'], model_outputs['escape_scores'] | |
| ) | |
| # L_renorm: Compression loss (preserve retrieval after merges) | |
| renorm_loss = self._compute_renorm_loss( | |
| model_outputs.get('pre_renorm_retrievals'), | |
| model_outputs.get('post_renorm_retrievals') | |
| ) | |
| # Total loss | |
| total_loss = (self.lambda_read * retrieval_loss + | |
| self.lambda_stab * stability_loss + | |
| self.lambda_esc * escape_loss + | |
| self.lambda_margin * margin_loss + | |
| self.lambda_renorm * renorm_loss) | |
| return LossComponents( | |
| retrieval_loss=retrieval_loss, | |
| stability_loss=stability_loss, | |
| escape_loss=escape_loss, | |
| margin_loss=margin_loss, | |
| renorm_loss=renorm_loss, | |
| total_loss=total_loss | |
| ) | |
| def _compute_retrieval_loss(self, queries: List[str], targets: List[str], | |
| retrievals: List[List[MemoryItem]]) -> float: | |
| """Compute retrieval accuracy loss.""" | |
| total_loss = 0.0 | |
| valid_pairs = 0 | |
| for query, target, retrieved_items in zip(queries, targets, retrievals): | |
| if not retrieved_items: | |
| continue | |
| # Find target in retrieved items | |
| target_found = False | |
| target_rank = len(retrieved_items) | |
| for i, item in enumerate(retrieved_items): | |
| if target.lower() in item.content.lower(): | |
| target_found = True | |
| target_rank = i | |
| break | |
| if target_found: | |
| # Ranking loss - penalize lower ranks | |
| rank_loss = np.log(1 + target_rank) | |
| total_loss += rank_loss | |
| else: | |
| # Target not found - high penalty | |
| total_loss += 10.0 | |
| valid_pairs += 1 | |
| return total_loss / max(valid_pairs, 1) | |
| def _compute_stability_loss(self, importance_labels: List[float], | |
| stability_scores: List[float]) -> float: | |
| """Keep important items stable (low potential).""" | |
| if not importance_labels or not stability_scores: | |
| return 0.0 | |
| loss = 0.0 | |
| for importance, stability in zip(importance_labels, stability_scores): | |
| if importance > 0.7: # Important items | |
| # Want high stability (low escape potential) | |
| target_stability = importance | |
| loss += F.mse_loss( | |
| torch.tensor(stability), | |
| torch.tensor(target_stability) | |
| ) | |
| return loss | |
| def _compute_escape_loss(self, noise_items: List[str], | |
| escape_scores: List[float]) -> float: | |
| """Push noise items to escape (high potential).""" | |
| if not noise_items or not escape_scores: | |
| return 0.0 | |
| loss = 0.0 | |
| tau = 2.0 # Base threshold | |
| for escape_score in escape_scores: | |
| # Want escape scores above threshold | |
| if escape_score <= tau: | |
| loss += (tau - escape_score) ** 2 | |
| return loss / len(escape_scores) | |
| def _compute_margin_loss(self, persist_scores: List[float], | |
| escape_scores: List[float]) -> float: | |
| """Enlarge gap between persist vs escape near boundaries.""" | |
| if not persist_scores or not escape_scores: | |
| return 0.0 | |
| margin = 0.5 | |
| loss = 0.0 | |
| count = 0 | |
| for p_score in persist_scores: | |
| for e_score in escape_scores: | |
| # Want persist_score + margin < escape_score | |
| gap = e_score - p_score | |
| if gap < margin: | |
| loss += (margin - gap) ** 2 | |
| count += 1 | |
| return loss / max(count, 1) | |
| def _compute_renorm_loss(self, pre_renorm: Optional[List[List[MemoryItem]]], | |
| post_renorm: Optional[List[List[MemoryItem]]]) -> float: | |
| """Preserve retrieval quality after renormalization.""" | |
| if not pre_renorm or not post_renorm: | |
| return 0.0 | |
| loss = 0.0 | |
| valid_comparisons = 0 | |
| for pre_items, post_items in zip(pre_renorm, post_renorm): | |
| if not pre_items or not post_items: | |
| continue | |
| # Compare top retrieved items | |
| pre_top = pre_items[0] if pre_items else None | |
| post_top = post_items[0] if post_items else None | |
| if pre_top and post_top: | |
| # Cosine similarity between top results | |
| similarity = F.cosine_similarity( | |
| pre_top.vector.unsqueeze(0), | |
| post_top.vector.unsqueeze(0) | |
| ) | |
| loss += 1.0 - similarity.item() | |
| valid_comparisons += 1 | |
| return loss / max(valid_comparisons, 1) | |
| class MandelMemTrainer: | |
| """Training orchestrator for MandelMem system.""" | |
| def __init__(self, quadtree: QuadTree, dynamics: IterativeDynamics, | |
| retriever: MultiScaleRetriever, loss_fn: MandelMemLoss): | |
| self.quadtree = quadtree | |
| self.dynamics = dynamics | |
| self.retriever = retriever | |
| self.loss_fn = loss_fn | |
| # Training state | |
| self.training_history: List[Dict[str, Any]] = [] | |
| self.current_epoch = 0 | |
| def train_epoch(self, train_batches: List[TrainingBatch], | |
| learning_rate: float = 1e-4) -> Dict[str, float]: | |
| """Train for one epoch.""" | |
| # Set up optimizer | |
| all_params = list(self.dynamics.parameters()) | |
| if hasattr(self.retriever.encoder, 'parameters'): | |
| all_params.extend(self.retriever.encoder.parameters()) | |
| optimizer = torch.optim.Adam(all_params, lr=learning_rate) | |
| epoch_losses = [] | |
| for batch in train_batches: | |
| optimizer.zero_grad() | |
| # Forward pass | |
| model_outputs = self._forward_pass(batch) | |
| # Compute losses | |
| losses = self.loss_fn(batch, model_outputs) | |
| # Backward pass | |
| losses.total_loss.backward() | |
| # Gradient clipping | |
| torch.nn.utils.clip_grad_norm_(all_params, max_norm=1.0) | |
| optimizer.step() | |
| epoch_losses.append({ | |
| 'total': losses.total_loss.item(), | |
| 'retrieval': losses.retrieval_loss, | |
| 'stability': losses.stability_loss, | |
| 'escape': losses.escape_loss, | |
| 'margin': losses.margin_loss, | |
| 'renorm': losses.renorm_loss | |
| }) | |
| # Aggregate epoch results | |
| epoch_stats = self._aggregate_losses(epoch_losses) | |
| self.training_history.append(epoch_stats) | |
| self.current_epoch += 1 | |
| return epoch_stats | |
| def _forward_pass(self, batch: TrainingBatch) -> Dict[str, Any]: | |
| """Forward pass through the model.""" | |
| outputs = { | |
| 'retrievals': [], | |
| 'stability_scores': [], | |
| 'escape_scores': [], | |
| 'persist_scores': [], | |
| 'pre_renorm_retrievals': None, | |
| 'post_renorm_retrievals': None | |
| } | |
| # Process each query | |
| for i, (query, target, importance, meta) in enumerate( | |
| zip(batch.queries, batch.targets, batch.importance_labels, batch.metadata) | |
| ): | |
| # Encode and write target to memory | |
| encoding = self.retriever.encoder.encode(target, meta) | |
| # Route to tile and perform iterative write | |
| path = self.quadtree.route_to_leaf(encoding.complex_coord) | |
| leaf_tile = self.quadtree.tiles[path[-1]] | |
| # Get iteration result | |
| iter_result = self.dynamics.iterate_write( | |
| leaf_tile, encoding.vector, encoding.complex_coord, meta | |
| ) | |
| # Store scores | |
| outputs['stability_scores'].append(1.0 - iter_result.max_potential / 3.0) | |
| outputs['escape_scores'].append(iter_result.max_potential) | |
| if iter_result.persist: | |
| outputs['persist_scores'].append(iter_result.max_potential) | |
| # Retrieve for query | |
| retrieval_result = self.retriever.retrieve(query, k=5) | |
| outputs['retrievals'].append(retrieval_result.items) | |
| # Process noise items | |
| for noise_item in batch.noise_items: | |
| noise_encoding = self.retriever.encoder.encode( | |
| noise_item, {'importance': 0.1, 'source': 'noise'} | |
| ) | |
| path = self.quadtree.route_to_leaf(noise_encoding.complex_coord) | |
| leaf_tile = self.quadtree.tiles[path[-1]] | |
| iter_result = self.dynamics.iterate_write( | |
| leaf_tile, noise_encoding.vector, noise_encoding.complex_coord, | |
| {'importance': 0.1, 'source': 'noise'} | |
| ) | |
| outputs['escape_scores'].append(iter_result.max_potential) | |
| return outputs | |
| def _aggregate_losses(self, epoch_losses: List[Dict[str, float]]) -> Dict[str, float]: | |
| """Aggregate losses across batches.""" | |
| if not epoch_losses: | |
| return {} | |
| aggregated = {} | |
| for key in epoch_losses[0].keys(): | |
| values = [loss[key] for loss in epoch_losses if key in loss] | |
| aggregated[f'avg_{key}'] = np.mean(values) | |
| aggregated[f'std_{key}'] = np.std(values) | |
| aggregated['epoch'] = self.current_epoch | |
| return aggregated | |
| def evaluate(self, test_batches: List[TrainingBatch]) -> Dict[str, float]: | |
| """Evaluate model on test data.""" | |
| with torch.no_grad(): | |
| test_losses = [] | |
| for batch in test_batches: | |
| model_outputs = self._forward_pass(batch) | |
| losses = self.loss_fn(batch, model_outputs) | |
| test_losses.append({ | |
| 'total': losses.total_loss.item(), | |
| 'retrieval': losses.retrieval_loss, | |
| 'stability': losses.stability_loss, | |
| 'escape': losses.escape_loss, | |
| 'margin': losses.margin_loss, | |
| 'renorm': losses.renorm_loss | |
| }) | |
| return self._aggregate_losses(test_losses) | |
| def get_training_stats(self) -> Dict[str, Any]: | |
| """Get training statistics.""" | |
| if not self.training_history: | |
| return {'message': 'No training history available'} | |
| latest = self.training_history[-1] | |
| # Compute trends | |
| if len(self.training_history) >= 2: | |
| prev = self.training_history[-2] | |
| trends = { | |
| f'{key}_trend': latest[key] - prev[key] | |
| for key in latest.keys() | |
| if key.startswith('avg_') and key in prev | |
| } | |
| else: | |
| trends = {} | |
| return { | |
| 'current_epoch': self.current_epoch, | |
| 'latest_losses': latest, | |
| 'trends': trends, | |
| 'total_epochs_trained': len(self.training_history) | |
| } | |
| class SyntheticDataGenerator: | |
| """Generate synthetic training data for MandelMem.""" | |
| def __init__(self, vocab_size: int = 1000): | |
| self.vocab_size = vocab_size | |
| self.word_list = self._generate_vocabulary() | |
| def _generate_vocabulary(self) -> List[str]: | |
| """Generate synthetic vocabulary.""" | |
| categories = { | |
| 'objects': ['car', 'house', 'tree', 'book', 'phone', 'computer', 'chair', 'table'], | |
| 'actions': ['run', 'jump', 'read', 'write', 'think', 'learn', 'create', 'build'], | |
| 'adjectives': ['big', 'small', 'red', 'blue', 'fast', 'slow', 'bright', 'dark'], | |
| 'concepts': ['memory', 'knowledge', 'wisdom', 'truth', 'beauty', 'justice', 'freedom'] | |
| } | |
| vocab = [] | |
| for category, words in categories.items(): | |
| vocab.extend(words) | |
| # Add numbers and common words | |
| vocab.extend([str(i) for i in range(100)]) | |
| vocab.extend(['the', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'from']) | |
| return vocab[:self.vocab_size] | |
| def generate_batch(self, batch_size: int = 32) -> TrainingBatch: | |
| """Generate a training batch.""" | |
| queries = [] | |
| targets = [] | |
| importance_labels = [] | |
| noise_items = [] | |
| metadata = [] | |
| for _ in range(batch_size): | |
| # Generate query-target pair | |
| target_words = np.random.choice(self.word_list, size=np.random.randint(3, 8)) | |
| target = ' '.join(target_words) | |
| # Query is partial target with some noise | |
| query_words = target_words[:np.random.randint(1, len(target_words))] | |
| if np.random.random() < 0.3: # Add noise word | |
| noise_word = np.random.choice(self.word_list) | |
| query_words = np.append(query_words, noise_word) | |
| query = ' '.join(query_words) | |
| # Random importance | |
| importance = np.random.beta(2, 2) # Bias toward middle values | |
| # Metadata | |
| meta = { | |
| 'importance': importance, | |
| 'source': np.random.choice(['user', 'system', 'external']), | |
| 'pii': np.random.random() < 0.1, | |
| 'repeats': np.random.poisson(1), | |
| 'recency_weight': np.random.exponential(0.5) | |
| } | |
| queries.append(query) | |
| targets.append(target) | |
| importance_labels.append(importance) | |
| metadata.append(meta) | |
| # Generate noise items | |
| for _ in range(batch_size // 4): | |
| noise_words = np.random.choice(self.word_list, size=np.random.randint(2, 6)) | |
| noise_items.append(' '.join(noise_words)) | |
| return TrainingBatch( | |
| queries=queries, | |
| targets=targets, | |
| importance_labels=importance_labels, | |
| noise_items=noise_items, | |
| metadata=metadata | |
| ) | |