Spaces:
Sleeping
Sleeping
| """ | |
| Multi-scale retrieval system for MandelMem. | |
| Implements zoom and hop functionality for hierarchical memory access. | |
| """ | |
| import torch | |
| import numpy as np | |
| from typing import List, Dict, Any, Tuple, Optional | |
| from dataclasses import dataclass | |
| import time | |
| from .quadtree import QuadTree, Tile, MemoryItem | |
| from .encoders import MultiModalEncoder | |
| class RetrievalResult: | |
| """Result of memory retrieval operation.""" | |
| items: List[MemoryItem] | |
| similarities: List[float] | |
| trace: List[str] # Tile path taken | |
| stability_scores: List[float] | |
| hops: List[str] # Julia-neighbor hops | |
| confidence: float | |
| total_time: float | |
| class RoutingPotential: | |
| """Routing potential for tile selection.""" | |
| tile_id: str | |
| spatial_distance: float | |
| semantic_similarity: float | |
| combined_potential: float | |
| class MultiScaleRetriever: | |
| """Implements multi-scale retrieval with zoom and hop functionality.""" | |
| def __init__(self, quadtree: QuadTree, encoder: MultiModalEncoder, | |
| alpha: float = 1.0, beta: float = 1.0): | |
| self.quadtree = quadtree | |
| self.encoder = encoder | |
| self.alpha = alpha # Spatial weight | |
| self.beta = beta # Semantic weight | |
| def retrieve(self, query: str, k: int = 5, with_trace: bool = True, | |
| confidence_threshold: float = 0.3, max_hops: int = 3) -> RetrievalResult: | |
| """Main retrieval interface.""" | |
| start_time = time.time() | |
| # Encode query | |
| encoding = self.encoder.encode(query) | |
| query_vector = encoding.vector | |
| query_coord = encoding.complex_coord | |
| # Route to leaf tile | |
| routing_path = self._route_to_leaf(query_coord, query_vector) | |
| trace = [step.tile_id for step in routing_path] | |
| # Retrieve from leaf tile | |
| leaf_tile = self.quadtree.tiles[trace[-1]] | |
| local_results = leaf_tile.search_local(query_vector, k * 2) | |
| # Calculate confidence | |
| confidence = self._calculate_confidence(local_results) | |
| # Perform Julia-neighbor hops if confidence is low | |
| hops = [] | |
| if confidence < confidence_threshold and max_hops > 0: | |
| hop_results = self._perform_hops(leaf_tile, query_vector, k, max_hops) | |
| local_results.extend(hop_results) | |
| hops = [f"hop_{i}" for i in range(len(hop_results))] | |
| # Sort and limit results | |
| local_results.sort(key=lambda x: x[1], reverse=True) | |
| final_results = local_results[:k] | |
| # Extract components | |
| items = [item for item, _ in final_results] | |
| similarities = [sim for _, sim in final_results] | |
| stability_scores = [item.stability_score for item in items] | |
| total_time = time.time() - start_time | |
| return RetrievalResult( | |
| items=items, | |
| similarities=similarities, | |
| trace=trace, | |
| stability_scores=stability_scores, | |
| hops=hops, | |
| confidence=confidence, | |
| total_time=total_time | |
| ) | |
| def _route_to_leaf(self, query_coord: complex, query_vector: torch.Tensor) -> List[RoutingPotential]: | |
| """Route from root to leaf using routing potential.""" | |
| path = [] | |
| current_tile_id = "root" | |
| while current_tile_id in self.quadtree.tile_hierarchy: | |
| children = self.quadtree.tile_hierarchy[current_tile_id] | |
| if not children: | |
| break | |
| # Calculate routing potential for each child | |
| potentials = [] | |
| for child_id in children: | |
| child_tile = self.quadtree.tiles[child_id] | |
| potential = self._calculate_routing_potential( | |
| child_tile, query_coord, query_vector | |
| ) | |
| potentials.append(potential) | |
| # Select child with minimum potential | |
| best_potential = min(potentials, key=lambda p: p.combined_potential) | |
| path.append(best_potential) | |
| current_tile_id = best_potential.tile_id | |
| return path | |
| def _calculate_routing_potential(self, tile: Tile, query_coord: complex, | |
| query_vector: torch.Tensor) -> RoutingPotential: | |
| """Calculate routing potential ρ(t; u_q) = α||u_q - μ_t|| + β attn(v_q, A_t).""" | |
| # Spatial component | |
| tile_center = tile.get_center() | |
| spatial_distance = abs(query_coord - tile_center) | |
| # Semantic component (attention with attractor) | |
| semantic_similarity = torch.cosine_similarity( | |
| query_vector.unsqueeze(0), | |
| tile.attractor.unsqueeze(0) | |
| ).item() | |
| # Combined potential (lower is better) | |
| combined_potential = (self.alpha * spatial_distance + | |
| self.beta * (1.0 - semantic_similarity)) | |
| return RoutingPotential( | |
| tile_id=tile.tile_id, | |
| spatial_distance=spatial_distance, | |
| semantic_similarity=semantic_similarity, | |
| combined_potential=combined_potential | |
| ) | |
| def _calculate_confidence(self, results: List[Tuple[MemoryItem, float]]) -> float: | |
| """Calculate retrieval confidence based on similarity scores.""" | |
| if not results: | |
| return 0.0 | |
| similarities = [sim for _, sim in results] | |
| # Confidence based on top similarity and score distribution | |
| max_sim = max(similarities) | |
| mean_sim = np.mean(similarities) | |
| std_sim = np.std(similarities) if len(similarities) > 1 else 0.0 | |
| # Higher confidence for high max similarity and low variance | |
| confidence = max_sim * (1.0 - std_sim / (mean_sim + 1e-6)) | |
| return min(1.0, max(0.0, confidence)) | |
| def _perform_hops(self, leaf_tile: Tile, query_vector: torch.Tensor, | |
| k: int, max_hops: int) -> List[Tuple[MemoryItem, float]]: | |
| """Perform Julia-neighbor hops for cross-cluster recall.""" | |
| hop_results = [] | |
| visited_tiles = {leaf_tile.tile_id} | |
| for hop in range(max_hops): | |
| # Get neighbors of current tiles | |
| neighbors = [] | |
| for tile_id in visited_tiles: | |
| tile_neighbors = self.quadtree.get_neighbors(tile_id) | |
| for neighbor in tile_neighbors: | |
| if neighbor.tile_id not in visited_tiles: | |
| neighbors.append(neighbor) | |
| if not neighbors: | |
| break | |
| # Calculate cross-tile affinity and select best neighbors | |
| neighbor_affinities = [] | |
| for neighbor in neighbors: | |
| affinity = self._calculate_cross_tile_affinity( | |
| leaf_tile, neighbor, query_vector | |
| ) | |
| neighbor_affinities.append((neighbor, affinity)) | |
| # Sort by affinity and take top neighbors | |
| neighbor_affinities.sort(key=lambda x: x[1], reverse=True) | |
| top_neighbors = neighbor_affinities[:2] # Limit hops per iteration | |
| # Search in top neighbors | |
| for neighbor, affinity in top_neighbors: | |
| neighbor_results = neighbor.search_local(query_vector, k // 2) | |
| # Weight results by cross-tile affinity | |
| weighted_results = [ | |
| (item, sim * affinity) for item, sim in neighbor_results | |
| ] | |
| hop_results.extend(weighted_results) | |
| visited_tiles.add(neighbor.tile_id) | |
| return hop_results | |
| def _calculate_cross_tile_affinity(self, source_tile: Tile, target_tile: Tile, | |
| query_vector: torch.Tensor) -> float: | |
| """Calculate affinity between tiles for hop decisions.""" | |
| # Attractor similarity | |
| attractor_sim = torch.cosine_similarity( | |
| source_tile.attractor.unsqueeze(0), | |
| target_tile.attractor.unsqueeze(0) | |
| ).item() | |
| # Query relevance to target | |
| query_relevance = torch.cosine_similarity( | |
| query_vector.unsqueeze(0), | |
| target_tile.attractor.unsqueeze(0) | |
| ).item() | |
| # Spatial proximity | |
| spatial_distance = abs(source_tile.get_center() - target_tile.get_center()) | |
| spatial_weight = 1.0 / (1.0 + spatial_distance) | |
| # Combined affinity | |
| affinity = 0.4 * attractor_sim + 0.4 * query_relevance + 0.2 * spatial_weight | |
| return max(0.0, affinity) | |
| class ContextualRetriever: | |
| """Enhanced retriever with contextual understanding.""" | |
| def __init__(self, base_retriever: MultiScaleRetriever): | |
| self.base_retriever = base_retriever | |
| self.context_history: List[str] = [] | |
| self.context_embeddings: List[torch.Tensor] = [] | |
| def retrieve_with_context(self, query: str, context: Optional[str] = None, | |
| k: int = 5) -> RetrievalResult: | |
| """Retrieve with contextual awareness.""" | |
| # Update context if provided | |
| if context: | |
| self._update_context(context) | |
| # Enhance query with context | |
| enhanced_query = self._enhance_query_with_context(query) | |
| # Perform base retrieval | |
| result = self.base_retriever.retrieve(enhanced_query, k) | |
| # Re-rank results based on context | |
| if self.context_embeddings: | |
| result = self._rerank_with_context(result) | |
| return result | |
| def _update_context(self, context: str): | |
| """Update context history.""" | |
| self.context_history.append(context) | |
| # Encode context | |
| encoding = self.base_retriever.encoder.encode(context) | |
| self.context_embeddings.append(encoding.vector) | |
| # Limit context window | |
| max_context = 10 | |
| if len(self.context_history) > max_context: | |
| self.context_history = self.context_history[-max_context:] | |
| self.context_embeddings = self.context_embeddings[-max_context:] | |
| def _enhance_query_with_context(self, query: str) -> str: | |
| """Enhance query with recent context.""" | |
| if not self.context_history: | |
| return query | |
| # Simple context enhancement (can be made more sophisticated) | |
| recent_context = " ".join(self.context_history[-3:]) | |
| return f"{query} [Context: {recent_context}]" | |
| def _rerank_with_context(self, result: RetrievalResult) -> RetrievalResult: | |
| """Re-rank results based on context similarity.""" | |
| if not self.context_embeddings or not result.items: | |
| return result | |
| # Calculate context similarity for each result | |
| context_vector = torch.mean(torch.stack(self.context_embeddings), dim=0) | |
| new_similarities = [] | |
| for i, item in enumerate(result.items): | |
| base_sim = result.similarities[i] | |
| context_sim = torch.cosine_similarity( | |
| item.vector.unsqueeze(0), | |
| context_vector.unsqueeze(0) | |
| ).item() | |
| # Combine base and context similarity | |
| enhanced_sim = 0.7 * base_sim + 0.3 * context_sim | |
| new_similarities.append(enhanced_sim) | |
| # Re-sort by enhanced similarity | |
| sorted_indices = sorted(range(len(new_similarities)), | |
| key=lambda i: new_similarities[i], reverse=True) | |
| result.items = [result.items[i] for i in sorted_indices] | |
| result.similarities = [new_similarities[i] for i in sorted_indices] | |
| result.stability_scores = [result.stability_scores[i] for i in sorted_indices] | |
| return result | |
| class ExplainableRetriever: | |
| """Retriever with enhanced interpretability features.""" | |
| def __init__(self, base_retriever: MultiScaleRetriever): | |
| self.base_retriever = base_retriever | |
| def retrieve_with_explanation(self, query: str, k: int = 5) -> Dict[str, Any]: | |
| """Retrieve with detailed explanation of the process.""" | |
| result = self.base_retriever.retrieve(query, k, with_trace=True) | |
| explanation = { | |
| 'query': query, | |
| 'results': result, | |
| 'routing_explanation': self._explain_routing(result.trace), | |
| 'similarity_explanation': self._explain_similarities(result), | |
| 'stability_explanation': self._explain_stability(result), | |
| 'counterfactuals': self._generate_counterfactuals(query, result) | |
| } | |
| return explanation | |
| def _explain_routing(self, trace: List[str]) -> Dict[str, Any]: | |
| """Explain the routing path taken.""" | |
| explanation = { | |
| 'path': trace, | |
| 'depth': len(trace) - 1, | |
| 'reasoning': [] | |
| } | |
| for i, tile_id in enumerate(trace): | |
| if i == 0: | |
| explanation['reasoning'].append(f"Started at root tile") | |
| else: | |
| parent_id = trace[i-1] | |
| explanation['reasoning'].append( | |
| f"Routed from {parent_id} to {tile_id} based on spatial and semantic potential" | |
| ) | |
| return explanation | |
| def _explain_similarities(self, result: RetrievalResult) -> List[Dict[str, Any]]: | |
| """Explain similarity scores for retrieved items.""" | |
| explanations = [] | |
| for i, (item, similarity) in enumerate(zip(result.items, result.similarities)): | |
| explanation = { | |
| 'rank': i + 1, | |
| 'similarity': similarity, | |
| 'content_preview': item.content[:100] + "..." if len(item.content) > 100 else item.content, | |
| 'factors': { | |
| 'semantic_match': similarity, | |
| 'recency': item.metadata.get('recency_weight', 0.5), | |
| 'importance': item.metadata.get('importance', 0.5), | |
| 'access_frequency': min(item.access_count / 10.0, 1.0) | |
| } | |
| } | |
| explanations.append(explanation) | |
| return explanations | |
| def _explain_stability(self, result: RetrievalResult) -> Dict[str, Any]: | |
| """Explain stability scores and persistence.""" | |
| if not result.stability_scores: | |
| return {'message': 'No stability information available'} | |
| avg_stability = np.mean(result.stability_scores) | |
| stability_distribution = { | |
| 'high_stability': sum(1 for s in result.stability_scores if s > 0.7), | |
| 'medium_stability': sum(1 for s in result.stability_scores if 0.3 <= s <= 0.7), | |
| 'low_stability': sum(1 for s in result.stability_scores if s < 0.3) | |
| } | |
| return { | |
| 'average_stability': avg_stability, | |
| 'distribution': stability_distribution, | |
| 'interpretation': self._interpret_stability(avg_stability) | |
| } | |
| def _interpret_stability(self, avg_stability: float) -> str: | |
| """Interpret stability score.""" | |
| if avg_stability > 0.8: | |
| return "Very stable memories - likely to persist long-term" | |
| elif avg_stability > 0.6: | |
| return "Moderately stable memories - established but may evolve" | |
| elif avg_stability > 0.4: | |
| return "Somewhat unstable memories - in plastic boundary band" | |
| else: | |
| return "Low stability memories - may be forgotten soon" | |
| def _generate_counterfactuals(self, query: str, result: RetrievalResult) -> List[str]: | |
| """Generate counterfactual explanations.""" | |
| counterfactuals = [] | |
| if result.confidence < 0.5: | |
| counterfactuals.append("If the query were more specific, retrieval confidence would be higher") | |
| if len(result.hops) > 0: | |
| counterfactuals.append(f"Without Julia-neighbor hops, {len(result.hops)} additional results would not have been found") | |
| if result.stability_scores and np.mean(result.stability_scores) < 0.5: | |
| counterfactuals.append("If persistence thresholds were lower, more stable memories would be available") | |
| return counterfactuals | |