Mandelmem / mandelmem /retrieval.py
Kossisoroyce's picture
Upload 10 files
c05fcc5 verified
"""
Multi-scale retrieval system for MandelMem.
Implements zoom and hop functionality for hierarchical memory access.
"""
import torch
import numpy as np
from typing import List, Dict, Any, Tuple, Optional
from dataclasses import dataclass
import time
from .quadtree import QuadTree, Tile, MemoryItem
from .encoders import MultiModalEncoder
@dataclass
class RetrievalResult:
"""Result of memory retrieval operation."""
items: List[MemoryItem]
similarities: List[float]
trace: List[str] # Tile path taken
stability_scores: List[float]
hops: List[str] # Julia-neighbor hops
confidence: float
total_time: float
@dataclass
class RoutingPotential:
"""Routing potential for tile selection."""
tile_id: str
spatial_distance: float
semantic_similarity: float
combined_potential: float
class MultiScaleRetriever:
"""Implements multi-scale retrieval with zoom and hop functionality."""
def __init__(self, quadtree: QuadTree, encoder: MultiModalEncoder,
alpha: float = 1.0, beta: float = 1.0):
self.quadtree = quadtree
self.encoder = encoder
self.alpha = alpha # Spatial weight
self.beta = beta # Semantic weight
def retrieve(self, query: str, k: int = 5, with_trace: bool = True,
confidence_threshold: float = 0.3, max_hops: int = 3) -> RetrievalResult:
"""Main retrieval interface."""
start_time = time.time()
# Encode query
encoding = self.encoder.encode(query)
query_vector = encoding.vector
query_coord = encoding.complex_coord
# Route to leaf tile
routing_path = self._route_to_leaf(query_coord, query_vector)
trace = [step.tile_id for step in routing_path]
# Retrieve from leaf tile
leaf_tile = self.quadtree.tiles[trace[-1]]
local_results = leaf_tile.search_local(query_vector, k * 2)
# Calculate confidence
confidence = self._calculate_confidence(local_results)
# Perform Julia-neighbor hops if confidence is low
hops = []
if confidence < confidence_threshold and max_hops > 0:
hop_results = self._perform_hops(leaf_tile, query_vector, k, max_hops)
local_results.extend(hop_results)
hops = [f"hop_{i}" for i in range(len(hop_results))]
# Sort and limit results
local_results.sort(key=lambda x: x[1], reverse=True)
final_results = local_results[:k]
# Extract components
items = [item for item, _ in final_results]
similarities = [sim for _, sim in final_results]
stability_scores = [item.stability_score for item in items]
total_time = time.time() - start_time
return RetrievalResult(
items=items,
similarities=similarities,
trace=trace,
stability_scores=stability_scores,
hops=hops,
confidence=confidence,
total_time=total_time
)
def _route_to_leaf(self, query_coord: complex, query_vector: torch.Tensor) -> List[RoutingPotential]:
"""Route from root to leaf using routing potential."""
path = []
current_tile_id = "root"
while current_tile_id in self.quadtree.tile_hierarchy:
children = self.quadtree.tile_hierarchy[current_tile_id]
if not children:
break
# Calculate routing potential for each child
potentials = []
for child_id in children:
child_tile = self.quadtree.tiles[child_id]
potential = self._calculate_routing_potential(
child_tile, query_coord, query_vector
)
potentials.append(potential)
# Select child with minimum potential
best_potential = min(potentials, key=lambda p: p.combined_potential)
path.append(best_potential)
current_tile_id = best_potential.tile_id
return path
def _calculate_routing_potential(self, tile: Tile, query_coord: complex,
query_vector: torch.Tensor) -> RoutingPotential:
"""Calculate routing potential ρ(t; u_q) = α||u_q - μ_t|| + β attn(v_q, A_t)."""
# Spatial component
tile_center = tile.get_center()
spatial_distance = abs(query_coord - tile_center)
# Semantic component (attention with attractor)
semantic_similarity = torch.cosine_similarity(
query_vector.unsqueeze(0),
tile.attractor.unsqueeze(0)
).item()
# Combined potential (lower is better)
combined_potential = (self.alpha * spatial_distance +
self.beta * (1.0 - semantic_similarity))
return RoutingPotential(
tile_id=tile.tile_id,
spatial_distance=spatial_distance,
semantic_similarity=semantic_similarity,
combined_potential=combined_potential
)
def _calculate_confidence(self, results: List[Tuple[MemoryItem, float]]) -> float:
"""Calculate retrieval confidence based on similarity scores."""
if not results:
return 0.0
similarities = [sim for _, sim in results]
# Confidence based on top similarity and score distribution
max_sim = max(similarities)
mean_sim = np.mean(similarities)
std_sim = np.std(similarities) if len(similarities) > 1 else 0.0
# Higher confidence for high max similarity and low variance
confidence = max_sim * (1.0 - std_sim / (mean_sim + 1e-6))
return min(1.0, max(0.0, confidence))
def _perform_hops(self, leaf_tile: Tile, query_vector: torch.Tensor,
k: int, max_hops: int) -> List[Tuple[MemoryItem, float]]:
"""Perform Julia-neighbor hops for cross-cluster recall."""
hop_results = []
visited_tiles = {leaf_tile.tile_id}
for hop in range(max_hops):
# Get neighbors of current tiles
neighbors = []
for tile_id in visited_tiles:
tile_neighbors = self.quadtree.get_neighbors(tile_id)
for neighbor in tile_neighbors:
if neighbor.tile_id not in visited_tiles:
neighbors.append(neighbor)
if not neighbors:
break
# Calculate cross-tile affinity and select best neighbors
neighbor_affinities = []
for neighbor in neighbors:
affinity = self._calculate_cross_tile_affinity(
leaf_tile, neighbor, query_vector
)
neighbor_affinities.append((neighbor, affinity))
# Sort by affinity and take top neighbors
neighbor_affinities.sort(key=lambda x: x[1], reverse=True)
top_neighbors = neighbor_affinities[:2] # Limit hops per iteration
# Search in top neighbors
for neighbor, affinity in top_neighbors:
neighbor_results = neighbor.search_local(query_vector, k // 2)
# Weight results by cross-tile affinity
weighted_results = [
(item, sim * affinity) for item, sim in neighbor_results
]
hop_results.extend(weighted_results)
visited_tiles.add(neighbor.tile_id)
return hop_results
def _calculate_cross_tile_affinity(self, source_tile: Tile, target_tile: Tile,
query_vector: torch.Tensor) -> float:
"""Calculate affinity between tiles for hop decisions."""
# Attractor similarity
attractor_sim = torch.cosine_similarity(
source_tile.attractor.unsqueeze(0),
target_tile.attractor.unsqueeze(0)
).item()
# Query relevance to target
query_relevance = torch.cosine_similarity(
query_vector.unsqueeze(0),
target_tile.attractor.unsqueeze(0)
).item()
# Spatial proximity
spatial_distance = abs(source_tile.get_center() - target_tile.get_center())
spatial_weight = 1.0 / (1.0 + spatial_distance)
# Combined affinity
affinity = 0.4 * attractor_sim + 0.4 * query_relevance + 0.2 * spatial_weight
return max(0.0, affinity)
class ContextualRetriever:
"""Enhanced retriever with contextual understanding."""
def __init__(self, base_retriever: MultiScaleRetriever):
self.base_retriever = base_retriever
self.context_history: List[str] = []
self.context_embeddings: List[torch.Tensor] = []
def retrieve_with_context(self, query: str, context: Optional[str] = None,
k: int = 5) -> RetrievalResult:
"""Retrieve with contextual awareness."""
# Update context if provided
if context:
self._update_context(context)
# Enhance query with context
enhanced_query = self._enhance_query_with_context(query)
# Perform base retrieval
result = self.base_retriever.retrieve(enhanced_query, k)
# Re-rank results based on context
if self.context_embeddings:
result = self._rerank_with_context(result)
return result
def _update_context(self, context: str):
"""Update context history."""
self.context_history.append(context)
# Encode context
encoding = self.base_retriever.encoder.encode(context)
self.context_embeddings.append(encoding.vector)
# Limit context window
max_context = 10
if len(self.context_history) > max_context:
self.context_history = self.context_history[-max_context:]
self.context_embeddings = self.context_embeddings[-max_context:]
def _enhance_query_with_context(self, query: str) -> str:
"""Enhance query with recent context."""
if not self.context_history:
return query
# Simple context enhancement (can be made more sophisticated)
recent_context = " ".join(self.context_history[-3:])
return f"{query} [Context: {recent_context}]"
def _rerank_with_context(self, result: RetrievalResult) -> RetrievalResult:
"""Re-rank results based on context similarity."""
if not self.context_embeddings or not result.items:
return result
# Calculate context similarity for each result
context_vector = torch.mean(torch.stack(self.context_embeddings), dim=0)
new_similarities = []
for i, item in enumerate(result.items):
base_sim = result.similarities[i]
context_sim = torch.cosine_similarity(
item.vector.unsqueeze(0),
context_vector.unsqueeze(0)
).item()
# Combine base and context similarity
enhanced_sim = 0.7 * base_sim + 0.3 * context_sim
new_similarities.append(enhanced_sim)
# Re-sort by enhanced similarity
sorted_indices = sorted(range(len(new_similarities)),
key=lambda i: new_similarities[i], reverse=True)
result.items = [result.items[i] for i in sorted_indices]
result.similarities = [new_similarities[i] for i in sorted_indices]
result.stability_scores = [result.stability_scores[i] for i in sorted_indices]
return result
class ExplainableRetriever:
"""Retriever with enhanced interpretability features."""
def __init__(self, base_retriever: MultiScaleRetriever):
self.base_retriever = base_retriever
def retrieve_with_explanation(self, query: str, k: int = 5) -> Dict[str, Any]:
"""Retrieve with detailed explanation of the process."""
result = self.base_retriever.retrieve(query, k, with_trace=True)
explanation = {
'query': query,
'results': result,
'routing_explanation': self._explain_routing(result.trace),
'similarity_explanation': self._explain_similarities(result),
'stability_explanation': self._explain_stability(result),
'counterfactuals': self._generate_counterfactuals(query, result)
}
return explanation
def _explain_routing(self, trace: List[str]) -> Dict[str, Any]:
"""Explain the routing path taken."""
explanation = {
'path': trace,
'depth': len(trace) - 1,
'reasoning': []
}
for i, tile_id in enumerate(trace):
if i == 0:
explanation['reasoning'].append(f"Started at root tile")
else:
parent_id = trace[i-1]
explanation['reasoning'].append(
f"Routed from {parent_id} to {tile_id} based on spatial and semantic potential"
)
return explanation
def _explain_similarities(self, result: RetrievalResult) -> List[Dict[str, Any]]:
"""Explain similarity scores for retrieved items."""
explanations = []
for i, (item, similarity) in enumerate(zip(result.items, result.similarities)):
explanation = {
'rank': i + 1,
'similarity': similarity,
'content_preview': item.content[:100] + "..." if len(item.content) > 100 else item.content,
'factors': {
'semantic_match': similarity,
'recency': item.metadata.get('recency_weight', 0.5),
'importance': item.metadata.get('importance', 0.5),
'access_frequency': min(item.access_count / 10.0, 1.0)
}
}
explanations.append(explanation)
return explanations
def _explain_stability(self, result: RetrievalResult) -> Dict[str, Any]:
"""Explain stability scores and persistence."""
if not result.stability_scores:
return {'message': 'No stability information available'}
avg_stability = np.mean(result.stability_scores)
stability_distribution = {
'high_stability': sum(1 for s in result.stability_scores if s > 0.7),
'medium_stability': sum(1 for s in result.stability_scores if 0.3 <= s <= 0.7),
'low_stability': sum(1 for s in result.stability_scores if s < 0.3)
}
return {
'average_stability': avg_stability,
'distribution': stability_distribution,
'interpretation': self._interpret_stability(avg_stability)
}
def _interpret_stability(self, avg_stability: float) -> str:
"""Interpret stability score."""
if avg_stability > 0.8:
return "Very stable memories - likely to persist long-term"
elif avg_stability > 0.6:
return "Moderately stable memories - established but may evolve"
elif avg_stability > 0.4:
return "Somewhat unstable memories - in plastic boundary band"
else:
return "Low stability memories - may be forgotten soon"
def _generate_counterfactuals(self, query: str, result: RetrievalResult) -> List[str]:
"""Generate counterfactual explanations."""
counterfactuals = []
if result.confidence < 0.5:
counterfactuals.append("If the query were more specific, retrieval confidence would be higher")
if len(result.hops) > 0:
counterfactuals.append(f"Without Julia-neighbor hops, {len(result.hops)} additional results would not have been found")
if result.stability_scores and np.mean(result.stability_scores) < 0.5:
counterfactuals.append("If persistence thresholds were lower, more stable memories would be available")
return counterfactuals