Spaces:
Sleeping
Sleeping
File size: 10,463 Bytes
c05fcc5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 |
"""
Iterative dynamics for MandelMem system.
Implements Mandelbrot-like iteration for persistence decisions.
"""
import torch
import torch.nn as nn
import numpy as np
from typing import Dict, Any, Tuple, Optional
from dataclasses import dataclass
from .quadtree import MemoryItem, Tile
@dataclass
class IterationResult:
"""Result of iterative dynamics."""
persist: bool
max_potential: float
band: str # 'stable', 'plastic', 'escape'
trajectory: torch.Tensor
final_state: torch.Tensor
class IterativeDynamics(nn.Module):
"""Implements the iterative map F_θ(z_k, [v, u, meta]) for persistence decisions."""
def __init__(self, embedding_dim: int = 768, meta_dim: int = 8):
super().__init__()
self.embedding_dim = embedding_dim
self.meta_dim = meta_dim
# Input dimensions: z + v + u (2D) + meta
input_dim = embedding_dim + embedding_dim + 2 + meta_dim
# Iterative map network
self.map_network = nn.Sequential(
nn.Linear(input_dim, embedding_dim * 2),
nn.Tanh(),
nn.Linear(embedding_dim * 2, embedding_dim),
nn.Tanh()
)
# Potential function for escape detection
self.potential_network = nn.Sequential(
nn.Linear(embedding_dim, embedding_dim // 2),
nn.ReLU(),
nn.Linear(embedding_dim // 2, 1),
nn.Softplus() # Ensure positive potential
)
def encode_metadata(self, meta: Dict[str, Any]) -> torch.Tensor:
"""Encode metadata to fixed-size vector."""
features = []
# Importance
features.append(meta.get('importance', 0.5))
# Source (one-hot)
source_types = ['user', 'system', 'external', 'generated']
source = meta.get('source', 'user')
for s in source_types:
features.append(1.0 if s == source else 0.0)
# PII flag
features.append(1.0 if meta.get('pii', False) else 0.0)
# Repetition count
features.append(min(meta.get('repeats', 0) / 10.0, 1.0))
# Recency (normalized)
features.append(meta.get('recency_weight', 0.5))
return torch.tensor(features[:self.meta_dim], dtype=torch.float32)
def iterate_step(self, z: torch.Tensor, v: torch.Tensor, u: complex,
meta: Dict[str, Any], tile_params: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Single iteration step of the map."""
# Prepare input
u_vec = torch.tensor([u.real, u.imag], dtype=torch.float32)
meta_vec = self.encode_metadata(meta)
# Concatenate all inputs
input_vec = torch.cat([z, v, u_vec, meta_vec])
# Apply map
z_next = self.map_network(input_vec)
# Apply tile-specific parameters if available (simplified)
if tile_params is not None and tile_params.numel() > 0:
# Simple additive bias instead of matrix multiplication
bias = torch.mean(tile_params, dim=0) if tile_params.dim() > 1 else tile_params
if bias.size(0) == z_next.size(0):
z_next = z_next + bias
return z_next
def compute_potential(self, z: torch.Tensor) -> float:
"""Compute escape potential P(z)."""
with torch.no_grad():
potential = self.potential_network(z)
return potential.item()
def iterate_write(self, tile: Tile, v: torch.Tensor, u: complex,
meta: Dict[str, Any], K: int = 8, tau: float = 2.0,
delta: float = 0.3) -> IterationResult:
"""Full iterative write process as described in the blueprint."""
# Initialize with tile attractor
z = tile.attractor.clone()
trajectory = [z.clone()]
max_potential = 0.0
# Apply repetition and recency adjustments to threshold
effective_tau = self._adjust_threshold(tau, meta)
for k in range(K):
# Iteration step
z = self.iterate_step(z, v, u, meta, tile.local_params)
trajectory.append(z.clone())
# Compute potential
potential = self.compute_potential(z)
max_potential = max(max_potential, potential)
# Early escape detection
if potential > effective_tau + delta:
return IterationResult(
persist=False,
max_potential=max_potential,
band='escape',
trajectory=torch.stack(trajectory),
final_state=z
)
# Determine final state based on potential
if max_potential <= effective_tau - delta:
# Stable - commit to slots
band = 'stable'
persist = True
elif max_potential <= effective_tau + delta:
# Plastic boundary band
band = 'plastic'
persist = True
else:
# Escaped
band = 'escape'
persist = False
return IterationResult(
persist=persist,
max_potential=max_potential,
band=band,
trajectory=torch.stack(trajectory),
final_state=z
)
def _adjust_threshold(self, base_tau: float, meta: Dict[str, Any]) -> float:
"""Adjust threshold based on importance, repetition, and recency."""
tau = base_tau
# Lower threshold for important items (easier to persist)
importance = meta.get('importance', 0.5)
tau -= (importance - 0.5) * 0.5
# Lower threshold for repeated items
repeats = meta.get('repeats', 0)
tau -= min(repeats * 0.1, 0.3)
# Lower threshold for recent items
recency_weight = meta.get('recency_weight', 0.5)
tau -= (recency_weight - 0.5) * 0.2
return max(tau, 0.5) # Minimum threshold
def compute_stability_margin(self, z1: torch.Tensor, z2: torch.Tensor) -> float:
"""Compute margin between stable and unstable trajectories."""
p1 = self.compute_potential(z1)
p2 = self.compute_potential(z2)
return abs(p1 - p2)
class BasinAnalyzer:
"""Analyzes basins of attraction and escape regions."""
def __init__(self, dynamics: IterativeDynamics):
self.dynamics = dynamics
def sample_basin(self, tile: Tile, n_samples: int = 1000) -> Dict[str, Any]:
"""Sample points in tile to analyze basin structure."""
bl, tr = tile.bounds
# Generate random points in tile
real_coords = np.random.uniform(bl.real, tr.real, n_samples)
imag_coords = np.random.uniform(bl.imag, tr.imag, n_samples)
stable_count = 0
plastic_count = 0
escape_count = 0
for i in range(n_samples):
u = complex(real_coords[i], imag_coords[i])
# Create dummy memory item for testing
v = torch.randn(self.dynamics.embedding_dim)
meta = {'importance': 0.5, 'source': 'test'}
result = self.dynamics.iterate_write(tile, v, u, meta)
if result.band == 'stable':
stable_count += 1
elif result.band == 'plastic':
plastic_count += 1
else:
escape_count += 1
return {
'stable_ratio': stable_count / n_samples,
'plastic_ratio': plastic_count / n_samples,
'escape_ratio': escape_count / n_samples,
'total_samples': n_samples
}
def find_basin_boundary(self, tile: Tile, resolution: int = 50) -> np.ndarray:
"""Find approximate basin boundary in tile."""
bl, tr = tile.bounds
real_range = np.linspace(bl.real, tr.real, resolution)
imag_range = np.linspace(bl.imag, tr.imag, resolution)
boundary_map = np.zeros((resolution, resolution))
for i, real_val in enumerate(real_range):
for j, imag_val in enumerate(imag_range):
u = complex(real_val, imag_val)
v = torch.randn(self.dynamics.embedding_dim)
meta = {'importance': 0.5, 'source': 'test'}
result = self.dynamics.iterate_write(tile, v, u, meta)
if result.band == 'stable':
boundary_map[i, j] = 1.0
elif result.band == 'plastic':
boundary_map[i, j] = 0.5
else:
boundary_map[i, j] = 0.0
return boundary_map
class AdaptiveThreshold:
"""Manages adaptive threshold adjustment per tile."""
def __init__(self, base_tau: float = 2.0, adaptation_rate: float = 0.01):
self.base_tau = base_tau
self.adaptation_rate = adaptation_rate
self.tile_thresholds: Dict[str, float] = {}
self.tile_stats: Dict[str, Dict[str, float]] = {}
def get_threshold(self, tile_id: str) -> float:
"""Get current threshold for tile."""
return self.tile_thresholds.get(tile_id, self.base_tau)
def update_threshold(self, tile_id: str, persist_rate: float, target_rate: float = 0.7):
"""Update threshold based on persistence rate."""
current_tau = self.get_threshold(tile_id)
# Adjust threshold to achieve target persistence rate
if persist_rate > target_rate:
# Too many items persisting, raise threshold
new_tau = current_tau + self.adaptation_rate
else:
# Too few items persisting, lower threshold
new_tau = current_tau - self.adaptation_rate
self.tile_thresholds[tile_id] = max(0.5, min(5.0, new_tau))
# Update stats
if tile_id not in self.tile_stats:
self.tile_stats[tile_id] = {}
self.tile_stats[tile_id]['persist_rate'] = persist_rate
self.tile_stats[tile_id]['threshold'] = new_tau
def get_tile_stats(self, tile_id: str) -> Dict[str, float]:
"""Get statistics for tile."""
return self.tile_stats.get(tile_id, {})
|