Mandelmem / mandelmem /dynamics.py
Kossisoroyce's picture
Upload 10 files
c05fcc5 verified
"""
Iterative dynamics for MandelMem system.
Implements Mandelbrot-like iteration for persistence decisions.
"""
import torch
import torch.nn as nn
import numpy as np
from typing import Dict, Any, Tuple, Optional
from dataclasses import dataclass
from .quadtree import MemoryItem, Tile
@dataclass
class IterationResult:
"""Result of iterative dynamics."""
persist: bool
max_potential: float
band: str # 'stable', 'plastic', 'escape'
trajectory: torch.Tensor
final_state: torch.Tensor
class IterativeDynamics(nn.Module):
"""Implements the iterative map F_θ(z_k, [v, u, meta]) for persistence decisions."""
def __init__(self, embedding_dim: int = 768, meta_dim: int = 8):
super().__init__()
self.embedding_dim = embedding_dim
self.meta_dim = meta_dim
# Input dimensions: z + v + u (2D) + meta
input_dim = embedding_dim + embedding_dim + 2 + meta_dim
# Iterative map network
self.map_network = nn.Sequential(
nn.Linear(input_dim, embedding_dim * 2),
nn.Tanh(),
nn.Linear(embedding_dim * 2, embedding_dim),
nn.Tanh()
)
# Potential function for escape detection
self.potential_network = nn.Sequential(
nn.Linear(embedding_dim, embedding_dim // 2),
nn.ReLU(),
nn.Linear(embedding_dim // 2, 1),
nn.Softplus() # Ensure positive potential
)
def encode_metadata(self, meta: Dict[str, Any]) -> torch.Tensor:
"""Encode metadata to fixed-size vector."""
features = []
# Importance
features.append(meta.get('importance', 0.5))
# Source (one-hot)
source_types = ['user', 'system', 'external', 'generated']
source = meta.get('source', 'user')
for s in source_types:
features.append(1.0 if s == source else 0.0)
# PII flag
features.append(1.0 if meta.get('pii', False) else 0.0)
# Repetition count
features.append(min(meta.get('repeats', 0) / 10.0, 1.0))
# Recency (normalized)
features.append(meta.get('recency_weight', 0.5))
return torch.tensor(features[:self.meta_dim], dtype=torch.float32)
def iterate_step(self, z: torch.Tensor, v: torch.Tensor, u: complex,
meta: Dict[str, Any], tile_params: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Single iteration step of the map."""
# Prepare input
u_vec = torch.tensor([u.real, u.imag], dtype=torch.float32)
meta_vec = self.encode_metadata(meta)
# Concatenate all inputs
input_vec = torch.cat([z, v, u_vec, meta_vec])
# Apply map
z_next = self.map_network(input_vec)
# Apply tile-specific parameters if available (simplified)
if tile_params is not None and tile_params.numel() > 0:
# Simple additive bias instead of matrix multiplication
bias = torch.mean(tile_params, dim=0) if tile_params.dim() > 1 else tile_params
if bias.size(0) == z_next.size(0):
z_next = z_next + bias
return z_next
def compute_potential(self, z: torch.Tensor) -> float:
"""Compute escape potential P(z)."""
with torch.no_grad():
potential = self.potential_network(z)
return potential.item()
def iterate_write(self, tile: Tile, v: torch.Tensor, u: complex,
meta: Dict[str, Any], K: int = 8, tau: float = 2.0,
delta: float = 0.3) -> IterationResult:
"""Full iterative write process as described in the blueprint."""
# Initialize with tile attractor
z = tile.attractor.clone()
trajectory = [z.clone()]
max_potential = 0.0
# Apply repetition and recency adjustments to threshold
effective_tau = self._adjust_threshold(tau, meta)
for k in range(K):
# Iteration step
z = self.iterate_step(z, v, u, meta, tile.local_params)
trajectory.append(z.clone())
# Compute potential
potential = self.compute_potential(z)
max_potential = max(max_potential, potential)
# Early escape detection
if potential > effective_tau + delta:
return IterationResult(
persist=False,
max_potential=max_potential,
band='escape',
trajectory=torch.stack(trajectory),
final_state=z
)
# Determine final state based on potential
if max_potential <= effective_tau - delta:
# Stable - commit to slots
band = 'stable'
persist = True
elif max_potential <= effective_tau + delta:
# Plastic boundary band
band = 'plastic'
persist = True
else:
# Escaped
band = 'escape'
persist = False
return IterationResult(
persist=persist,
max_potential=max_potential,
band=band,
trajectory=torch.stack(trajectory),
final_state=z
)
def _adjust_threshold(self, base_tau: float, meta: Dict[str, Any]) -> float:
"""Adjust threshold based on importance, repetition, and recency."""
tau = base_tau
# Lower threshold for important items (easier to persist)
importance = meta.get('importance', 0.5)
tau -= (importance - 0.5) * 0.5
# Lower threshold for repeated items
repeats = meta.get('repeats', 0)
tau -= min(repeats * 0.1, 0.3)
# Lower threshold for recent items
recency_weight = meta.get('recency_weight', 0.5)
tau -= (recency_weight - 0.5) * 0.2
return max(tau, 0.5) # Minimum threshold
def compute_stability_margin(self, z1: torch.Tensor, z2: torch.Tensor) -> float:
"""Compute margin between stable and unstable trajectories."""
p1 = self.compute_potential(z1)
p2 = self.compute_potential(z2)
return abs(p1 - p2)
class BasinAnalyzer:
"""Analyzes basins of attraction and escape regions."""
def __init__(self, dynamics: IterativeDynamics):
self.dynamics = dynamics
def sample_basin(self, tile: Tile, n_samples: int = 1000) -> Dict[str, Any]:
"""Sample points in tile to analyze basin structure."""
bl, tr = tile.bounds
# Generate random points in tile
real_coords = np.random.uniform(bl.real, tr.real, n_samples)
imag_coords = np.random.uniform(bl.imag, tr.imag, n_samples)
stable_count = 0
plastic_count = 0
escape_count = 0
for i in range(n_samples):
u = complex(real_coords[i], imag_coords[i])
# Create dummy memory item for testing
v = torch.randn(self.dynamics.embedding_dim)
meta = {'importance': 0.5, 'source': 'test'}
result = self.dynamics.iterate_write(tile, v, u, meta)
if result.band == 'stable':
stable_count += 1
elif result.band == 'plastic':
plastic_count += 1
else:
escape_count += 1
return {
'stable_ratio': stable_count / n_samples,
'plastic_ratio': plastic_count / n_samples,
'escape_ratio': escape_count / n_samples,
'total_samples': n_samples
}
def find_basin_boundary(self, tile: Tile, resolution: int = 50) -> np.ndarray:
"""Find approximate basin boundary in tile."""
bl, tr = tile.bounds
real_range = np.linspace(bl.real, tr.real, resolution)
imag_range = np.linspace(bl.imag, tr.imag, resolution)
boundary_map = np.zeros((resolution, resolution))
for i, real_val in enumerate(real_range):
for j, imag_val in enumerate(imag_range):
u = complex(real_val, imag_val)
v = torch.randn(self.dynamics.embedding_dim)
meta = {'importance': 0.5, 'source': 'test'}
result = self.dynamics.iterate_write(tile, v, u, meta)
if result.band == 'stable':
boundary_map[i, j] = 1.0
elif result.band == 'plastic':
boundary_map[i, j] = 0.5
else:
boundary_map[i, j] = 0.0
return boundary_map
class AdaptiveThreshold:
"""Manages adaptive threshold adjustment per tile."""
def __init__(self, base_tau: float = 2.0, adaptation_rate: float = 0.01):
self.base_tau = base_tau
self.adaptation_rate = adaptation_rate
self.tile_thresholds: Dict[str, float] = {}
self.tile_stats: Dict[str, Dict[str, float]] = {}
def get_threshold(self, tile_id: str) -> float:
"""Get current threshold for tile."""
return self.tile_thresholds.get(tile_id, self.base_tau)
def update_threshold(self, tile_id: str, persist_rate: float, target_rate: float = 0.7):
"""Update threshold based on persistence rate."""
current_tau = self.get_threshold(tile_id)
# Adjust threshold to achieve target persistence rate
if persist_rate > target_rate:
# Too many items persisting, raise threshold
new_tau = current_tau + self.adaptation_rate
else:
# Too few items persisting, lower threshold
new_tau = current_tau - self.adaptation_rate
self.tile_thresholds[tile_id] = max(0.5, min(5.0, new_tau))
# Update stats
if tile_id not in self.tile_stats:
self.tile_stats[tile_id] = {}
self.tile_stats[tile_id]['persist_rate'] = persist_rate
self.tile_stats[tile_id]['threshold'] = new_tau
def get_tile_stats(self, tile_id: str) -> Dict[str, float]:
"""Get statistics for tile."""
return self.tile_stats.get(tile_id, {})