Spaces:
Sleeping
Sleeping
| """ | |
| Evaluation metrics and testing framework for MandelMem system. | |
| Implements the evaluation criteria from the blueprint. | |
| """ | |
| import torch | |
| import numpy as np | |
| import time | |
| from typing import Dict, Any, List, Tuple, Optional | |
| from dataclasses import dataclass | |
| import matplotlib.pyplot as plt | |
| from collections import defaultdict | |
| from .core import MandelMem | |
| from .quadtree import MemoryItem | |
| from .training import SyntheticDataGenerator | |
| class EvaluationResult: | |
| """Result of evaluation run.""" | |
| metric_name: str | |
| score: float | |
| details: Dict[str, Any] | |
| timestamp: float | |
| class PersistenceFidelityEvaluator: | |
| """Evaluates survival curve of 'must-keep' items after N writes.""" | |
| def __init__(self, memory_system: MandelMem): | |
| self.memory_system = memory_system | |
| def evaluate(self, must_keep_items: List[str], n_writes: int = 1000, | |
| interference_items: List[str] = None) -> EvaluationResult: | |
| """Test persistence fidelity.""" | |
| start_time = time.time() | |
| # Write must-keep items with high importance | |
| must_keep_ids = [] | |
| for item in must_keep_items: | |
| result = self.memory_system.write( | |
| item, {'importance': 0.9, 'source': 'critical'} | |
| ) | |
| must_keep_ids.append((item, result.tile_id)) | |
| # Generate interference writes | |
| if interference_items is None: | |
| generator = SyntheticDataGenerator() | |
| interference_batch = generator.generate_batch(n_writes) | |
| interference_items = interference_batch.targets | |
| # Perform interference writes | |
| for item in interference_items[:n_writes]: | |
| self.memory_system.write(item, {'importance': 0.3, 'source': 'noise'}) | |
| # Check survival of must-keep items | |
| survival_count = 0 | |
| survival_details = [] | |
| for original_item, tile_id in must_keep_ids: | |
| # Try to retrieve the item | |
| read_result = self.memory_system.read(original_item, k=10) | |
| # Check if original item is in results | |
| found = False | |
| best_similarity = 0.0 | |
| for retrieved_item in read_result.items: | |
| if original_item.lower() in retrieved_item.content.lower(): | |
| found = True | |
| # Get similarity score | |
| idx = read_result.items.index(retrieved_item) | |
| best_similarity = read_result.similarities[idx] | |
| break | |
| if found: | |
| survival_count += 1 | |
| survival_details.append({ | |
| 'item': original_item, | |
| 'survived': found, | |
| 'best_similarity': best_similarity, | |
| 'tile_id': tile_id | |
| }) | |
| survival_rate = survival_count / len(must_keep_items) if must_keep_items else 0.0 | |
| return EvaluationResult( | |
| metric_name='persistence_fidelity', | |
| score=survival_rate, | |
| details={ | |
| 'survival_count': survival_count, | |
| 'total_items': len(must_keep_items), | |
| 'interference_writes': n_writes, | |
| 'survival_details': survival_details, | |
| 'evaluation_time': time.time() - start_time | |
| }, | |
| timestamp=time.time() | |
| ) | |
| class PlasticityEvaluator: | |
| """Evaluates edit success rate in boundary band without harming neighbors.""" | |
| def __init__(self, memory_system: MandelMem): | |
| self.memory_system = memory_system | |
| def evaluate(self, test_items: List[str], edit_fraction: float = 0.3) -> EvaluationResult: | |
| """Test plasticity in boundary band.""" | |
| start_time = time.time() | |
| # Write items with medium importance (likely to be in plastic band) | |
| original_items = [] | |
| for item in test_items: | |
| result = self.memory_system.write( | |
| item, {'importance': 0.5, 'source': 'test'} | |
| ) | |
| original_items.append((item, result)) | |
| # Wait briefly to let items settle | |
| time.sleep(0.1) | |
| # Attempt edits on subset of items | |
| edit_count = int(len(test_items) * edit_fraction) | |
| items_to_edit = test_items[:edit_count] | |
| neighbor_items = test_items[edit_count:] | |
| edit_success_count = 0 | |
| neighbor_harm_count = 0 | |
| for i, original_item in enumerate(items_to_edit): | |
| # Create edited version | |
| edited_item = f"{original_item} [EDITED]" | |
| # Write edited version | |
| edit_result = self.memory_system.write( | |
| edited_item, {'importance': 0.5, 'source': 'edit'} | |
| ) | |
| # Check if edit was successful (item persisted) | |
| if edit_result.persisted and edit_result.band in ['stable', 'plastic']: | |
| edit_success_count += 1 | |
| # Check if neighbor items were harmed | |
| for neighbor_item in neighbor_items: | |
| read_result = self.memory_system.read(neighbor_item, k=5) | |
| # Check if neighbor is still retrievable | |
| found = any(neighbor_item.lower() in item.content.lower() | |
| for item in read_result.items) | |
| if not found: | |
| neighbor_harm_count += 1 | |
| edit_success_rate = edit_success_count / len(items_to_edit) if items_to_edit else 0.0 | |
| neighbor_preservation_rate = 1.0 - (neighbor_harm_count / len(neighbor_items)) if neighbor_items else 1.0 | |
| # Combined plasticity score | |
| plasticity_score = 0.7 * edit_success_rate + 0.3 * neighbor_preservation_rate | |
| return EvaluationResult( | |
| metric_name='plasticity', | |
| score=plasticity_score, | |
| details={ | |
| 'edit_success_rate': edit_success_rate, | |
| 'neighbor_preservation_rate': neighbor_preservation_rate, | |
| 'edits_attempted': len(items_to_edit), | |
| 'edits_successful': edit_success_count, | |
| 'neighbors_harmed': neighbor_harm_count, | |
| 'evaluation_time': time.time() - start_time | |
| }, | |
| timestamp=time.time() | |
| ) | |
| class LatencyEvaluator: | |
| """Evaluates hops per read and P95 read time.""" | |
| def __init__(self, memory_system: MandelMem): | |
| self.memory_system = memory_system | |
| def evaluate(self, queries: List[str], n_runs: int = 100) -> EvaluationResult: | |
| """Test retrieval latency.""" | |
| start_time = time.time() | |
| read_times = [] | |
| hop_counts = [] | |
| trace_lengths = [] | |
| for _ in range(n_runs): | |
| for query in queries: | |
| query_start = time.time() | |
| result = self.memory_system.read(query, k=5, with_trace=True) | |
| query_time = time.time() - query_start | |
| read_times.append(query_time) | |
| # Count hops (Julia-neighbor hops) | |
| hop_count = len(result.trace) - 1 # Subtract root | |
| hop_counts.append(hop_count) | |
| trace_lengths.append(len(result.trace)) | |
| # Calculate statistics | |
| avg_read_time = np.mean(read_times) | |
| p95_read_time = np.percentile(read_times, 95) | |
| avg_hops = np.mean(hop_counts) | |
| max_hops = np.max(hop_counts) if hop_counts else 0 | |
| # Latency score (lower is better, normalize to 0-1) | |
| latency_score = max(0.0, 1.0 - (p95_read_time / 1.0)) # 1 second baseline | |
| return EvaluationResult( | |
| metric_name='latency', | |
| score=latency_score, | |
| details={ | |
| 'avg_read_time': avg_read_time, | |
| 'p95_read_time': p95_read_time, | |
| 'avg_hops_per_read': avg_hops, | |
| 'max_hops': max_hops, | |
| 'total_queries': len(queries) * n_runs, | |
| 'evaluation_time': time.time() - start_time | |
| }, | |
| timestamp=time.time() | |
| ) | |
| class InterferenceEvaluator: | |
| """Evaluates Δ-recall on old tasks after new task writes.""" | |
| def __init__(self, memory_system: MandelMem): | |
| self.memory_system = memory_system | |
| def evaluate(self, old_task_items: List[str], new_task_items: List[str], | |
| test_queries: List[str]) -> EvaluationResult: | |
| """Test interference resistance.""" | |
| start_time = time.time() | |
| # Phase 1: Write old task items | |
| for item in old_task_items: | |
| self.memory_system.write(item, {'importance': 0.7, 'source': 'old_task'}) | |
| # Measure baseline recall | |
| baseline_recall = self._measure_recall(test_queries, old_task_items) | |
| # Phase 2: Write new task items (interference) | |
| for item in new_task_items: | |
| self.memory_system.write(item, {'importance': 0.7, 'source': 'new_task'}) | |
| # Measure post-interference recall | |
| post_recall = self._measure_recall(test_queries, old_task_items) | |
| # Calculate interference (Δ-recall) | |
| delta_recall = baseline_recall - post_recall | |
| interference_resistance = max(0.0, 1.0 - delta_recall) | |
| return EvaluationResult( | |
| metric_name='interference_resistance', | |
| score=interference_resistance, | |
| details={ | |
| 'baseline_recall': baseline_recall, | |
| 'post_interference_recall': post_recall, | |
| 'delta_recall': delta_recall, | |
| 'old_task_items': len(old_task_items), | |
| 'new_task_items': len(new_task_items), | |
| 'test_queries': len(test_queries), | |
| 'evaluation_time': time.time() - start_time | |
| }, | |
| timestamp=time.time() | |
| ) | |
| def _measure_recall(self, queries: List[str], target_items: List[str]) -> float: | |
| """Measure recall for given queries and targets.""" | |
| total_found = 0 | |
| total_targets = 0 | |
| for query in queries: | |
| result = self.memory_system.read(query, k=10) | |
| # Count how many target items were found | |
| for target in target_items: | |
| if any(target.lower() in item.content.lower() for item in result.items): | |
| total_found += 1 | |
| total_targets += 1 | |
| return total_found / total_targets if total_targets > 0 else 0.0 | |
| class CompressionQualityEvaluator: | |
| """Evaluates retrieval@k before vs. after renormalization.""" | |
| def __init__(self, memory_system: MandelMem): | |
| self.memory_system = memory_system | |
| def evaluate(self, test_items: List[str], queries: List[str], k: int = 5) -> EvaluationResult: | |
| """Test compression quality preservation.""" | |
| start_time = time.time() | |
| # Write test items | |
| for item in test_items: | |
| self.memory_system.write(item, {'importance': 0.6, 'source': 'test'}) | |
| # Measure pre-renormalization retrieval quality | |
| pre_renorm_scores = [] | |
| for query in queries: | |
| result = self.memory_system.read(query, k=k) | |
| # Simple quality metric: average similarity of top-k results | |
| avg_similarity = np.mean(result.similarities) if result.similarities else 0.0 | |
| pre_renorm_scores.append(avg_similarity) | |
| pre_renorm_quality = np.mean(pre_renorm_scores) | |
| # Force renormalization on all tiles | |
| renorm_results = [] | |
| for tile in self.memory_system.quadtree.get_leaf_tiles(): | |
| if len(tile.get_all_items()) > 0: | |
| renorm_result = self.memory_system.renorm_engine.renormalize_tile(tile) | |
| renorm_results.append(renorm_result) | |
| # Measure post-renormalization retrieval quality | |
| post_renorm_scores = [] | |
| for query in queries: | |
| result = self.memory_system.read(query, k=k) | |
| avg_similarity = np.mean(result.similarities) if result.similarities else 0.0 | |
| post_renorm_scores.append(avg_similarity) | |
| post_renorm_quality = np.mean(post_renorm_scores) | |
| # Quality preservation score | |
| quality_preservation = post_renorm_quality / pre_renorm_quality if pre_renorm_quality > 0 else 1.0 | |
| quality_preservation = min(1.0, quality_preservation) # Cap at 1.0 | |
| return EvaluationResult( | |
| metric_name='compression_quality', | |
| score=quality_preservation, | |
| details={ | |
| 'pre_renorm_quality': pre_renorm_quality, | |
| 'post_renorm_quality': post_renorm_quality, | |
| 'quality_preservation_ratio': quality_preservation, | |
| 'tiles_renormalized': len(renorm_results), | |
| 'avg_compression_ratio': np.mean([r.compression_ratio for r in renorm_results]) if renorm_results else 1.0, | |
| 'evaluation_time': time.time() - start_time | |
| }, | |
| timestamp=time.time() | |
| ) | |
| class ExplainabilityEvaluator: | |
| """Evaluates % reads with coherent route + user-rated path plausibility.""" | |
| def __init__(self, memory_system: MandelMem): | |
| self.memory_system = memory_system | |
| def evaluate(self, queries: List[str], coherence_threshold: float = 0.7) -> EvaluationResult: | |
| """Test explainability of retrieval paths.""" | |
| start_time = time.time() | |
| coherent_routes = 0 | |
| total_routes = 0 | |
| route_details = [] | |
| for query in queries: | |
| result = self.memory_system.read(query, k=5, with_trace=True, with_explanation=True) | |
| if result.trace: | |
| # Analyze route coherence | |
| coherence_score = self._analyze_route_coherence(result.trace, query) | |
| if coherence_score >= coherence_threshold: | |
| coherent_routes += 1 | |
| route_details.append({ | |
| 'query': query, | |
| 'trace': result.trace, | |
| 'coherence_score': coherence_score, | |
| 'explanation_available': result.explanation is not None | |
| }) | |
| total_routes += 1 | |
| coherence_rate = coherent_routes / total_routes if total_routes > 0 else 0.0 | |
| return EvaluationResult( | |
| metric_name='explainability', | |
| score=coherence_rate, | |
| details={ | |
| 'coherent_routes': coherent_routes, | |
| 'total_routes': total_routes, | |
| 'coherence_rate': coherence_rate, | |
| 'route_details': route_details, | |
| 'evaluation_time': time.time() - start_time | |
| }, | |
| timestamp=time.time() | |
| ) | |
| def _analyze_route_coherence(self, trace: List[str], query: str) -> float: | |
| """Analyze coherence of routing trace.""" | |
| if len(trace) < 2: | |
| return 1.0 # Single tile is coherent | |
| # Simple coherence metric: consistent depth progression | |
| depth_progression = [] | |
| for tile_id in trace: | |
| if tile_id == "root": | |
| depth = 0 | |
| else: | |
| depth = len(tile_id) - 4 # Remove "root" prefix | |
| depth_progression.append(depth) | |
| # Check if depth increases monotonically | |
| monotonic = all(depth_progression[i] <= depth_progression[i+1] | |
| for i in range(len(depth_progression)-1)) | |
| # Check reasonable depth (not too shallow or deep) | |
| final_depth = depth_progression[-1] | |
| reasonable_depth = 3 <= final_depth <= 6 | |
| # Combine factors | |
| coherence_score = 0.5 * (1.0 if monotonic else 0.0) + 0.5 * (1.0 if reasonable_depth else 0.0) | |
| return coherence_score | |
| class ComprehensiveEvaluator: | |
| """Runs all evaluation metrics and provides comprehensive assessment.""" | |
| def __init__(self, memory_system: MandelMem): | |
| self.memory_system = memory_system | |
| self.evaluators = { | |
| 'persistence_fidelity': PersistenceFidelityEvaluator(memory_system), | |
| 'plasticity': PlasticityEvaluator(memory_system), | |
| 'latency': LatencyEvaluator(memory_system), | |
| 'interference_resistance': InterferenceEvaluator(memory_system), | |
| 'compression_quality': CompressionQualityEvaluator(memory_system), | |
| 'explainability': ExplainabilityEvaluator(memory_system) | |
| } | |
| def run_full_evaluation(self, test_config: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: | |
| """Run comprehensive evaluation suite.""" | |
| if test_config is None: | |
| test_config = self._get_default_test_config() | |
| results = {} | |
| overall_start = time.time() | |
| # Generate test data | |
| generator = SyntheticDataGenerator() | |
| test_batch = generator.generate_batch(100) | |
| # Run each evaluation | |
| print("Running comprehensive MandelMem evaluation...") | |
| # 1. Persistence Fidelity | |
| print(" - Testing persistence fidelity...") | |
| must_keep = test_batch.targets[:20] | |
| interference = test_batch.targets[20:80] | |
| results['persistence_fidelity'] = self.evaluators['persistence_fidelity'].evaluate( | |
| must_keep, n_writes=500, interference_items=interference | |
| ) | |
| # 2. Plasticity | |
| print(" - Testing plasticity...") | |
| plasticity_items = test_batch.targets[80:100] | |
| results['plasticity'] = self.evaluators['plasticity'].evaluate(plasticity_items) | |
| # 3. Latency | |
| print(" - Testing latency...") | |
| queries = test_batch.queries[:10] | |
| results['latency'] = self.evaluators['latency'].evaluate(queries, n_runs=20) | |
| # 4. Interference Resistance | |
| print(" - Testing interference resistance...") | |
| old_task = test_batch.targets[:30] | |
| new_task = test_batch.targets[30:60] | |
| test_queries = test_batch.queries[:15] | |
| results['interference_resistance'] = self.evaluators['interference_resistance'].evaluate( | |
| old_task, new_task, test_queries | |
| ) | |
| # 5. Compression Quality | |
| print(" - Testing compression quality...") | |
| compression_items = test_batch.targets[60:90] | |
| compression_queries = test_batch.queries[15:25] | |
| results['compression_quality'] = self.evaluators['compression_quality'].evaluate( | |
| compression_items, compression_queries | |
| ) | |
| # 6. Explainability | |
| print(" - Testing explainability...") | |
| explain_queries = test_batch.queries[25:35] | |
| results['explainability'] = self.evaluators['explainability'].evaluate(explain_queries) | |
| # Calculate overall score | |
| scores = [result.score for result in results.values()] | |
| overall_score = np.mean(scores) | |
| evaluation_summary = { | |
| 'overall_score': overall_score, | |
| 'individual_results': results, | |
| 'evaluation_time': time.time() - overall_start, | |
| 'test_config': test_config, | |
| 'system_stats': self.memory_system.get_statistics() | |
| } | |
| print(f"Evaluation completed. Overall score: {overall_score:.3f}") | |
| return evaluation_summary | |
| def _get_default_test_config(self) -> Dict[str, Any]: | |
| """Get default test configuration.""" | |
| return { | |
| 'persistence_writes': 500, | |
| 'plasticity_edit_fraction': 0.3, | |
| 'latency_runs': 20, | |
| 'compression_k': 5, | |
| 'coherence_threshold': 0.7 | |
| } | |
| def plot_evaluation_results(self, results: Dict[str, Any], save_path: Optional[str] = None): | |
| """Plot evaluation results.""" | |
| individual_results = results['individual_results'] | |
| # Extract scores and names | |
| metric_names = list(individual_results.keys()) | |
| scores = [individual_results[name].score for name in metric_names] | |
| # Create bar plot | |
| plt.figure(figsize=(12, 6)) | |
| bars = plt.bar(metric_names, scores, color='skyblue', alpha=0.7) | |
| # Add value labels on bars | |
| for bar, score in zip(bars, scores): | |
| plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01, | |
| f'{score:.3f}', ha='center', va='bottom') | |
| plt.title('MandelMem Evaluation Results') | |
| plt.ylabel('Score (0-1)') | |
| plt.ylim(0, 1.1) | |
| plt.xticks(rotation=45, ha='right') | |
| plt.grid(axis='y', alpha=0.3) | |
| # Add overall score line | |
| overall_score = results['overall_score'] | |
| plt.axhline(y=overall_score, color='red', linestyle='--', | |
| label=f'Overall Score: {overall_score:.3f}') | |
| plt.legend() | |
| plt.tight_layout() | |
| if save_path: | |
| plt.savefig(save_path, dpi=300, bbox_inches='tight') | |
| plt.show() | |
| def run_quick_evaluation(memory_system: MandelMem) -> Dict[str, float]: | |
| """Run quick evaluation for development/debugging.""" | |
| evaluator = ComprehensiveEvaluator(memory_system) | |
| # Quick test with minimal data | |
| generator = SyntheticDataGenerator() | |
| test_batch = generator.generate_batch(20) | |
| quick_results = {} | |
| # Test basic functionality | |
| try: | |
| # Write and read test | |
| write_result = memory_system.write("test memory item", {'importance': 0.8}) | |
| read_result = memory_system.read("test memory", k=3) | |
| quick_results['basic_functionality'] = 1.0 if read_result.items else 0.0 | |
| quick_results['write_persistence'] = 1.0 if write_result.persisted else 0.0 | |
| quick_results['read_latency'] = min(1.0, 1.0 / max(read_result.total_time, 0.001)) | |
| # System stats | |
| stats = memory_system.get_statistics() | |
| quick_results['memory_utilization'] = min(1.0, stats['quadtree']['total_items'] / 100) | |
| except Exception as e: | |
| print(f"Quick evaluation error: {e}") | |
| quick_results['error'] = 0.0 | |
| return quick_results | |