Spaces:
Sleeping
Sleeping
| """ | |
| Renormalization and compression system for MandelMem. | |
| Handles periodic compression, merging, and summarization of memories. | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| from typing import List, Dict, Any, Tuple, Optional | |
| from dataclasses import dataclass | |
| import time | |
| from collections import defaultdict | |
| from .quadtree import QuadTree, Tile, MemoryItem | |
| class CompressionResult: | |
| """Result of compression operation.""" | |
| items_merged: int | |
| items_summarized: int | |
| items_promoted: int | |
| items_pruned: int | |
| compression_ratio: float | |
| quality_preserved: float | |
| class MemorySummarizer(nn.Module): | |
| """Neural network for creating compact memory sketches.""" | |
| def __init__(self, embedding_dim: int = 768, sketch_dim: int = 256): | |
| super().__init__() | |
| self.embedding_dim = embedding_dim | |
| self.sketch_dim = sketch_dim | |
| # Encoder for creating sketches | |
| self.sketch_encoder = nn.Sequential( | |
| nn.Linear(embedding_dim, embedding_dim // 2), | |
| nn.ReLU(), | |
| nn.Linear(embedding_dim // 2, sketch_dim), | |
| nn.Tanh() | |
| ) | |
| # Decoder for reconstructing from sketches | |
| self.sketch_decoder = nn.Sequential( | |
| nn.Linear(sketch_dim, embedding_dim // 2), | |
| nn.ReLU(), | |
| nn.Linear(embedding_dim // 2, embedding_dim), | |
| nn.Tanh() | |
| ) | |
| def create_sketch(self, vectors: torch.Tensor) -> torch.Tensor: | |
| """Create compact sketch from multiple vectors.""" | |
| if vectors.dim() == 1: | |
| vectors = vectors.unsqueeze(0) | |
| # Average pooling followed by compression | |
| pooled = torch.mean(vectors, dim=0, keepdim=True) | |
| sketch = self.sketch_encoder(pooled) | |
| return sketch.squeeze(0) | |
| def reconstruct_from_sketch(self, sketch: torch.Tensor) -> torch.Tensor: | |
| """Reconstruct vector from sketch.""" | |
| if sketch.dim() == 1: | |
| sketch = sketch.unsqueeze(0) | |
| reconstructed = self.sketch_decoder(sketch) | |
| return reconstructed.squeeze(0) | |
| def compute_reconstruction_loss(self, original: torch.Tensor, | |
| reconstructed: torch.Tensor) -> float: | |
| """Compute reconstruction quality.""" | |
| mse_loss = torch.nn.functional.mse_loss(original, reconstructed) | |
| cosine_sim = torch.nn.functional.cosine_similarity( | |
| original.unsqueeze(0), reconstructed.unsqueeze(0) | |
| ).item() | |
| return mse_loss.item(), cosine_sim | |
| class RenormalizationEngine: | |
| """Main engine for memory compression and renormalization.""" | |
| def __init__(self, quadtree: QuadTree, summarizer: MemorySummarizer, | |
| merge_threshold: float = 0.9, sketch_threshold: int = 100): | |
| self.quadtree = quadtree | |
| self.summarizer = summarizer | |
| self.merge_threshold = merge_threshold | |
| self.sketch_threshold = sketch_threshold | |
| # Track renormalization history | |
| self.renorm_history: Dict[str, List[Dict[str, Any]]] = defaultdict(list) | |
| def renormalize_tile(self, tile: Tile, preserve_quality: bool = True) -> CompressionResult: | |
| """Perform full renormalization on a single tile.""" | |
| start_time = time.time() | |
| initial_count = len(tile.get_all_items()) | |
| # Step 1: Merge near-duplicate slots | |
| merged_count = self._merge_duplicates(tile) | |
| # Step 2: Summarize long tail into sketches | |
| summarized_count = self._create_sketches(tile) | |
| # Step 3: Promote stable short-term items | |
| promoted_count = self._promote_stable_items(tile) | |
| # Step 4: Prune consistently unstable entries | |
| pruned_count = self._prune_unstable_items(tile) | |
| # Step 5: Update attractor | |
| self._update_attractor_post_renorm(tile) | |
| final_count = len(tile.get_all_items()) | |
| compression_ratio = final_count / initial_count if initial_count > 0 else 1.0 | |
| # Measure quality preservation if requested | |
| quality_preserved = 1.0 | |
| if preserve_quality: | |
| quality_preserved = self._measure_quality_preservation(tile) | |
| result = CompressionResult( | |
| items_merged=merged_count, | |
| items_summarized=summarized_count, | |
| items_promoted=promoted_count, | |
| items_pruned=pruned_count, | |
| compression_ratio=compression_ratio, | |
| quality_preserved=quality_preserved | |
| ) | |
| # Record renormalization | |
| self._record_renormalization(tile, result, time.time() - start_time) | |
| return result | |
| def _merge_duplicates(self, tile: Tile) -> int: | |
| """Merge near-duplicate items in slots.""" | |
| if len(tile.slots) < 2: | |
| return 0 | |
| merged_count = 0 | |
| items_to_remove = set() | |
| # Compare all pairs of items | |
| for i, item1 in enumerate(tile.slots): | |
| if i in items_to_remove: | |
| continue | |
| for j, item2 in enumerate(tile.slots[i+1:], i+1): | |
| if j in items_to_remove: | |
| continue | |
| # Calculate similarity | |
| similarity = torch.cosine_similarity( | |
| item1.vector.unsqueeze(0), | |
| item2.vector.unsqueeze(0) | |
| ).item() | |
| if similarity > self.merge_threshold: | |
| # Merge items | |
| merged_item = self._merge_items(item1, item2) | |
| tile.slots[i] = merged_item | |
| items_to_remove.add(j) | |
| merged_count += 1 | |
| # Remove merged items | |
| tile.slots = [item for i, item in enumerate(tile.slots) | |
| if i not in items_to_remove] | |
| return merged_count | |
| def _merge_items(self, item1: MemoryItem, item2: MemoryItem) -> MemoryItem: | |
| """Merge two similar memory items.""" | |
| # Weighted average of vectors based on stability | |
| w1 = item1.stability_score | |
| w2 = item2.stability_score | |
| total_weight = w1 + w2 | |
| if total_weight > 0: | |
| merged_vector = (w1 * item1.vector + w2 * item2.vector) / total_weight | |
| else: | |
| merged_vector = (item1.vector + item2.vector) / 2 | |
| # Combine content (keep more important one) | |
| if item1.metadata.get('importance', 0.5) >= item2.metadata.get('importance', 0.5): | |
| primary_content = item1.content | |
| primary_meta = item1.metadata.copy() | |
| else: | |
| primary_content = item2.content | |
| primary_meta = item2.metadata.copy() | |
| # Update metadata | |
| primary_meta['merged_from'] = [item1.content[:50], item2.content[:50]] | |
| primary_meta['merge_timestamp'] = time.time() | |
| return MemoryItem( | |
| vector=merged_vector, | |
| content=primary_content, | |
| metadata=primary_meta, | |
| timestamp=max(item1.timestamp, item2.timestamp), | |
| stability_score=max(item1.stability_score, item2.stability_score), | |
| access_count=item1.access_count + item2.access_count | |
| ) | |
| def _create_sketches(self, tile: Tile) -> int: | |
| """Create sketches for long-tail items.""" | |
| if len(tile.slots) < self.sketch_threshold: | |
| return 0 | |
| # Sort by access frequency and stability | |
| sorted_items = sorted(tile.slots, | |
| key=lambda x: (x.access_count, x.stability_score), | |
| reverse=True) | |
| # Keep top items, sketch the rest | |
| keep_count = self.sketch_threshold // 2 | |
| to_keep = sorted_items[:keep_count] | |
| to_sketch = sorted_items[keep_count:] | |
| if not to_sketch: | |
| return 0 | |
| # Create sketches for groups of similar items | |
| sketch_groups = self._group_for_sketching(to_sketch) | |
| sketched_count = 0 | |
| for group in sketch_groups: | |
| if len(group) > 1: | |
| # Create sketch | |
| vectors = torch.stack([item.vector for item in group]) | |
| sketch = self.summarizer.create_sketch(vectors) | |
| # Create sketch item | |
| sketch_content = f"[SKETCH of {len(group)} items: " + \ | |
| ", ".join([item.content[:20] for item in group[:3]]) + \ | |
| ("..." if len(group) > 3 else "") + "]" | |
| sketch_meta = { | |
| 'type': 'sketch', | |
| 'original_count': len(group), | |
| 'sketch_timestamp': time.time(), | |
| 'original_items': [item.content[:50] for item in group] | |
| } | |
| sketch_item = MemoryItem( | |
| vector=sketch, | |
| content=sketch_content, | |
| metadata=sketch_meta, | |
| timestamp=max(item.timestamp for item in group), | |
| stability_score=np.mean([item.stability_score for item in group]) | |
| ) | |
| to_keep.append(sketch_item) | |
| sketched_count += len(group) | |
| tile.slots = to_keep | |
| return sketched_count | |
| def _group_for_sketching(self, items: List[MemoryItem], | |
| similarity_threshold: float = 0.7) -> List[List[MemoryItem]]: | |
| """Group similar items for sketching.""" | |
| groups = [] | |
| ungrouped = items.copy() | |
| while ungrouped: | |
| # Start new group with first ungrouped item | |
| current_group = [ungrouped.pop(0)] | |
| # Find similar items to add to group | |
| i = 0 | |
| while i < len(ungrouped): | |
| item = ungrouped[i] | |
| # Check similarity with group centroid | |
| group_vectors = torch.stack([g_item.vector for g_item in current_group]) | |
| group_centroid = torch.mean(group_vectors, dim=0) | |
| similarity = torch.cosine_similarity( | |
| item.vector.unsqueeze(0), | |
| group_centroid.unsqueeze(0) | |
| ).item() | |
| if similarity > similarity_threshold: | |
| current_group.append(ungrouped.pop(i)) | |
| else: | |
| i += 1 | |
| groups.append(current_group) | |
| return groups | |
| def _promote_stable_items(self, tile: Tile) -> int: | |
| """Promote stable items from buffer to slots.""" | |
| promoted_count = 0 | |
| items_to_promote = [] | |
| for item in tile.buffer: | |
| # Check if item has become stable through repetition | |
| if (item.access_count >= 3 and | |
| item.stability_score > 0.6 and | |
| len(tile.slots) < tile.max_slots): | |
| items_to_promote.append(item) | |
| # Promote items | |
| for item in items_to_promote: | |
| tile.buffer.remove(item) | |
| tile.slots.append(item) | |
| promoted_count += 1 | |
| return promoted_count | |
| def _prune_unstable_items(self, tile: Tile) -> int: | |
| """Remove consistently unstable items.""" | |
| pruned_count = 0 | |
| current_time = time.time() | |
| # Prune from buffer | |
| stable_buffer = [] | |
| for item in tile.buffer: | |
| age = current_time - item.timestamp | |
| if (item.stability_score < 0.2 and | |
| age > tile.half_life and | |
| item.access_count < 2): | |
| pruned_count += 1 | |
| else: | |
| stable_buffer.append(item) | |
| tile.buffer = stable_buffer | |
| # Prune from plastic band | |
| stable_plastic = [] | |
| for item in tile.plastic: | |
| if item.stability_score > 0.3 or item.access_count >= 2: | |
| stable_plastic.append(item) | |
| else: | |
| pruned_count += 1 | |
| tile.plastic = stable_plastic | |
| return pruned_count | |
| def _update_attractor_post_renorm(self, tile: Tile): | |
| """Update tile attractor after renormalization.""" | |
| all_items = tile.get_all_items() | |
| if not all_items: | |
| return | |
| # Weighted average based on stability and access | |
| total_weight = 0 | |
| weighted_sum = torch.zeros_like(tile.attractor) | |
| for item in all_items: | |
| weight = item.stability_score * (1 + np.log(1 + item.access_count)) | |
| weighted_sum += weight * item.vector | |
| total_weight += weight | |
| if total_weight > 0: | |
| tile.attractor = weighted_sum / total_weight | |
| def _measure_quality_preservation(self, tile: Tile) -> float: | |
| """Measure how well renormalization preserved retrieval quality.""" | |
| # This is a simplified quality measure | |
| # In practice, you'd want to test retrieval performance before/after | |
| all_items = tile.get_all_items() | |
| if not all_items: | |
| return 1.0 | |
| # Check diversity of remaining items | |
| if len(all_items) < 2: | |
| return 0.5 | |
| similarities = [] | |
| for i, item1 in enumerate(all_items): | |
| for item2 in all_items[i+1:]: | |
| sim = torch.cosine_similarity( | |
| item1.vector.unsqueeze(0), | |
| item2.vector.unsqueeze(0) | |
| ).item() | |
| similarities.append(sim) | |
| # Good quality = diverse items (low average similarity) | |
| avg_similarity = np.mean(similarities) | |
| diversity_score = 1.0 - avg_similarity | |
| return max(0.0, min(1.0, diversity_score)) | |
| def _record_renormalization(self, tile: Tile, result: CompressionResult, duration: float): | |
| """Record renormalization event for tracking.""" | |
| record = { | |
| 'timestamp': time.time(), | |
| 'duration': duration, | |
| 'result': result, | |
| 'tile_stats': { | |
| 'total_items': len(tile.get_all_items()), | |
| 'slots': len(tile.slots), | |
| 'buffer': len(tile.buffer), | |
| 'plastic': len(tile.plastic) | |
| } | |
| } | |
| self.renorm_history[tile.tile_id].append(record) | |
| tile.stats.last_renorm = record['timestamp'] | |
| def schedule_renormalization(self, cadence_hours: int = 24, | |
| write_threshold: int = 10000) -> List[str]: | |
| """Schedule tiles for renormalization based on criteria.""" | |
| current_time = time.time() | |
| tiles_to_renorm = [] | |
| for tile_id, tile in self.quadtree.tiles.items(): | |
| # Check time-based criterion | |
| time_since_renorm = current_time - tile.stats.last_renorm | |
| time_criterion = time_since_renorm > (cadence_hours * 3600) | |
| # Check write-based criterion | |
| write_criterion = tile.stats.write_count > write_threshold | |
| # Check capacity criterion | |
| capacity_criterion = (len(tile.slots) >= tile.max_slots * 0.9 or | |
| len(tile.buffer) >= tile.max_buffer * 0.9) | |
| if time_criterion or write_criterion or capacity_criterion: | |
| tiles_to_renorm.append(tile_id) | |
| return tiles_to_renorm | |
| def get_compression_stats(self) -> Dict[str, Any]: | |
| """Get overall compression statistics.""" | |
| total_renorms = sum(len(history) for history in self.renorm_history.values()) | |
| if total_renorms == 0: | |
| return {'message': 'No renormalizations performed yet'} | |
| # Aggregate statistics | |
| total_merged = sum( | |
| record['result'].items_merged | |
| for history in self.renorm_history.values() | |
| for record in history | |
| ) | |
| total_sketched = sum( | |
| record['result'].items_summarized | |
| for history in self.renorm_history.values() | |
| for record in history | |
| ) | |
| avg_compression_ratio = np.mean([ | |
| record['result'].compression_ratio | |
| for history in self.renorm_history.values() | |
| for record in history | |
| ]) | |
| avg_quality = np.mean([ | |
| record['result'].quality_preserved | |
| for history in self.renorm_history.values() | |
| for record in history | |
| ]) | |
| return { | |
| 'total_renormalizations': total_renorms, | |
| 'total_items_merged': total_merged, | |
| 'total_items_sketched': total_sketched, | |
| 'average_compression_ratio': avg_compression_ratio, | |
| 'average_quality_preserved': avg_quality, | |
| 'tiles_with_history': len(self.renorm_history) | |
| } | |