Spaces:
Sleeping
Sleeping
| """ | |
| Core MandelMem system - main interface and orchestration. | |
| Brings together all components into a unified memory architecture. | |
| """ | |
| import torch | |
| import numpy as np | |
| import time | |
| from typing import Dict, Any, List, Optional, Union | |
| from dataclasses import dataclass | |
| from .encoders import MultiModalEncoder | |
| from .quadtree import QuadTree, MemoryItem | |
| from .dynamics import IterativeDynamics, AdaptiveThreshold | |
| from .retrieval import MultiScaleRetriever, ExplainableRetriever | |
| from .renormalization import RenormalizationEngine, MemorySummarizer | |
| from .training import MandelMemTrainer, MandelMemLoss | |
| class WriteResult: | |
| """Result of memory write operation.""" | |
| persisted: bool | |
| tile_id: str | |
| stability_score: float | |
| band: str # 'stable', 'plastic', 'escape' | |
| complex_coord: complex | |
| write_time: float | |
| class ReadResult: | |
| """Result of memory read operation.""" | |
| items: List[MemoryItem] | |
| similarities: List[float] | |
| trace: List[str] | |
| confidence: float | |
| total_time: float | |
| explanation: Optional[Dict[str, Any]] = None | |
| class MandelMem: | |
| """Main MandelMem system interface.""" | |
| def __init__(self, | |
| depth: int = 6, | |
| embedding_dim: int = 768, | |
| max_slots: int = 32, | |
| max_buffer: int = 64, | |
| base_tau: float = 2.0, | |
| delta: float = 0.3): | |
| """Initialize MandelMem system. | |
| Args: | |
| depth: Quadtree depth | |
| embedding_dim: Vector embedding dimension | |
| max_slots: Maximum persistent slots per tile | |
| max_buffer: Maximum buffer items per tile | |
| base_tau: Base persistence threshold | |
| delta: Plasticity band width | |
| """ | |
| self.depth = depth | |
| self.embedding_dim = embedding_dim | |
| self.base_tau = base_tau | |
| self.delta = delta | |
| # Initialize core components | |
| self.encoder = MultiModalEncoder(embedding_dim) | |
| self.quadtree = QuadTree(depth, embedding_dim) | |
| self.dynamics = IterativeDynamics(embedding_dim) | |
| self.retriever = MultiScaleRetriever(self.quadtree, self.encoder) | |
| self.explainable_retriever = ExplainableRetriever(self.retriever) | |
| # Renormalization system | |
| self.summarizer = MemorySummarizer(embedding_dim) | |
| self.renorm_engine = RenormalizationEngine(self.quadtree, self.summarizer) | |
| # Adaptive thresholds | |
| self.adaptive_threshold = AdaptiveThreshold(base_tau) | |
| # Training system | |
| self.loss_fn = MandelMemLoss() | |
| self.trainer = MandelMemTrainer( | |
| self.quadtree, self.dynamics, self.retriever, self.loss_fn | |
| ) | |
| # System state | |
| self.total_writes = 0 | |
| self.total_reads = 0 | |
| self.policies: Dict[str, Dict[str, Any]] = {} | |
| def write(self, content: str, meta: Optional[Dict[str, Any]] = None) -> WriteResult: | |
| """Write content to memory with persistence dynamics. | |
| Args: | |
| content: Content to store | |
| meta: Optional metadata (importance, source, etc.) | |
| Returns: | |
| WriteResult with persistence decision and details | |
| """ | |
| start_time = time.time() | |
| if meta is None: | |
| meta = {'importance': 0.5, 'source': 'user'} | |
| # Add timestamp and recency | |
| meta['timestamp'] = start_time | |
| meta['recency_weight'] = 1.0 | |
| # Encode content | |
| encoding = self.encoder.encode(content, meta, start_time) | |
| # Route to leaf tile | |
| path = self.quadtree.route_to_leaf(encoding.complex_coord) | |
| leaf_tile = self.quadtree.tiles[path[-1]] | |
| # Get adaptive threshold for this tile | |
| tile_tau = self.adaptive_threshold.get_threshold(leaf_tile.tile_id) | |
| # Perform iterative write | |
| iter_result = self.dynamics.iterate_write( | |
| leaf_tile, encoding.vector, encoding.complex_coord, | |
| meta, tau=tile_tau, delta=self.delta | |
| ) | |
| # Create memory item | |
| memory_item = MemoryItem( | |
| vector=encoding.vector, | |
| content=content, | |
| metadata=meta, | |
| timestamp=start_time, | |
| stability_score=1.0 - iter_result.max_potential / 3.0 | |
| ) | |
| # Store based on persistence decision | |
| if iter_result.persist: | |
| if iter_result.band == 'stable': | |
| leaf_tile.add_to_slots(memory_item) | |
| elif iter_result.band == 'plastic': | |
| leaf_tile.add_to_plastic(memory_item) | |
| else: | |
| leaf_tile.add_to_buffer(memory_item) | |
| # Update statistics | |
| self.total_writes += 1 | |
| leaf_tile.stats.write_count += 1 | |
| # Check if renormalization is needed | |
| if self._should_renormalize(leaf_tile): | |
| self.renorm_engine.renormalize_tile(leaf_tile) | |
| write_time = time.time() - start_time | |
| return WriteResult( | |
| persisted=iter_result.persist, | |
| tile_id=leaf_tile.tile_id, | |
| stability_score=memory_item.stability_score, | |
| band=iter_result.band, | |
| complex_coord=encoding.complex_coord, | |
| write_time=write_time | |
| ) | |
| def read(self, query: str, k: int = 5, with_trace: bool = False, | |
| with_explanation: bool = False) -> ReadResult: | |
| """Read memories matching query. | |
| Args: | |
| query: Search query | |
| k: Number of results to return | |
| with_trace: Include routing trace | |
| with_explanation: Include detailed explanation | |
| Returns: | |
| ReadResult with retrieved memories and metadata | |
| """ | |
| start_time = time.time() | |
| if with_explanation: | |
| # Use explainable retriever | |
| explanation = self.explainable_retriever.retrieve_with_explanation(query, k) | |
| result = explanation['results'] | |
| explanation_data = explanation | |
| else: | |
| # Use standard retriever | |
| result = self.retriever.retrieve(query, k, with_trace) | |
| explanation_data = None | |
| # Update statistics | |
| self.total_reads += 1 | |
| for item in result.items: | |
| item.access_count += 1 | |
| item.last_access = time.time() | |
| return ReadResult( | |
| items=result.items, | |
| similarities=result.similarities, | |
| trace=result.trace if with_trace else [], | |
| confidence=result.confidence, | |
| total_time=time.time() - start_time, | |
| explanation=explanation_data | |
| ) | |
| def set_policy(self, tile_pattern: str, **kwargs): | |
| """Set policy for tile pattern. | |
| Args: | |
| tile_pattern: Tile pattern (e.g., '/finance/**') | |
| **kwargs: Policy parameters (persist, encrypt, etc.) | |
| """ | |
| self.policies[tile_pattern] = kwargs | |
| # Apply to matching tiles | |
| for tile_id, tile in self.quadtree.tiles.items(): | |
| if self._matches_pattern(tile_id, tile_pattern): | |
| tile.policies.update(kwargs) | |
| def set_threshold(self, tile_pattern: str, tau: float, delta: Optional[float] = None): | |
| """Set persistence threshold for tile pattern. | |
| Args: | |
| tile_pattern: Tile pattern | |
| tau: Persistence threshold | |
| delta: Plasticity band width | |
| """ | |
| for tile_id, tile in self.quadtree.tiles.items(): | |
| if self._matches_pattern(tile_id, tile_pattern): | |
| self.adaptive_threshold.tile_thresholds[tile_id] = tau | |
| if delta is not None: | |
| # Store delta in tile policies | |
| tile.policies['delta'] = delta | |
| def get_statistics(self) -> Dict[str, Any]: | |
| """Get comprehensive system statistics.""" | |
| quadtree_stats = self.quadtree.get_statistics() | |
| compression_stats = self.renorm_engine.get_compression_stats() | |
| training_stats = self.trainer.get_training_stats() | |
| # Memory distribution | |
| memory_dist = self._analyze_memory_distribution() | |
| # Performance metrics | |
| perf_metrics = { | |
| 'total_writes': self.total_writes, | |
| 'total_reads': self.total_reads, | |
| 'avg_write_time': self._get_avg_write_time(), | |
| 'avg_read_time': self._get_avg_read_time() | |
| } | |
| return { | |
| 'quadtree': quadtree_stats, | |
| 'compression': compression_stats, | |
| 'training': training_stats, | |
| 'memory_distribution': memory_dist, | |
| 'performance': perf_metrics, | |
| 'policies': len(self.policies) | |
| } | |
| def _should_renormalize(self, tile) -> bool: | |
| """Check if tile needs renormalization.""" | |
| # Simple heuristic - can be made more sophisticated | |
| return (len(tile.slots) >= tile.max_slots * 0.9 or | |
| len(tile.buffer) >= tile.max_buffer * 0.9 or | |
| tile.stats.write_count % 1000 == 0) | |
| def _matches_pattern(self, tile_id: str, pattern: str) -> bool: | |
| """Check if tile ID matches pattern.""" | |
| # Simple pattern matching - can be enhanced | |
| if pattern.endswith('/**'): | |
| prefix = pattern[:-3] | |
| return tile_id.startswith(prefix) | |
| elif pattern.endswith('/*'): | |
| prefix = pattern[:-2] | |
| return tile_id.startswith(prefix) and '/' not in tile_id[len(prefix):] | |
| else: | |
| return tile_id == pattern | |
| def _analyze_memory_distribution(self) -> Dict[str, Any]: | |
| """Analyze memory distribution across tiles.""" | |
| leaf_tiles = self.quadtree.get_leaf_tiles() | |
| slot_counts = [len(tile.slots) for tile in leaf_tiles] | |
| buffer_counts = [len(tile.buffer) for tile in leaf_tiles] | |
| plastic_counts = [len(tile.plastic) for tile in leaf_tiles] | |
| return { | |
| 'total_leaf_tiles': len(leaf_tiles), | |
| 'avg_slots_per_tile': np.mean(slot_counts) if slot_counts else 0, | |
| 'avg_buffer_per_tile': np.mean(buffer_counts) if buffer_counts else 0, | |
| 'avg_plastic_per_tile': np.mean(plastic_counts) if plastic_counts else 0, | |
| 'max_slots_used': max(slot_counts) if slot_counts else 0, | |
| 'tiles_at_capacity': sum(1 for c in slot_counts if c >= 30) # Near max_slots | |
| } | |
| def _get_avg_write_time(self) -> float: | |
| """Get average write time (placeholder).""" | |
| # In a real implementation, you'd track this | |
| return 0.05 # 50ms average | |
| def _get_avg_read_time(self) -> float: | |
| """Get average read time (placeholder).""" | |
| # In a real implementation, you'd track this | |
| return 0.02 # 20ms average | |
| def export_memory_state(self, filepath: str): | |
| """Export current memory state to file.""" | |
| state = { | |
| 'quadtree_tiles': {}, | |
| 'statistics': self.get_statistics(), | |
| 'policies': self.policies, | |
| 'thresholds': self.adaptive_threshold.tile_thresholds | |
| } | |
| # Export tile contents (simplified) | |
| for tile_id, tile in self.quadtree.tiles.items(): | |
| if tile.get_all_items(): | |
| state['quadtree_tiles'][tile_id] = { | |
| 'attractor': tile.attractor.tolist(), | |
| 'item_count': len(tile.get_all_items()), | |
| 'bounds': [str(tile.bounds[0]), str(tile.bounds[1])], | |
| 'stats': tile.stats.__dict__ | |
| } | |
| import json | |
| with open(filepath, 'w') as f: | |
| json.dump(state, f, indent=2, default=str) | |
| def safe_mode(self, enabled: bool = True): | |
| """Enable/disable safe mode (higher thresholds, less persistence).""" | |
| if enabled: | |
| # Raise all thresholds | |
| for tile_id in self.quadtree.tiles.keys(): | |
| current_tau = self.adaptive_threshold.get_threshold(tile_id) | |
| self.adaptive_threshold.tile_thresholds[tile_id] = current_tau + 0.5 | |
| else: | |
| # Reset to base thresholds | |
| self.adaptive_threshold.tile_thresholds.clear() | |
| # Convenience functions for quick usage | |
| def create_mandelmem(config: Optional[Dict[str, Any]] = None) -> MandelMem: | |
| """Create MandelMem instance with optional configuration.""" | |
| if config is None: | |
| config = {} | |
| return MandelMem( | |
| depth=config.get('depth', 6), | |
| embedding_dim=config.get('embedding_dim', 768), | |
| max_slots=config.get('max_slots', 32), | |
| max_buffer=config.get('max_buffer', 64), | |
| base_tau=config.get('base_tau', 2.0), | |
| delta=config.get('delta', 0.3) | |
| ) | |
| def quick_demo() -> Dict[str, Any]: | |
| """Quick demonstration of MandelMem capabilities.""" | |
| # Create system | |
| memory = create_mandelmem() | |
| # Write some memories | |
| important_info = "The quarterly revenue increased by 15% due to strong product sales" | |
| casual_note = "Remember to buy milk on the way home" | |
| technical_detail = "The algorithm uses gradient descent with learning rate 0.001" | |
| write_results = [] | |
| write_results.append(memory.write(important_info, {'importance': 0.9, 'source': 'user'})) | |
| write_results.append(memory.write(casual_note, {'importance': 0.3, 'source': 'user'})) | |
| write_results.append(memory.write(technical_detail, {'importance': 0.7, 'source': 'system'})) | |
| # Read memories | |
| read_results = [] | |
| read_results.append(memory.read("revenue sales", k=3, with_explanation=True)) | |
| read_results.append(memory.read("algorithm learning", k=3, with_trace=True)) | |
| read_results.append(memory.read("milk shopping", k=3)) | |
| # Get statistics | |
| stats = memory.get_statistics() | |
| return { | |
| 'write_results': write_results, | |
| 'read_results': read_results, | |
| 'statistics': stats, | |
| 'demo_completed': True | |
| } | |