Spaces:
Sleeping
Sleeping
| """ | |
| Quadtree implementation for fractal memory indexing. | |
| Organizes memory in hierarchical tiles with attractors and slot banks. | |
| """ | |
| import torch | |
| import numpy as np | |
| import hnswlib | |
| from typing import List, Dict, Any, Optional, Tuple, Set | |
| from dataclasses import dataclass, field | |
| from collections import defaultdict | |
| import time | |
| class MemoryItem: | |
| """Individual memory item stored in tiles.""" | |
| vector: torch.Tensor | |
| content: str | |
| metadata: Dict[str, Any] | |
| timestamp: float | |
| stability_score: float = 0.0 | |
| access_count: int = 0 | |
| last_access: float = 0.0 | |
| class TileStats: | |
| """Statistics for a tile.""" | |
| write_count: int = 0 | |
| read_count: int = 0 | |
| last_renorm: float = 0.0 | |
| avg_stability: float = 0.0 | |
| class Tile: | |
| """Individual tile in the quadtree storing memories and local parameters.""" | |
| def __init__(self, tile_id: str, bounds: Tuple[complex, complex], | |
| embedding_dim: int = 768, max_slots: int = 32, | |
| max_buffer: int = 64): | |
| self.tile_id = tile_id | |
| self.bounds = bounds # (bottom_left, top_right) complex coordinates | |
| self.embedding_dim = embedding_dim | |
| self.max_slots = max_slots | |
| self.max_buffer = max_buffer | |
| # Core storage | |
| self.attractor = torch.randn(embedding_dim) * 0.1 # Persistent prototype | |
| self.slots: List[MemoryItem] = [] # Persistent items | |
| self.buffer: List[MemoryItem] = [] # Short-term items | |
| self.plastic: List[MemoryItem] = [] # Boundary band items | |
| # Local parameters for iterative map | |
| self.local_params = torch.randn(embedding_dim * 2, embedding_dim) * 0.01 | |
| self.bias = torch.randn(embedding_dim) * 0.01 | |
| # HNSW index for fast retrieval | |
| self.hnsw_index = hnswlib.Index(space='cosine', dim=embedding_dim) | |
| self.hnsw_index.init_index(max_elements=max_slots + max_buffer, ef_construction=200, M=16) | |
| self.hnsw_items: List[MemoryItem] = [] | |
| # Statistics and policies | |
| self.stats = TileStats() | |
| self.policies: Dict[str, Any] = {} | |
| self.half_life = 86400.0 # 1 day default | |
| def get_center(self) -> complex: | |
| """Get center coordinate of tile.""" | |
| bl, tr = self.bounds | |
| return (bl + tr) / 2 | |
| def contains(self, coord: complex) -> bool: | |
| """Check if coordinate is within tile bounds.""" | |
| bl, tr = self.bounds | |
| return (bl.real <= coord.real <= tr.real and | |
| bl.imag <= coord.imag <= tr.imag) | |
| def add_to_slots(self, item: MemoryItem) -> bool: | |
| """Add item to persistent slots.""" | |
| if len(self.slots) >= self.max_slots: | |
| # Remove least stable item | |
| self.slots.sort(key=lambda x: x.stability_score) | |
| removed = self.slots.pop(0) | |
| self._remove_from_hnsw(removed) | |
| self.slots.append(item) | |
| self._add_to_hnsw(item) | |
| self._update_attractor(item.vector) | |
| return True | |
| def add_to_buffer(self, item: MemoryItem) -> bool: | |
| """Add item to short-term buffer.""" | |
| if len(self.buffer) >= self.max_buffer: | |
| # Remove oldest item | |
| oldest = min(self.buffer, key=lambda x: x.timestamp) | |
| self.buffer.remove(oldest) | |
| self._remove_from_hnsw(oldest) | |
| self.buffer.append(item) | |
| self._add_to_hnsw(item) | |
| return True | |
| def add_to_plastic(self, item: MemoryItem) -> bool: | |
| """Add item to plastic boundary band.""" | |
| self.plastic.append(item) | |
| self._add_to_hnsw(item) | |
| return True | |
| def _add_to_hnsw(self, item: MemoryItem): | |
| """Add item to HNSW index.""" | |
| try: | |
| idx = len(self.hnsw_items) | |
| # Detach tensor from computation graph before converting to numpy | |
| vector_np = item.vector.detach().numpy().reshape(1, -1) | |
| self.hnsw_index.add_items(vector_np, [idx]) | |
| self.hnsw_items.append(item) | |
| except Exception as e: | |
| print(f"HNSW add error: {e}") | |
| def _remove_from_hnsw(self, item: MemoryItem): | |
| """Remove item from HNSW index (simplified).""" | |
| try: | |
| if item in self.hnsw_items: | |
| self.hnsw_items.remove(item) | |
| except Exception as e: | |
| print(f"HNSW remove error: {e}") | |
| def _update_attractor(self, vector: torch.Tensor, alpha: float = 0.1): | |
| """Update tile attractor with new vector.""" | |
| self.attractor = (1 - alpha) * self.attractor + alpha * vector | |
| def search_local(self, query_vector: torch.Tensor, k: int = 5) -> List[Tuple[MemoryItem, float]]: | |
| """Search within this tile using HNSW.""" | |
| if len(self.hnsw_items) == 0: | |
| return [] | |
| try: | |
| self.hnsw_index.set_ef(max(k * 2, 50)) | |
| # Detach query vector from computation graph | |
| query_np = query_vector.detach().numpy().reshape(1, -1) | |
| indices, distances = self.hnsw_index.knn_query( | |
| query_np, k=min(k, len(self.hnsw_items)) | |
| ) | |
| results = [] | |
| for idx, dist in zip(indices[0], distances[0]): | |
| if idx < len(self.hnsw_items): | |
| item = self.hnsw_items[idx] | |
| item.access_count += 1 | |
| item.last_access = time.time() | |
| results.append((item, 1.0 - dist)) # Convert to similarity | |
| return results | |
| except Exception as e: | |
| print(f"HNSW search error: {e}") | |
| return [] | |
| def get_all_items(self) -> List[MemoryItem]: | |
| """Get all items in tile.""" | |
| return self.slots + self.buffer + self.plastic | |
| def apply_decay(self, current_time: float): | |
| """Apply temporal decay to buffer items.""" | |
| decayed = [] | |
| for item in self.buffer: | |
| age = current_time - item.timestamp | |
| decay_factor = np.exp(-age / self.half_life) | |
| if decay_factor < 0.1: # Remove very old items | |
| self._remove_from_hnsw(item) | |
| else: | |
| item.stability_score *= decay_factor | |
| decayed.append(item) | |
| self.buffer = decayed | |
| class QuadTree: | |
| """Hierarchical quadtree for fractal memory organization.""" | |
| def __init__(self, depth: int = 6, embedding_dim: int = 768, | |
| bounds: Tuple[complex, complex] = (complex(-2, -2), complex(2, 2))): | |
| self.depth = depth | |
| self.embedding_dim = embedding_dim | |
| self.root_bounds = bounds | |
| self.tiles: Dict[str, Tile] = {} | |
| self.tile_hierarchy: Dict[str, List[str]] = defaultdict(list) | |
| # Build tree structure | |
| self._build_tree() | |
| def _build_tree(self): | |
| """Build the complete quadtree structure.""" | |
| def build_recursive(tile_id: str, bounds: Tuple[complex, complex], level: int): | |
| # Create tile | |
| tile = Tile(tile_id, bounds, self.embedding_dim) | |
| self.tiles[tile_id] = tile | |
| if level < self.depth: | |
| # Create children | |
| bl, tr = bounds | |
| mid_real = (bl.real + tr.real) / 2 | |
| mid_imag = (bl.imag + tr.imag) / 2 | |
| mid = complex(mid_real, mid_imag) | |
| children = [ | |
| (f"{tile_id}0", (bl, mid)), # Bottom-left | |
| (f"{tile_id}1", (complex(mid_real, bl.imag), complex(tr.real, mid_imag))), # Bottom-right | |
| (f"{tile_id}2", (complex(bl.real, mid_imag), complex(mid_real, tr.real))), # Top-left | |
| (f"{tile_id}3", (mid, tr)) # Top-right | |
| ] | |
| for child_id, child_bounds in children: | |
| self.tile_hierarchy[tile_id].append(child_id) | |
| build_recursive(child_id, child_bounds, level + 1) | |
| build_recursive("root", self.root_bounds, 0) | |
| def route_to_leaf(self, coord: complex) -> List[str]: | |
| """Route from root to leaf tile containing coordinate.""" | |
| path = ["root"] | |
| current = "root" | |
| for level in range(self.depth): | |
| children = self.tile_hierarchy[current] | |
| if not children: | |
| break | |
| # Find child containing coordinate | |
| for child_id in children: | |
| if self.tiles[child_id].contains(coord): | |
| path.append(child_id) | |
| current = child_id | |
| break | |
| else: | |
| # Coordinate outside bounds, use closest | |
| distances = [] | |
| for child_id in children: | |
| child_center = self.tiles[child_id].get_center() | |
| dist = abs(coord - child_center) | |
| distances.append((dist, child_id)) | |
| _, closest_child = min(distances) | |
| path.append(closest_child) | |
| current = closest_child | |
| return path | |
| def get_leaf_tiles(self) -> List[Tile]: | |
| """Get all leaf tiles.""" | |
| leaves = [] | |
| for tile_id, tile in self.tiles.items(): | |
| if tile_id not in self.tile_hierarchy or not self.tile_hierarchy[tile_id]: | |
| leaves.append(tile) | |
| return leaves | |
| def get_neighbors(self, tile_id: str, radius: int = 1) -> List[Tile]: | |
| """Get neighboring tiles (Julia-neighbors).""" | |
| tile = self.tiles[tile_id] | |
| center = tile.get_center() | |
| neighbors = [] | |
| # Simple approach: find tiles within distance threshold | |
| for other_id, other_tile in self.tiles.items(): | |
| if other_id != tile_id: | |
| other_center = other_tile.get_center() | |
| if abs(center - other_center) <= radius: | |
| neighbors.append(other_tile) | |
| return neighbors | |
| def get_tile_path_to_root(self, tile_id: str) -> List[str]: | |
| """Get path from tile to root.""" | |
| path = [] | |
| current = tile_id | |
| while current: | |
| path.append(current) | |
| if current == "root": | |
| break | |
| # Parent is prefix (remove last character) | |
| current = current[:-1] if len(current) > 4 else "root" | |
| return list(reversed(path)) | |
| def get_statistics(self) -> Dict[str, Any]: | |
| """Get tree statistics.""" | |
| total_items = sum(len(tile.get_all_items()) for tile in self.tiles.values()) | |
| leaf_tiles = self.get_leaf_tiles() | |
| return { | |
| "total_tiles": len(self.tiles), | |
| "leaf_tiles": len(leaf_tiles), | |
| "total_items": total_items, | |
| "avg_items_per_leaf": total_items / len(leaf_tiles) if leaf_tiles else 0, | |
| "depth": self.depth | |
| } | |