Spaces:
Sleeping
Sleeping
File size: 16,409 Bytes
c05fcc5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 |
"""
Multi-scale retrieval system for MandelMem.
Implements zoom and hop functionality for hierarchical memory access.
"""
import torch
import numpy as np
from typing import List, Dict, Any, Tuple, Optional
from dataclasses import dataclass
import time
from .quadtree import QuadTree, Tile, MemoryItem
from .encoders import MultiModalEncoder
@dataclass
class RetrievalResult:
"""Result of memory retrieval operation."""
items: List[MemoryItem]
similarities: List[float]
trace: List[str] # Tile path taken
stability_scores: List[float]
hops: List[str] # Julia-neighbor hops
confidence: float
total_time: float
@dataclass
class RoutingPotential:
"""Routing potential for tile selection."""
tile_id: str
spatial_distance: float
semantic_similarity: float
combined_potential: float
class MultiScaleRetriever:
"""Implements multi-scale retrieval with zoom and hop functionality."""
def __init__(self, quadtree: QuadTree, encoder: MultiModalEncoder,
alpha: float = 1.0, beta: float = 1.0):
self.quadtree = quadtree
self.encoder = encoder
self.alpha = alpha # Spatial weight
self.beta = beta # Semantic weight
def retrieve(self, query: str, k: int = 5, with_trace: bool = True,
confidence_threshold: float = 0.3, max_hops: int = 3) -> RetrievalResult:
"""Main retrieval interface."""
start_time = time.time()
# Encode query
encoding = self.encoder.encode(query)
query_vector = encoding.vector
query_coord = encoding.complex_coord
# Route to leaf tile
routing_path = self._route_to_leaf(query_coord, query_vector)
trace = [step.tile_id for step in routing_path]
# Retrieve from leaf tile
leaf_tile = self.quadtree.tiles[trace[-1]]
local_results = leaf_tile.search_local(query_vector, k * 2)
# Calculate confidence
confidence = self._calculate_confidence(local_results)
# Perform Julia-neighbor hops if confidence is low
hops = []
if confidence < confidence_threshold and max_hops > 0:
hop_results = self._perform_hops(leaf_tile, query_vector, k, max_hops)
local_results.extend(hop_results)
hops = [f"hop_{i}" for i in range(len(hop_results))]
# Sort and limit results
local_results.sort(key=lambda x: x[1], reverse=True)
final_results = local_results[:k]
# Extract components
items = [item for item, _ in final_results]
similarities = [sim for _, sim in final_results]
stability_scores = [item.stability_score for item in items]
total_time = time.time() - start_time
return RetrievalResult(
items=items,
similarities=similarities,
trace=trace,
stability_scores=stability_scores,
hops=hops,
confidence=confidence,
total_time=total_time
)
def _route_to_leaf(self, query_coord: complex, query_vector: torch.Tensor) -> List[RoutingPotential]:
"""Route from root to leaf using routing potential."""
path = []
current_tile_id = "root"
while current_tile_id in self.quadtree.tile_hierarchy:
children = self.quadtree.tile_hierarchy[current_tile_id]
if not children:
break
# Calculate routing potential for each child
potentials = []
for child_id in children:
child_tile = self.quadtree.tiles[child_id]
potential = self._calculate_routing_potential(
child_tile, query_coord, query_vector
)
potentials.append(potential)
# Select child with minimum potential
best_potential = min(potentials, key=lambda p: p.combined_potential)
path.append(best_potential)
current_tile_id = best_potential.tile_id
return path
def _calculate_routing_potential(self, tile: Tile, query_coord: complex,
query_vector: torch.Tensor) -> RoutingPotential:
"""Calculate routing potential ρ(t; u_q) = α||u_q - μ_t|| + β attn(v_q, A_t)."""
# Spatial component
tile_center = tile.get_center()
spatial_distance = abs(query_coord - tile_center)
# Semantic component (attention with attractor)
semantic_similarity = torch.cosine_similarity(
query_vector.unsqueeze(0),
tile.attractor.unsqueeze(0)
).item()
# Combined potential (lower is better)
combined_potential = (self.alpha * spatial_distance +
self.beta * (1.0 - semantic_similarity))
return RoutingPotential(
tile_id=tile.tile_id,
spatial_distance=spatial_distance,
semantic_similarity=semantic_similarity,
combined_potential=combined_potential
)
def _calculate_confidence(self, results: List[Tuple[MemoryItem, float]]) -> float:
"""Calculate retrieval confidence based on similarity scores."""
if not results:
return 0.0
similarities = [sim for _, sim in results]
# Confidence based on top similarity and score distribution
max_sim = max(similarities)
mean_sim = np.mean(similarities)
std_sim = np.std(similarities) if len(similarities) > 1 else 0.0
# Higher confidence for high max similarity and low variance
confidence = max_sim * (1.0 - std_sim / (mean_sim + 1e-6))
return min(1.0, max(0.0, confidence))
def _perform_hops(self, leaf_tile: Tile, query_vector: torch.Tensor,
k: int, max_hops: int) -> List[Tuple[MemoryItem, float]]:
"""Perform Julia-neighbor hops for cross-cluster recall."""
hop_results = []
visited_tiles = {leaf_tile.tile_id}
for hop in range(max_hops):
# Get neighbors of current tiles
neighbors = []
for tile_id in visited_tiles:
tile_neighbors = self.quadtree.get_neighbors(tile_id)
for neighbor in tile_neighbors:
if neighbor.tile_id not in visited_tiles:
neighbors.append(neighbor)
if not neighbors:
break
# Calculate cross-tile affinity and select best neighbors
neighbor_affinities = []
for neighbor in neighbors:
affinity = self._calculate_cross_tile_affinity(
leaf_tile, neighbor, query_vector
)
neighbor_affinities.append((neighbor, affinity))
# Sort by affinity and take top neighbors
neighbor_affinities.sort(key=lambda x: x[1], reverse=True)
top_neighbors = neighbor_affinities[:2] # Limit hops per iteration
# Search in top neighbors
for neighbor, affinity in top_neighbors:
neighbor_results = neighbor.search_local(query_vector, k // 2)
# Weight results by cross-tile affinity
weighted_results = [
(item, sim * affinity) for item, sim in neighbor_results
]
hop_results.extend(weighted_results)
visited_tiles.add(neighbor.tile_id)
return hop_results
def _calculate_cross_tile_affinity(self, source_tile: Tile, target_tile: Tile,
query_vector: torch.Tensor) -> float:
"""Calculate affinity between tiles for hop decisions."""
# Attractor similarity
attractor_sim = torch.cosine_similarity(
source_tile.attractor.unsqueeze(0),
target_tile.attractor.unsqueeze(0)
).item()
# Query relevance to target
query_relevance = torch.cosine_similarity(
query_vector.unsqueeze(0),
target_tile.attractor.unsqueeze(0)
).item()
# Spatial proximity
spatial_distance = abs(source_tile.get_center() - target_tile.get_center())
spatial_weight = 1.0 / (1.0 + spatial_distance)
# Combined affinity
affinity = 0.4 * attractor_sim + 0.4 * query_relevance + 0.2 * spatial_weight
return max(0.0, affinity)
class ContextualRetriever:
"""Enhanced retriever with contextual understanding."""
def __init__(self, base_retriever: MultiScaleRetriever):
self.base_retriever = base_retriever
self.context_history: List[str] = []
self.context_embeddings: List[torch.Tensor] = []
def retrieve_with_context(self, query: str, context: Optional[str] = None,
k: int = 5) -> RetrievalResult:
"""Retrieve with contextual awareness."""
# Update context if provided
if context:
self._update_context(context)
# Enhance query with context
enhanced_query = self._enhance_query_with_context(query)
# Perform base retrieval
result = self.base_retriever.retrieve(enhanced_query, k)
# Re-rank results based on context
if self.context_embeddings:
result = self._rerank_with_context(result)
return result
def _update_context(self, context: str):
"""Update context history."""
self.context_history.append(context)
# Encode context
encoding = self.base_retriever.encoder.encode(context)
self.context_embeddings.append(encoding.vector)
# Limit context window
max_context = 10
if len(self.context_history) > max_context:
self.context_history = self.context_history[-max_context:]
self.context_embeddings = self.context_embeddings[-max_context:]
def _enhance_query_with_context(self, query: str) -> str:
"""Enhance query with recent context."""
if not self.context_history:
return query
# Simple context enhancement (can be made more sophisticated)
recent_context = " ".join(self.context_history[-3:])
return f"{query} [Context: {recent_context}]"
def _rerank_with_context(self, result: RetrievalResult) -> RetrievalResult:
"""Re-rank results based on context similarity."""
if not self.context_embeddings or not result.items:
return result
# Calculate context similarity for each result
context_vector = torch.mean(torch.stack(self.context_embeddings), dim=0)
new_similarities = []
for i, item in enumerate(result.items):
base_sim = result.similarities[i]
context_sim = torch.cosine_similarity(
item.vector.unsqueeze(0),
context_vector.unsqueeze(0)
).item()
# Combine base and context similarity
enhanced_sim = 0.7 * base_sim + 0.3 * context_sim
new_similarities.append(enhanced_sim)
# Re-sort by enhanced similarity
sorted_indices = sorted(range(len(new_similarities)),
key=lambda i: new_similarities[i], reverse=True)
result.items = [result.items[i] for i in sorted_indices]
result.similarities = [new_similarities[i] for i in sorted_indices]
result.stability_scores = [result.stability_scores[i] for i in sorted_indices]
return result
class ExplainableRetriever:
"""Retriever with enhanced interpretability features."""
def __init__(self, base_retriever: MultiScaleRetriever):
self.base_retriever = base_retriever
def retrieve_with_explanation(self, query: str, k: int = 5) -> Dict[str, Any]:
"""Retrieve with detailed explanation of the process."""
result = self.base_retriever.retrieve(query, k, with_trace=True)
explanation = {
'query': query,
'results': result,
'routing_explanation': self._explain_routing(result.trace),
'similarity_explanation': self._explain_similarities(result),
'stability_explanation': self._explain_stability(result),
'counterfactuals': self._generate_counterfactuals(query, result)
}
return explanation
def _explain_routing(self, trace: List[str]) -> Dict[str, Any]:
"""Explain the routing path taken."""
explanation = {
'path': trace,
'depth': len(trace) - 1,
'reasoning': []
}
for i, tile_id in enumerate(trace):
if i == 0:
explanation['reasoning'].append(f"Started at root tile")
else:
parent_id = trace[i-1]
explanation['reasoning'].append(
f"Routed from {parent_id} to {tile_id} based on spatial and semantic potential"
)
return explanation
def _explain_similarities(self, result: RetrievalResult) -> List[Dict[str, Any]]:
"""Explain similarity scores for retrieved items."""
explanations = []
for i, (item, similarity) in enumerate(zip(result.items, result.similarities)):
explanation = {
'rank': i + 1,
'similarity': similarity,
'content_preview': item.content[:100] + "..." if len(item.content) > 100 else item.content,
'factors': {
'semantic_match': similarity,
'recency': item.metadata.get('recency_weight', 0.5),
'importance': item.metadata.get('importance', 0.5),
'access_frequency': min(item.access_count / 10.0, 1.0)
}
}
explanations.append(explanation)
return explanations
def _explain_stability(self, result: RetrievalResult) -> Dict[str, Any]:
"""Explain stability scores and persistence."""
if not result.stability_scores:
return {'message': 'No stability information available'}
avg_stability = np.mean(result.stability_scores)
stability_distribution = {
'high_stability': sum(1 for s in result.stability_scores if s > 0.7),
'medium_stability': sum(1 for s in result.stability_scores if 0.3 <= s <= 0.7),
'low_stability': sum(1 for s in result.stability_scores if s < 0.3)
}
return {
'average_stability': avg_stability,
'distribution': stability_distribution,
'interpretation': self._interpret_stability(avg_stability)
}
def _interpret_stability(self, avg_stability: float) -> str:
"""Interpret stability score."""
if avg_stability > 0.8:
return "Very stable memories - likely to persist long-term"
elif avg_stability > 0.6:
return "Moderately stable memories - established but may evolve"
elif avg_stability > 0.4:
return "Somewhat unstable memories - in plastic boundary band"
else:
return "Low stability memories - may be forgotten soon"
def _generate_counterfactuals(self, query: str, result: RetrievalResult) -> List[str]:
"""Generate counterfactual explanations."""
counterfactuals = []
if result.confidence < 0.5:
counterfactuals.append("If the query were more specific, retrieval confidence would be higher")
if len(result.hops) > 0:
counterfactuals.append(f"Without Julia-neighbor hops, {len(result.hops)} additional results would not have been found")
if result.stability_scores and np.mean(result.stability_scores) < 0.5:
counterfactuals.append("If persistence thresholds were lower, more stable memories would be available")
return counterfactuals
|