Kossisoroyce's picture
Upload 10 files
c05fcc5 verified
"""
Core MandelMem system - main interface and orchestration.
Brings together all components into a unified memory architecture.
"""
import torch
import numpy as np
import time
from typing import Dict, Any, List, Optional, Union
from dataclasses import dataclass
from .encoders import MultiModalEncoder
from .quadtree import QuadTree, MemoryItem
from .dynamics import IterativeDynamics, AdaptiveThreshold
from .retrieval import MultiScaleRetriever, ExplainableRetriever
from .renormalization import RenormalizationEngine, MemorySummarizer
from .training import MandelMemTrainer, MandelMemLoss
@dataclass
class WriteResult:
"""Result of memory write operation."""
persisted: bool
tile_id: str
stability_score: float
band: str # 'stable', 'plastic', 'escape'
complex_coord: complex
write_time: float
@dataclass
class ReadResult:
"""Result of memory read operation."""
items: List[MemoryItem]
similarities: List[float]
trace: List[str]
confidence: float
total_time: float
explanation: Optional[Dict[str, Any]] = None
class MandelMem:
"""Main MandelMem system interface."""
def __init__(self,
depth: int = 6,
embedding_dim: int = 768,
max_slots: int = 32,
max_buffer: int = 64,
base_tau: float = 2.0,
delta: float = 0.3):
"""Initialize MandelMem system.
Args:
depth: Quadtree depth
embedding_dim: Vector embedding dimension
max_slots: Maximum persistent slots per tile
max_buffer: Maximum buffer items per tile
base_tau: Base persistence threshold
delta: Plasticity band width
"""
self.depth = depth
self.embedding_dim = embedding_dim
self.base_tau = base_tau
self.delta = delta
# Initialize core components
self.encoder = MultiModalEncoder(embedding_dim)
self.quadtree = QuadTree(depth, embedding_dim)
self.dynamics = IterativeDynamics(embedding_dim)
self.retriever = MultiScaleRetriever(self.quadtree, self.encoder)
self.explainable_retriever = ExplainableRetriever(self.retriever)
# Renormalization system
self.summarizer = MemorySummarizer(embedding_dim)
self.renorm_engine = RenormalizationEngine(self.quadtree, self.summarizer)
# Adaptive thresholds
self.adaptive_threshold = AdaptiveThreshold(base_tau)
# Training system
self.loss_fn = MandelMemLoss()
self.trainer = MandelMemTrainer(
self.quadtree, self.dynamics, self.retriever, self.loss_fn
)
# System state
self.total_writes = 0
self.total_reads = 0
self.policies: Dict[str, Dict[str, Any]] = {}
def write(self, content: str, meta: Optional[Dict[str, Any]] = None) -> WriteResult:
"""Write content to memory with persistence dynamics.
Args:
content: Content to store
meta: Optional metadata (importance, source, etc.)
Returns:
WriteResult with persistence decision and details
"""
start_time = time.time()
if meta is None:
meta = {'importance': 0.5, 'source': 'user'}
# Add timestamp and recency
meta['timestamp'] = start_time
meta['recency_weight'] = 1.0
# Encode content
encoding = self.encoder.encode(content, meta, start_time)
# Route to leaf tile
path = self.quadtree.route_to_leaf(encoding.complex_coord)
leaf_tile = self.quadtree.tiles[path[-1]]
# Get adaptive threshold for this tile
tile_tau = self.adaptive_threshold.get_threshold(leaf_tile.tile_id)
# Perform iterative write
iter_result = self.dynamics.iterate_write(
leaf_tile, encoding.vector, encoding.complex_coord,
meta, tau=tile_tau, delta=self.delta
)
# Create memory item
memory_item = MemoryItem(
vector=encoding.vector,
content=content,
metadata=meta,
timestamp=start_time,
stability_score=1.0 - iter_result.max_potential / 3.0
)
# Store based on persistence decision
if iter_result.persist:
if iter_result.band == 'stable':
leaf_tile.add_to_slots(memory_item)
elif iter_result.band == 'plastic':
leaf_tile.add_to_plastic(memory_item)
else:
leaf_tile.add_to_buffer(memory_item)
# Update statistics
self.total_writes += 1
leaf_tile.stats.write_count += 1
# Check if renormalization is needed
if self._should_renormalize(leaf_tile):
self.renorm_engine.renormalize_tile(leaf_tile)
write_time = time.time() - start_time
return WriteResult(
persisted=iter_result.persist,
tile_id=leaf_tile.tile_id,
stability_score=memory_item.stability_score,
band=iter_result.band,
complex_coord=encoding.complex_coord,
write_time=write_time
)
def read(self, query: str, k: int = 5, with_trace: bool = False,
with_explanation: bool = False) -> ReadResult:
"""Read memories matching query.
Args:
query: Search query
k: Number of results to return
with_trace: Include routing trace
with_explanation: Include detailed explanation
Returns:
ReadResult with retrieved memories and metadata
"""
start_time = time.time()
if with_explanation:
# Use explainable retriever
explanation = self.explainable_retriever.retrieve_with_explanation(query, k)
result = explanation['results']
explanation_data = explanation
else:
# Use standard retriever
result = self.retriever.retrieve(query, k, with_trace)
explanation_data = None
# Update statistics
self.total_reads += 1
for item in result.items:
item.access_count += 1
item.last_access = time.time()
return ReadResult(
items=result.items,
similarities=result.similarities,
trace=result.trace if with_trace else [],
confidence=result.confidence,
total_time=time.time() - start_time,
explanation=explanation_data
)
def set_policy(self, tile_pattern: str, **kwargs):
"""Set policy for tile pattern.
Args:
tile_pattern: Tile pattern (e.g., '/finance/**')
**kwargs: Policy parameters (persist, encrypt, etc.)
"""
self.policies[tile_pattern] = kwargs
# Apply to matching tiles
for tile_id, tile in self.quadtree.tiles.items():
if self._matches_pattern(tile_id, tile_pattern):
tile.policies.update(kwargs)
def set_threshold(self, tile_pattern: str, tau: float, delta: Optional[float] = None):
"""Set persistence threshold for tile pattern.
Args:
tile_pattern: Tile pattern
tau: Persistence threshold
delta: Plasticity band width
"""
for tile_id, tile in self.quadtree.tiles.items():
if self._matches_pattern(tile_id, tile_pattern):
self.adaptive_threshold.tile_thresholds[tile_id] = tau
if delta is not None:
# Store delta in tile policies
tile.policies['delta'] = delta
def get_statistics(self) -> Dict[str, Any]:
"""Get comprehensive system statistics."""
quadtree_stats = self.quadtree.get_statistics()
compression_stats = self.renorm_engine.get_compression_stats()
training_stats = self.trainer.get_training_stats()
# Memory distribution
memory_dist = self._analyze_memory_distribution()
# Performance metrics
perf_metrics = {
'total_writes': self.total_writes,
'total_reads': self.total_reads,
'avg_write_time': self._get_avg_write_time(),
'avg_read_time': self._get_avg_read_time()
}
return {
'quadtree': quadtree_stats,
'compression': compression_stats,
'training': training_stats,
'memory_distribution': memory_dist,
'performance': perf_metrics,
'policies': len(self.policies)
}
def _should_renormalize(self, tile) -> bool:
"""Check if tile needs renormalization."""
# Simple heuristic - can be made more sophisticated
return (len(tile.slots) >= tile.max_slots * 0.9 or
len(tile.buffer) >= tile.max_buffer * 0.9 or
tile.stats.write_count % 1000 == 0)
def _matches_pattern(self, tile_id: str, pattern: str) -> bool:
"""Check if tile ID matches pattern."""
# Simple pattern matching - can be enhanced
if pattern.endswith('/**'):
prefix = pattern[:-3]
return tile_id.startswith(prefix)
elif pattern.endswith('/*'):
prefix = pattern[:-2]
return tile_id.startswith(prefix) and '/' not in tile_id[len(prefix):]
else:
return tile_id == pattern
def _analyze_memory_distribution(self) -> Dict[str, Any]:
"""Analyze memory distribution across tiles."""
leaf_tiles = self.quadtree.get_leaf_tiles()
slot_counts = [len(tile.slots) for tile in leaf_tiles]
buffer_counts = [len(tile.buffer) for tile in leaf_tiles]
plastic_counts = [len(tile.plastic) for tile in leaf_tiles]
return {
'total_leaf_tiles': len(leaf_tiles),
'avg_slots_per_tile': np.mean(slot_counts) if slot_counts else 0,
'avg_buffer_per_tile': np.mean(buffer_counts) if buffer_counts else 0,
'avg_plastic_per_tile': np.mean(plastic_counts) if plastic_counts else 0,
'max_slots_used': max(slot_counts) if slot_counts else 0,
'tiles_at_capacity': sum(1 for c in slot_counts if c >= 30) # Near max_slots
}
def _get_avg_write_time(self) -> float:
"""Get average write time (placeholder)."""
# In a real implementation, you'd track this
return 0.05 # 50ms average
def _get_avg_read_time(self) -> float:
"""Get average read time (placeholder)."""
# In a real implementation, you'd track this
return 0.02 # 20ms average
def export_memory_state(self, filepath: str):
"""Export current memory state to file."""
state = {
'quadtree_tiles': {},
'statistics': self.get_statistics(),
'policies': self.policies,
'thresholds': self.adaptive_threshold.tile_thresholds
}
# Export tile contents (simplified)
for tile_id, tile in self.quadtree.tiles.items():
if tile.get_all_items():
state['quadtree_tiles'][tile_id] = {
'attractor': tile.attractor.tolist(),
'item_count': len(tile.get_all_items()),
'bounds': [str(tile.bounds[0]), str(tile.bounds[1])],
'stats': tile.stats.__dict__
}
import json
with open(filepath, 'w') as f:
json.dump(state, f, indent=2, default=str)
def safe_mode(self, enabled: bool = True):
"""Enable/disable safe mode (higher thresholds, less persistence)."""
if enabled:
# Raise all thresholds
for tile_id in self.quadtree.tiles.keys():
current_tau = self.adaptive_threshold.get_threshold(tile_id)
self.adaptive_threshold.tile_thresholds[tile_id] = current_tau + 0.5
else:
# Reset to base thresholds
self.adaptive_threshold.tile_thresholds.clear()
# Convenience functions for quick usage
def create_mandelmem(config: Optional[Dict[str, Any]] = None) -> MandelMem:
"""Create MandelMem instance with optional configuration."""
if config is None:
config = {}
return MandelMem(
depth=config.get('depth', 6),
embedding_dim=config.get('embedding_dim', 768),
max_slots=config.get('max_slots', 32),
max_buffer=config.get('max_buffer', 64),
base_tau=config.get('base_tau', 2.0),
delta=config.get('delta', 0.3)
)
def quick_demo() -> Dict[str, Any]:
"""Quick demonstration of MandelMem capabilities."""
# Create system
memory = create_mandelmem()
# Write some memories
important_info = "The quarterly revenue increased by 15% due to strong product sales"
casual_note = "Remember to buy milk on the way home"
technical_detail = "The algorithm uses gradient descent with learning rate 0.001"
write_results = []
write_results.append(memory.write(important_info, {'importance': 0.9, 'source': 'user'}))
write_results.append(memory.write(casual_note, {'importance': 0.3, 'source': 'user'}))
write_results.append(memory.write(technical_detail, {'importance': 0.7, 'source': 'system'}))
# Read memories
read_results = []
read_results.append(memory.read("revenue sales", k=3, with_explanation=True))
read_results.append(memory.read("algorithm learning", k=3, with_trace=True))
read_results.append(memory.read("milk shopping", k=3))
# Get statistics
stats = memory.get_statistics()
return {
'write_results': write_results,
'read_results': read_results,
'statistics': stats,
'demo_completed': True
}