Spaces:
Sleeping
Sleeping
| """ | |
| Iterative dynamics for MandelMem system. | |
| Implements Mandelbrot-like iteration for persistence decisions. | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| from typing import Dict, Any, Tuple, Optional | |
| from dataclasses import dataclass | |
| from .quadtree import MemoryItem, Tile | |
| class IterationResult: | |
| """Result of iterative dynamics.""" | |
| persist: bool | |
| max_potential: float | |
| band: str # 'stable', 'plastic', 'escape' | |
| trajectory: torch.Tensor | |
| final_state: torch.Tensor | |
| class IterativeDynamics(nn.Module): | |
| """Implements the iterative map F_θ(z_k, [v, u, meta]) for persistence decisions.""" | |
| def __init__(self, embedding_dim: int = 768, meta_dim: int = 8): | |
| super().__init__() | |
| self.embedding_dim = embedding_dim | |
| self.meta_dim = meta_dim | |
| # Input dimensions: z + v + u (2D) + meta | |
| input_dim = embedding_dim + embedding_dim + 2 + meta_dim | |
| # Iterative map network | |
| self.map_network = nn.Sequential( | |
| nn.Linear(input_dim, embedding_dim * 2), | |
| nn.Tanh(), | |
| nn.Linear(embedding_dim * 2, embedding_dim), | |
| nn.Tanh() | |
| ) | |
| # Potential function for escape detection | |
| self.potential_network = nn.Sequential( | |
| nn.Linear(embedding_dim, embedding_dim // 2), | |
| nn.ReLU(), | |
| nn.Linear(embedding_dim // 2, 1), | |
| nn.Softplus() # Ensure positive potential | |
| ) | |
| def encode_metadata(self, meta: Dict[str, Any]) -> torch.Tensor: | |
| """Encode metadata to fixed-size vector.""" | |
| features = [] | |
| # Importance | |
| features.append(meta.get('importance', 0.5)) | |
| # Source (one-hot) | |
| source_types = ['user', 'system', 'external', 'generated'] | |
| source = meta.get('source', 'user') | |
| for s in source_types: | |
| features.append(1.0 if s == source else 0.0) | |
| # PII flag | |
| features.append(1.0 if meta.get('pii', False) else 0.0) | |
| # Repetition count | |
| features.append(min(meta.get('repeats', 0) / 10.0, 1.0)) | |
| # Recency (normalized) | |
| features.append(meta.get('recency_weight', 0.5)) | |
| return torch.tensor(features[:self.meta_dim], dtype=torch.float32) | |
| def iterate_step(self, z: torch.Tensor, v: torch.Tensor, u: complex, | |
| meta: Dict[str, Any], tile_params: Optional[torch.Tensor] = None) -> torch.Tensor: | |
| """Single iteration step of the map.""" | |
| # Prepare input | |
| u_vec = torch.tensor([u.real, u.imag], dtype=torch.float32) | |
| meta_vec = self.encode_metadata(meta) | |
| # Concatenate all inputs | |
| input_vec = torch.cat([z, v, u_vec, meta_vec]) | |
| # Apply map | |
| z_next = self.map_network(input_vec) | |
| # Apply tile-specific parameters if available (simplified) | |
| if tile_params is not None and tile_params.numel() > 0: | |
| # Simple additive bias instead of matrix multiplication | |
| bias = torch.mean(tile_params, dim=0) if tile_params.dim() > 1 else tile_params | |
| if bias.size(0) == z_next.size(0): | |
| z_next = z_next + bias | |
| return z_next | |
| def compute_potential(self, z: torch.Tensor) -> float: | |
| """Compute escape potential P(z).""" | |
| with torch.no_grad(): | |
| potential = self.potential_network(z) | |
| return potential.item() | |
| def iterate_write(self, tile: Tile, v: torch.Tensor, u: complex, | |
| meta: Dict[str, Any], K: int = 8, tau: float = 2.0, | |
| delta: float = 0.3) -> IterationResult: | |
| """Full iterative write process as described in the blueprint.""" | |
| # Initialize with tile attractor | |
| z = tile.attractor.clone() | |
| trajectory = [z.clone()] | |
| max_potential = 0.0 | |
| # Apply repetition and recency adjustments to threshold | |
| effective_tau = self._adjust_threshold(tau, meta) | |
| for k in range(K): | |
| # Iteration step | |
| z = self.iterate_step(z, v, u, meta, tile.local_params) | |
| trajectory.append(z.clone()) | |
| # Compute potential | |
| potential = self.compute_potential(z) | |
| max_potential = max(max_potential, potential) | |
| # Early escape detection | |
| if potential > effective_tau + delta: | |
| return IterationResult( | |
| persist=False, | |
| max_potential=max_potential, | |
| band='escape', | |
| trajectory=torch.stack(trajectory), | |
| final_state=z | |
| ) | |
| # Determine final state based on potential | |
| if max_potential <= effective_tau - delta: | |
| # Stable - commit to slots | |
| band = 'stable' | |
| persist = True | |
| elif max_potential <= effective_tau + delta: | |
| # Plastic boundary band | |
| band = 'plastic' | |
| persist = True | |
| else: | |
| # Escaped | |
| band = 'escape' | |
| persist = False | |
| return IterationResult( | |
| persist=persist, | |
| max_potential=max_potential, | |
| band=band, | |
| trajectory=torch.stack(trajectory), | |
| final_state=z | |
| ) | |
| def _adjust_threshold(self, base_tau: float, meta: Dict[str, Any]) -> float: | |
| """Adjust threshold based on importance, repetition, and recency.""" | |
| tau = base_tau | |
| # Lower threshold for important items (easier to persist) | |
| importance = meta.get('importance', 0.5) | |
| tau -= (importance - 0.5) * 0.5 | |
| # Lower threshold for repeated items | |
| repeats = meta.get('repeats', 0) | |
| tau -= min(repeats * 0.1, 0.3) | |
| # Lower threshold for recent items | |
| recency_weight = meta.get('recency_weight', 0.5) | |
| tau -= (recency_weight - 0.5) * 0.2 | |
| return max(tau, 0.5) # Minimum threshold | |
| def compute_stability_margin(self, z1: torch.Tensor, z2: torch.Tensor) -> float: | |
| """Compute margin between stable and unstable trajectories.""" | |
| p1 = self.compute_potential(z1) | |
| p2 = self.compute_potential(z2) | |
| return abs(p1 - p2) | |
| class BasinAnalyzer: | |
| """Analyzes basins of attraction and escape regions.""" | |
| def __init__(self, dynamics: IterativeDynamics): | |
| self.dynamics = dynamics | |
| def sample_basin(self, tile: Tile, n_samples: int = 1000) -> Dict[str, Any]: | |
| """Sample points in tile to analyze basin structure.""" | |
| bl, tr = tile.bounds | |
| # Generate random points in tile | |
| real_coords = np.random.uniform(bl.real, tr.real, n_samples) | |
| imag_coords = np.random.uniform(bl.imag, tr.imag, n_samples) | |
| stable_count = 0 | |
| plastic_count = 0 | |
| escape_count = 0 | |
| for i in range(n_samples): | |
| u = complex(real_coords[i], imag_coords[i]) | |
| # Create dummy memory item for testing | |
| v = torch.randn(self.dynamics.embedding_dim) | |
| meta = {'importance': 0.5, 'source': 'test'} | |
| result = self.dynamics.iterate_write(tile, v, u, meta) | |
| if result.band == 'stable': | |
| stable_count += 1 | |
| elif result.band == 'plastic': | |
| plastic_count += 1 | |
| else: | |
| escape_count += 1 | |
| return { | |
| 'stable_ratio': stable_count / n_samples, | |
| 'plastic_ratio': plastic_count / n_samples, | |
| 'escape_ratio': escape_count / n_samples, | |
| 'total_samples': n_samples | |
| } | |
| def find_basin_boundary(self, tile: Tile, resolution: int = 50) -> np.ndarray: | |
| """Find approximate basin boundary in tile.""" | |
| bl, tr = tile.bounds | |
| real_range = np.linspace(bl.real, tr.real, resolution) | |
| imag_range = np.linspace(bl.imag, tr.imag, resolution) | |
| boundary_map = np.zeros((resolution, resolution)) | |
| for i, real_val in enumerate(real_range): | |
| for j, imag_val in enumerate(imag_range): | |
| u = complex(real_val, imag_val) | |
| v = torch.randn(self.dynamics.embedding_dim) | |
| meta = {'importance': 0.5, 'source': 'test'} | |
| result = self.dynamics.iterate_write(tile, v, u, meta) | |
| if result.band == 'stable': | |
| boundary_map[i, j] = 1.0 | |
| elif result.band == 'plastic': | |
| boundary_map[i, j] = 0.5 | |
| else: | |
| boundary_map[i, j] = 0.0 | |
| return boundary_map | |
| class AdaptiveThreshold: | |
| """Manages adaptive threshold adjustment per tile.""" | |
| def __init__(self, base_tau: float = 2.0, adaptation_rate: float = 0.01): | |
| self.base_tau = base_tau | |
| self.adaptation_rate = adaptation_rate | |
| self.tile_thresholds: Dict[str, float] = {} | |
| self.tile_stats: Dict[str, Dict[str, float]] = {} | |
| def get_threshold(self, tile_id: str) -> float: | |
| """Get current threshold for tile.""" | |
| return self.tile_thresholds.get(tile_id, self.base_tau) | |
| def update_threshold(self, tile_id: str, persist_rate: float, target_rate: float = 0.7): | |
| """Update threshold based on persistence rate.""" | |
| current_tau = self.get_threshold(tile_id) | |
| # Adjust threshold to achieve target persistence rate | |
| if persist_rate > target_rate: | |
| # Too many items persisting, raise threshold | |
| new_tau = current_tau + self.adaptation_rate | |
| else: | |
| # Too few items persisting, lower threshold | |
| new_tau = current_tau - self.adaptation_rate | |
| self.tile_thresholds[tile_id] = max(0.5, min(5.0, new_tau)) | |
| # Update stats | |
| if tile_id not in self.tile_stats: | |
| self.tile_stats[tile_id] = {} | |
| self.tile_stats[tile_id]['persist_rate'] = persist_rate | |
| self.tile_stats[tile_id]['threshold'] = new_tau | |
| def get_tile_stats(self, tile_id: str) -> Dict[str, float]: | |
| """Get statistics for tile.""" | |
| return self.tile_stats.get(tile_id, {}) | |