Mandelmem / mandelmem /training.py
Kossisoroyce's picture
Upload 10 files
c05fcc5 verified
"""
Training objectives and loss functions for MandelMem system.
Implements the multi-objective training described in the blueprint.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Dict, Any, List, Tuple, Optional
from dataclasses import dataclass
from .quadtree import QuadTree, MemoryItem
from .dynamics import IterativeDynamics
from .retrieval import MultiScaleRetriever
@dataclass
class TrainingBatch:
"""Batch of training data for MandelMem."""
queries: List[str]
targets: List[str]
importance_labels: List[float]
noise_items: List[str]
metadata: List[Dict[str, Any]]
@dataclass
class LossComponents:
"""Individual loss components."""
retrieval_loss: float
stability_loss: float
escape_loss: float
margin_loss: float
renorm_loss: float
total_loss: float
class MandelMemLoss(nn.Module):
"""Multi-objective loss function for MandelMem training."""
def __init__(self, lambda_read: float = 1.0, lambda_stab: float = 0.5,
lambda_esc: float = 0.3, lambda_margin: float = 0.2,
lambda_renorm: float = 0.1):
super().__init__()
self.lambda_read = lambda_read
self.lambda_stab = lambda_stab
self.lambda_esc = lambda_esc
self.lambda_margin = lambda_margin
self.lambda_renorm = lambda_renorm
def forward(self, batch: TrainingBatch, model_outputs: Dict[str, Any]) -> LossComponents:
"""Compute all loss components."""
# L_read: Retrieval loss (CE/contrastive on correct item)
retrieval_loss = self._compute_retrieval_loss(
batch.queries, batch.targets, model_outputs['retrievals']
)
# L_stab: Stability loss for important items
stability_loss = self._compute_stability_loss(
batch.importance_labels, model_outputs['stability_scores']
)
# L_esc: Escape loss for noise items
escape_loss = self._compute_escape_loss(
batch.noise_items, model_outputs['escape_scores']
)
# L_margin: Margin loss between persist vs non-persist
margin_loss = self._compute_margin_loss(
model_outputs['persist_scores'], model_outputs['escape_scores']
)
# L_renorm: Compression loss (preserve retrieval after merges)
renorm_loss = self._compute_renorm_loss(
model_outputs.get('pre_renorm_retrievals'),
model_outputs.get('post_renorm_retrievals')
)
# Total loss
total_loss = (self.lambda_read * retrieval_loss +
self.lambda_stab * stability_loss +
self.lambda_esc * escape_loss +
self.lambda_margin * margin_loss +
self.lambda_renorm * renorm_loss)
return LossComponents(
retrieval_loss=retrieval_loss,
stability_loss=stability_loss,
escape_loss=escape_loss,
margin_loss=margin_loss,
renorm_loss=renorm_loss,
total_loss=total_loss
)
def _compute_retrieval_loss(self, queries: List[str], targets: List[str],
retrievals: List[List[MemoryItem]]) -> float:
"""Compute retrieval accuracy loss."""
total_loss = 0.0
valid_pairs = 0
for query, target, retrieved_items in zip(queries, targets, retrievals):
if not retrieved_items:
continue
# Find target in retrieved items
target_found = False
target_rank = len(retrieved_items)
for i, item in enumerate(retrieved_items):
if target.lower() in item.content.lower():
target_found = True
target_rank = i
break
if target_found:
# Ranking loss - penalize lower ranks
rank_loss = np.log(1 + target_rank)
total_loss += rank_loss
else:
# Target not found - high penalty
total_loss += 10.0
valid_pairs += 1
return total_loss / max(valid_pairs, 1)
def _compute_stability_loss(self, importance_labels: List[float],
stability_scores: List[float]) -> float:
"""Keep important items stable (low potential)."""
if not importance_labels or not stability_scores:
return 0.0
loss = 0.0
for importance, stability in zip(importance_labels, stability_scores):
if importance > 0.7: # Important items
# Want high stability (low escape potential)
target_stability = importance
loss += F.mse_loss(
torch.tensor(stability),
torch.tensor(target_stability)
)
return loss
def _compute_escape_loss(self, noise_items: List[str],
escape_scores: List[float]) -> float:
"""Push noise items to escape (high potential)."""
if not noise_items or not escape_scores:
return 0.0
loss = 0.0
tau = 2.0 # Base threshold
for escape_score in escape_scores:
# Want escape scores above threshold
if escape_score <= tau:
loss += (tau - escape_score) ** 2
return loss / len(escape_scores)
def _compute_margin_loss(self, persist_scores: List[float],
escape_scores: List[float]) -> float:
"""Enlarge gap between persist vs escape near boundaries."""
if not persist_scores or not escape_scores:
return 0.0
margin = 0.5
loss = 0.0
count = 0
for p_score in persist_scores:
for e_score in escape_scores:
# Want persist_score + margin < escape_score
gap = e_score - p_score
if gap < margin:
loss += (margin - gap) ** 2
count += 1
return loss / max(count, 1)
def _compute_renorm_loss(self, pre_renorm: Optional[List[List[MemoryItem]]],
post_renorm: Optional[List[List[MemoryItem]]]) -> float:
"""Preserve retrieval quality after renormalization."""
if not pre_renorm or not post_renorm:
return 0.0
loss = 0.0
valid_comparisons = 0
for pre_items, post_items in zip(pre_renorm, post_renorm):
if not pre_items or not post_items:
continue
# Compare top retrieved items
pre_top = pre_items[0] if pre_items else None
post_top = post_items[0] if post_items else None
if pre_top and post_top:
# Cosine similarity between top results
similarity = F.cosine_similarity(
pre_top.vector.unsqueeze(0),
post_top.vector.unsqueeze(0)
)
loss += 1.0 - similarity.item()
valid_comparisons += 1
return loss / max(valid_comparisons, 1)
class MandelMemTrainer:
"""Training orchestrator for MandelMem system."""
def __init__(self, quadtree: QuadTree, dynamics: IterativeDynamics,
retriever: MultiScaleRetriever, loss_fn: MandelMemLoss):
self.quadtree = quadtree
self.dynamics = dynamics
self.retriever = retriever
self.loss_fn = loss_fn
# Training state
self.training_history: List[Dict[str, Any]] = []
self.current_epoch = 0
def train_epoch(self, train_batches: List[TrainingBatch],
learning_rate: float = 1e-4) -> Dict[str, float]:
"""Train for one epoch."""
# Set up optimizer
all_params = list(self.dynamics.parameters())
if hasattr(self.retriever.encoder, 'parameters'):
all_params.extend(self.retriever.encoder.parameters())
optimizer = torch.optim.Adam(all_params, lr=learning_rate)
epoch_losses = []
for batch in train_batches:
optimizer.zero_grad()
# Forward pass
model_outputs = self._forward_pass(batch)
# Compute losses
losses = self.loss_fn(batch, model_outputs)
# Backward pass
losses.total_loss.backward()
# Gradient clipping
torch.nn.utils.clip_grad_norm_(all_params, max_norm=1.0)
optimizer.step()
epoch_losses.append({
'total': losses.total_loss.item(),
'retrieval': losses.retrieval_loss,
'stability': losses.stability_loss,
'escape': losses.escape_loss,
'margin': losses.margin_loss,
'renorm': losses.renorm_loss
})
# Aggregate epoch results
epoch_stats = self._aggregate_losses(epoch_losses)
self.training_history.append(epoch_stats)
self.current_epoch += 1
return epoch_stats
def _forward_pass(self, batch: TrainingBatch) -> Dict[str, Any]:
"""Forward pass through the model."""
outputs = {
'retrievals': [],
'stability_scores': [],
'escape_scores': [],
'persist_scores': [],
'pre_renorm_retrievals': None,
'post_renorm_retrievals': None
}
# Process each query
for i, (query, target, importance, meta) in enumerate(
zip(batch.queries, batch.targets, batch.importance_labels, batch.metadata)
):
# Encode and write target to memory
encoding = self.retriever.encoder.encode(target, meta)
# Route to tile and perform iterative write
path = self.quadtree.route_to_leaf(encoding.complex_coord)
leaf_tile = self.quadtree.tiles[path[-1]]
# Get iteration result
iter_result = self.dynamics.iterate_write(
leaf_tile, encoding.vector, encoding.complex_coord, meta
)
# Store scores
outputs['stability_scores'].append(1.0 - iter_result.max_potential / 3.0)
outputs['escape_scores'].append(iter_result.max_potential)
if iter_result.persist:
outputs['persist_scores'].append(iter_result.max_potential)
# Retrieve for query
retrieval_result = self.retriever.retrieve(query, k=5)
outputs['retrievals'].append(retrieval_result.items)
# Process noise items
for noise_item in batch.noise_items:
noise_encoding = self.retriever.encoder.encode(
noise_item, {'importance': 0.1, 'source': 'noise'}
)
path = self.quadtree.route_to_leaf(noise_encoding.complex_coord)
leaf_tile = self.quadtree.tiles[path[-1]]
iter_result = self.dynamics.iterate_write(
leaf_tile, noise_encoding.vector, noise_encoding.complex_coord,
{'importance': 0.1, 'source': 'noise'}
)
outputs['escape_scores'].append(iter_result.max_potential)
return outputs
def _aggregate_losses(self, epoch_losses: List[Dict[str, float]]) -> Dict[str, float]:
"""Aggregate losses across batches."""
if not epoch_losses:
return {}
aggregated = {}
for key in epoch_losses[0].keys():
values = [loss[key] for loss in epoch_losses if key in loss]
aggregated[f'avg_{key}'] = np.mean(values)
aggregated[f'std_{key}'] = np.std(values)
aggregated['epoch'] = self.current_epoch
return aggregated
def evaluate(self, test_batches: List[TrainingBatch]) -> Dict[str, float]:
"""Evaluate model on test data."""
with torch.no_grad():
test_losses = []
for batch in test_batches:
model_outputs = self._forward_pass(batch)
losses = self.loss_fn(batch, model_outputs)
test_losses.append({
'total': losses.total_loss.item(),
'retrieval': losses.retrieval_loss,
'stability': losses.stability_loss,
'escape': losses.escape_loss,
'margin': losses.margin_loss,
'renorm': losses.renorm_loss
})
return self._aggregate_losses(test_losses)
def get_training_stats(self) -> Dict[str, Any]:
"""Get training statistics."""
if not self.training_history:
return {'message': 'No training history available'}
latest = self.training_history[-1]
# Compute trends
if len(self.training_history) >= 2:
prev = self.training_history[-2]
trends = {
f'{key}_trend': latest[key] - prev[key]
for key in latest.keys()
if key.startswith('avg_') and key in prev
}
else:
trends = {}
return {
'current_epoch': self.current_epoch,
'latest_losses': latest,
'trends': trends,
'total_epochs_trained': len(self.training_history)
}
class SyntheticDataGenerator:
"""Generate synthetic training data for MandelMem."""
def __init__(self, vocab_size: int = 1000):
self.vocab_size = vocab_size
self.word_list = self._generate_vocabulary()
def _generate_vocabulary(self) -> List[str]:
"""Generate synthetic vocabulary."""
categories = {
'objects': ['car', 'house', 'tree', 'book', 'phone', 'computer', 'chair', 'table'],
'actions': ['run', 'jump', 'read', 'write', 'think', 'learn', 'create', 'build'],
'adjectives': ['big', 'small', 'red', 'blue', 'fast', 'slow', 'bright', 'dark'],
'concepts': ['memory', 'knowledge', 'wisdom', 'truth', 'beauty', 'justice', 'freedom']
}
vocab = []
for category, words in categories.items():
vocab.extend(words)
# Add numbers and common words
vocab.extend([str(i) for i in range(100)])
vocab.extend(['the', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'from'])
return vocab[:self.vocab_size]
def generate_batch(self, batch_size: int = 32) -> TrainingBatch:
"""Generate a training batch."""
queries = []
targets = []
importance_labels = []
noise_items = []
metadata = []
for _ in range(batch_size):
# Generate query-target pair
target_words = np.random.choice(self.word_list, size=np.random.randint(3, 8))
target = ' '.join(target_words)
# Query is partial target with some noise
query_words = target_words[:np.random.randint(1, len(target_words))]
if np.random.random() < 0.3: # Add noise word
noise_word = np.random.choice(self.word_list)
query_words = np.append(query_words, noise_word)
query = ' '.join(query_words)
# Random importance
importance = np.random.beta(2, 2) # Bias toward middle values
# Metadata
meta = {
'importance': importance,
'source': np.random.choice(['user', 'system', 'external']),
'pii': np.random.random() < 0.1,
'repeats': np.random.poisson(1),
'recency_weight': np.random.exponential(0.5)
}
queries.append(query)
targets.append(target)
importance_labels.append(importance)
metadata.append(meta)
# Generate noise items
for _ in range(batch_size // 4):
noise_words = np.random.choice(self.word_list, size=np.random.randint(2, 6))
noise_items.append(' '.join(noise_words))
return TrainingBatch(
queries=queries,
targets=targets,
importance_labels=importance_labels,
noise_items=noise_items,
metadata=metadata
)