Mandelmem / mandelmem /quadtree.py
Kossisoroyce's picture
Upload 10 files
c05fcc5 verified
"""
Quadtree implementation for fractal memory indexing.
Organizes memory in hierarchical tiles with attractors and slot banks.
"""
import torch
import numpy as np
import hnswlib
from typing import List, Dict, Any, Optional, Tuple, Set
from dataclasses import dataclass, field
from collections import defaultdict
import time
@dataclass
class MemoryItem:
"""Individual memory item stored in tiles."""
vector: torch.Tensor
content: str
metadata: Dict[str, Any]
timestamp: float
stability_score: float = 0.0
access_count: int = 0
last_access: float = 0.0
@dataclass
class TileStats:
"""Statistics for a tile."""
write_count: int = 0
read_count: int = 0
last_renorm: float = 0.0
avg_stability: float = 0.0
class Tile:
"""Individual tile in the quadtree storing memories and local parameters."""
def __init__(self, tile_id: str, bounds: Tuple[complex, complex],
embedding_dim: int = 768, max_slots: int = 32,
max_buffer: int = 64):
self.tile_id = tile_id
self.bounds = bounds # (bottom_left, top_right) complex coordinates
self.embedding_dim = embedding_dim
self.max_slots = max_slots
self.max_buffer = max_buffer
# Core storage
self.attractor = torch.randn(embedding_dim) * 0.1 # Persistent prototype
self.slots: List[MemoryItem] = [] # Persistent items
self.buffer: List[MemoryItem] = [] # Short-term items
self.plastic: List[MemoryItem] = [] # Boundary band items
# Local parameters for iterative map
self.local_params = torch.randn(embedding_dim * 2, embedding_dim) * 0.01
self.bias = torch.randn(embedding_dim) * 0.01
# HNSW index for fast retrieval
self.hnsw_index = hnswlib.Index(space='cosine', dim=embedding_dim)
self.hnsw_index.init_index(max_elements=max_slots + max_buffer, ef_construction=200, M=16)
self.hnsw_items: List[MemoryItem] = []
# Statistics and policies
self.stats = TileStats()
self.policies: Dict[str, Any] = {}
self.half_life = 86400.0 # 1 day default
def get_center(self) -> complex:
"""Get center coordinate of tile."""
bl, tr = self.bounds
return (bl + tr) / 2
def contains(self, coord: complex) -> bool:
"""Check if coordinate is within tile bounds."""
bl, tr = self.bounds
return (bl.real <= coord.real <= tr.real and
bl.imag <= coord.imag <= tr.imag)
def add_to_slots(self, item: MemoryItem) -> bool:
"""Add item to persistent slots."""
if len(self.slots) >= self.max_slots:
# Remove least stable item
self.slots.sort(key=lambda x: x.stability_score)
removed = self.slots.pop(0)
self._remove_from_hnsw(removed)
self.slots.append(item)
self._add_to_hnsw(item)
self._update_attractor(item.vector)
return True
def add_to_buffer(self, item: MemoryItem) -> bool:
"""Add item to short-term buffer."""
if len(self.buffer) >= self.max_buffer:
# Remove oldest item
oldest = min(self.buffer, key=lambda x: x.timestamp)
self.buffer.remove(oldest)
self._remove_from_hnsw(oldest)
self.buffer.append(item)
self._add_to_hnsw(item)
return True
def add_to_plastic(self, item: MemoryItem) -> bool:
"""Add item to plastic boundary band."""
self.plastic.append(item)
self._add_to_hnsw(item)
return True
def _add_to_hnsw(self, item: MemoryItem):
"""Add item to HNSW index."""
try:
idx = len(self.hnsw_items)
# Detach tensor from computation graph before converting to numpy
vector_np = item.vector.detach().numpy().reshape(1, -1)
self.hnsw_index.add_items(vector_np, [idx])
self.hnsw_items.append(item)
except Exception as e:
print(f"HNSW add error: {e}")
def _remove_from_hnsw(self, item: MemoryItem):
"""Remove item from HNSW index (simplified)."""
try:
if item in self.hnsw_items:
self.hnsw_items.remove(item)
except Exception as e:
print(f"HNSW remove error: {e}")
def _update_attractor(self, vector: torch.Tensor, alpha: float = 0.1):
"""Update tile attractor with new vector."""
self.attractor = (1 - alpha) * self.attractor + alpha * vector
def search_local(self, query_vector: torch.Tensor, k: int = 5) -> List[Tuple[MemoryItem, float]]:
"""Search within this tile using HNSW."""
if len(self.hnsw_items) == 0:
return []
try:
self.hnsw_index.set_ef(max(k * 2, 50))
# Detach query vector from computation graph
query_np = query_vector.detach().numpy().reshape(1, -1)
indices, distances = self.hnsw_index.knn_query(
query_np, k=min(k, len(self.hnsw_items))
)
results = []
for idx, dist in zip(indices[0], distances[0]):
if idx < len(self.hnsw_items):
item = self.hnsw_items[idx]
item.access_count += 1
item.last_access = time.time()
results.append((item, 1.0 - dist)) # Convert to similarity
return results
except Exception as e:
print(f"HNSW search error: {e}")
return []
def get_all_items(self) -> List[MemoryItem]:
"""Get all items in tile."""
return self.slots + self.buffer + self.plastic
def apply_decay(self, current_time: float):
"""Apply temporal decay to buffer items."""
decayed = []
for item in self.buffer:
age = current_time - item.timestamp
decay_factor = np.exp(-age / self.half_life)
if decay_factor < 0.1: # Remove very old items
self._remove_from_hnsw(item)
else:
item.stability_score *= decay_factor
decayed.append(item)
self.buffer = decayed
class QuadTree:
"""Hierarchical quadtree for fractal memory organization."""
def __init__(self, depth: int = 6, embedding_dim: int = 768,
bounds: Tuple[complex, complex] = (complex(-2, -2), complex(2, 2))):
self.depth = depth
self.embedding_dim = embedding_dim
self.root_bounds = bounds
self.tiles: Dict[str, Tile] = {}
self.tile_hierarchy: Dict[str, List[str]] = defaultdict(list)
# Build tree structure
self._build_tree()
def _build_tree(self):
"""Build the complete quadtree structure."""
def build_recursive(tile_id: str, bounds: Tuple[complex, complex], level: int):
# Create tile
tile = Tile(tile_id, bounds, self.embedding_dim)
self.tiles[tile_id] = tile
if level < self.depth:
# Create children
bl, tr = bounds
mid_real = (bl.real + tr.real) / 2
mid_imag = (bl.imag + tr.imag) / 2
mid = complex(mid_real, mid_imag)
children = [
(f"{tile_id}0", (bl, mid)), # Bottom-left
(f"{tile_id}1", (complex(mid_real, bl.imag), complex(tr.real, mid_imag))), # Bottom-right
(f"{tile_id}2", (complex(bl.real, mid_imag), complex(mid_real, tr.real))), # Top-left
(f"{tile_id}3", (mid, tr)) # Top-right
]
for child_id, child_bounds in children:
self.tile_hierarchy[tile_id].append(child_id)
build_recursive(child_id, child_bounds, level + 1)
build_recursive("root", self.root_bounds, 0)
def route_to_leaf(self, coord: complex) -> List[str]:
"""Route from root to leaf tile containing coordinate."""
path = ["root"]
current = "root"
for level in range(self.depth):
children = self.tile_hierarchy[current]
if not children:
break
# Find child containing coordinate
for child_id in children:
if self.tiles[child_id].contains(coord):
path.append(child_id)
current = child_id
break
else:
# Coordinate outside bounds, use closest
distances = []
for child_id in children:
child_center = self.tiles[child_id].get_center()
dist = abs(coord - child_center)
distances.append((dist, child_id))
_, closest_child = min(distances)
path.append(closest_child)
current = closest_child
return path
def get_leaf_tiles(self) -> List[Tile]:
"""Get all leaf tiles."""
leaves = []
for tile_id, tile in self.tiles.items():
if tile_id not in self.tile_hierarchy or not self.tile_hierarchy[tile_id]:
leaves.append(tile)
return leaves
def get_neighbors(self, tile_id: str, radius: int = 1) -> List[Tile]:
"""Get neighboring tiles (Julia-neighbors)."""
tile = self.tiles[tile_id]
center = tile.get_center()
neighbors = []
# Simple approach: find tiles within distance threshold
for other_id, other_tile in self.tiles.items():
if other_id != tile_id:
other_center = other_tile.get_center()
if abs(center - other_center) <= radius:
neighbors.append(other_tile)
return neighbors
def get_tile_path_to_root(self, tile_id: str) -> List[str]:
"""Get path from tile to root."""
path = []
current = tile_id
while current:
path.append(current)
if current == "root":
break
# Parent is prefix (remove last character)
current = current[:-1] if len(current) > 4 else "root"
return list(reversed(path))
def get_statistics(self) -> Dict[str, Any]:
"""Get tree statistics."""
total_items = sum(len(tile.get_all_items()) for tile in self.tiles.values())
leaf_tiles = self.get_leaf_tiles()
return {
"total_tiles": len(self.tiles),
"leaf_tiles": len(leaf_tiles),
"total_items": total_items,
"avg_items_per_leaf": total_items / len(leaf_tiles) if leaf_tiles else 0,
"depth": self.depth
}