Mandelmem / mandelmem /evaluation.py
Kossisoroyce's picture
Upload 10 files
c05fcc5 verified
"""
Evaluation metrics and testing framework for MandelMem system.
Implements the evaluation criteria from the blueprint.
"""
import torch
import numpy as np
import time
from typing import Dict, Any, List, Tuple, Optional
from dataclasses import dataclass
import matplotlib.pyplot as plt
from collections import defaultdict
from .core import MandelMem
from .quadtree import MemoryItem
from .training import SyntheticDataGenerator
@dataclass
class EvaluationResult:
"""Result of evaluation run."""
metric_name: str
score: float
details: Dict[str, Any]
timestamp: float
class PersistenceFidelityEvaluator:
"""Evaluates survival curve of 'must-keep' items after N writes."""
def __init__(self, memory_system: MandelMem):
self.memory_system = memory_system
def evaluate(self, must_keep_items: List[str], n_writes: int = 1000,
interference_items: List[str] = None) -> EvaluationResult:
"""Test persistence fidelity."""
start_time = time.time()
# Write must-keep items with high importance
must_keep_ids = []
for item in must_keep_items:
result = self.memory_system.write(
item, {'importance': 0.9, 'source': 'critical'}
)
must_keep_ids.append((item, result.tile_id))
# Generate interference writes
if interference_items is None:
generator = SyntheticDataGenerator()
interference_batch = generator.generate_batch(n_writes)
interference_items = interference_batch.targets
# Perform interference writes
for item in interference_items[:n_writes]:
self.memory_system.write(item, {'importance': 0.3, 'source': 'noise'})
# Check survival of must-keep items
survival_count = 0
survival_details = []
for original_item, tile_id in must_keep_ids:
# Try to retrieve the item
read_result = self.memory_system.read(original_item, k=10)
# Check if original item is in results
found = False
best_similarity = 0.0
for retrieved_item in read_result.items:
if original_item.lower() in retrieved_item.content.lower():
found = True
# Get similarity score
idx = read_result.items.index(retrieved_item)
best_similarity = read_result.similarities[idx]
break
if found:
survival_count += 1
survival_details.append({
'item': original_item,
'survived': found,
'best_similarity': best_similarity,
'tile_id': tile_id
})
survival_rate = survival_count / len(must_keep_items) if must_keep_items else 0.0
return EvaluationResult(
metric_name='persistence_fidelity',
score=survival_rate,
details={
'survival_count': survival_count,
'total_items': len(must_keep_items),
'interference_writes': n_writes,
'survival_details': survival_details,
'evaluation_time': time.time() - start_time
},
timestamp=time.time()
)
class PlasticityEvaluator:
"""Evaluates edit success rate in boundary band without harming neighbors."""
def __init__(self, memory_system: MandelMem):
self.memory_system = memory_system
def evaluate(self, test_items: List[str], edit_fraction: float = 0.3) -> EvaluationResult:
"""Test plasticity in boundary band."""
start_time = time.time()
# Write items with medium importance (likely to be in plastic band)
original_items = []
for item in test_items:
result = self.memory_system.write(
item, {'importance': 0.5, 'source': 'test'}
)
original_items.append((item, result))
# Wait briefly to let items settle
time.sleep(0.1)
# Attempt edits on subset of items
edit_count = int(len(test_items) * edit_fraction)
items_to_edit = test_items[:edit_count]
neighbor_items = test_items[edit_count:]
edit_success_count = 0
neighbor_harm_count = 0
for i, original_item in enumerate(items_to_edit):
# Create edited version
edited_item = f"{original_item} [EDITED]"
# Write edited version
edit_result = self.memory_system.write(
edited_item, {'importance': 0.5, 'source': 'edit'}
)
# Check if edit was successful (item persisted)
if edit_result.persisted and edit_result.band in ['stable', 'plastic']:
edit_success_count += 1
# Check if neighbor items were harmed
for neighbor_item in neighbor_items:
read_result = self.memory_system.read(neighbor_item, k=5)
# Check if neighbor is still retrievable
found = any(neighbor_item.lower() in item.content.lower()
for item in read_result.items)
if not found:
neighbor_harm_count += 1
edit_success_rate = edit_success_count / len(items_to_edit) if items_to_edit else 0.0
neighbor_preservation_rate = 1.0 - (neighbor_harm_count / len(neighbor_items)) if neighbor_items else 1.0
# Combined plasticity score
plasticity_score = 0.7 * edit_success_rate + 0.3 * neighbor_preservation_rate
return EvaluationResult(
metric_name='plasticity',
score=plasticity_score,
details={
'edit_success_rate': edit_success_rate,
'neighbor_preservation_rate': neighbor_preservation_rate,
'edits_attempted': len(items_to_edit),
'edits_successful': edit_success_count,
'neighbors_harmed': neighbor_harm_count,
'evaluation_time': time.time() - start_time
},
timestamp=time.time()
)
class LatencyEvaluator:
"""Evaluates hops per read and P95 read time."""
def __init__(self, memory_system: MandelMem):
self.memory_system = memory_system
def evaluate(self, queries: List[str], n_runs: int = 100) -> EvaluationResult:
"""Test retrieval latency."""
start_time = time.time()
read_times = []
hop_counts = []
trace_lengths = []
for _ in range(n_runs):
for query in queries:
query_start = time.time()
result = self.memory_system.read(query, k=5, with_trace=True)
query_time = time.time() - query_start
read_times.append(query_time)
# Count hops (Julia-neighbor hops)
hop_count = len(result.trace) - 1 # Subtract root
hop_counts.append(hop_count)
trace_lengths.append(len(result.trace))
# Calculate statistics
avg_read_time = np.mean(read_times)
p95_read_time = np.percentile(read_times, 95)
avg_hops = np.mean(hop_counts)
max_hops = np.max(hop_counts) if hop_counts else 0
# Latency score (lower is better, normalize to 0-1)
latency_score = max(0.0, 1.0 - (p95_read_time / 1.0)) # 1 second baseline
return EvaluationResult(
metric_name='latency',
score=latency_score,
details={
'avg_read_time': avg_read_time,
'p95_read_time': p95_read_time,
'avg_hops_per_read': avg_hops,
'max_hops': max_hops,
'total_queries': len(queries) * n_runs,
'evaluation_time': time.time() - start_time
},
timestamp=time.time()
)
class InterferenceEvaluator:
"""Evaluates Δ-recall on old tasks after new task writes."""
def __init__(self, memory_system: MandelMem):
self.memory_system = memory_system
def evaluate(self, old_task_items: List[str], new_task_items: List[str],
test_queries: List[str]) -> EvaluationResult:
"""Test interference resistance."""
start_time = time.time()
# Phase 1: Write old task items
for item in old_task_items:
self.memory_system.write(item, {'importance': 0.7, 'source': 'old_task'})
# Measure baseline recall
baseline_recall = self._measure_recall(test_queries, old_task_items)
# Phase 2: Write new task items (interference)
for item in new_task_items:
self.memory_system.write(item, {'importance': 0.7, 'source': 'new_task'})
# Measure post-interference recall
post_recall = self._measure_recall(test_queries, old_task_items)
# Calculate interference (Δ-recall)
delta_recall = baseline_recall - post_recall
interference_resistance = max(0.0, 1.0 - delta_recall)
return EvaluationResult(
metric_name='interference_resistance',
score=interference_resistance,
details={
'baseline_recall': baseline_recall,
'post_interference_recall': post_recall,
'delta_recall': delta_recall,
'old_task_items': len(old_task_items),
'new_task_items': len(new_task_items),
'test_queries': len(test_queries),
'evaluation_time': time.time() - start_time
},
timestamp=time.time()
)
def _measure_recall(self, queries: List[str], target_items: List[str]) -> float:
"""Measure recall for given queries and targets."""
total_found = 0
total_targets = 0
for query in queries:
result = self.memory_system.read(query, k=10)
# Count how many target items were found
for target in target_items:
if any(target.lower() in item.content.lower() for item in result.items):
total_found += 1
total_targets += 1
return total_found / total_targets if total_targets > 0 else 0.0
class CompressionQualityEvaluator:
"""Evaluates retrieval@k before vs. after renormalization."""
def __init__(self, memory_system: MandelMem):
self.memory_system = memory_system
def evaluate(self, test_items: List[str], queries: List[str], k: int = 5) -> EvaluationResult:
"""Test compression quality preservation."""
start_time = time.time()
# Write test items
for item in test_items:
self.memory_system.write(item, {'importance': 0.6, 'source': 'test'})
# Measure pre-renormalization retrieval quality
pre_renorm_scores = []
for query in queries:
result = self.memory_system.read(query, k=k)
# Simple quality metric: average similarity of top-k results
avg_similarity = np.mean(result.similarities) if result.similarities else 0.0
pre_renorm_scores.append(avg_similarity)
pre_renorm_quality = np.mean(pre_renorm_scores)
# Force renormalization on all tiles
renorm_results = []
for tile in self.memory_system.quadtree.get_leaf_tiles():
if len(tile.get_all_items()) > 0:
renorm_result = self.memory_system.renorm_engine.renormalize_tile(tile)
renorm_results.append(renorm_result)
# Measure post-renormalization retrieval quality
post_renorm_scores = []
for query in queries:
result = self.memory_system.read(query, k=k)
avg_similarity = np.mean(result.similarities) if result.similarities else 0.0
post_renorm_scores.append(avg_similarity)
post_renorm_quality = np.mean(post_renorm_scores)
# Quality preservation score
quality_preservation = post_renorm_quality / pre_renorm_quality if pre_renorm_quality > 0 else 1.0
quality_preservation = min(1.0, quality_preservation) # Cap at 1.0
return EvaluationResult(
metric_name='compression_quality',
score=quality_preservation,
details={
'pre_renorm_quality': pre_renorm_quality,
'post_renorm_quality': post_renorm_quality,
'quality_preservation_ratio': quality_preservation,
'tiles_renormalized': len(renorm_results),
'avg_compression_ratio': np.mean([r.compression_ratio for r in renorm_results]) if renorm_results else 1.0,
'evaluation_time': time.time() - start_time
},
timestamp=time.time()
)
class ExplainabilityEvaluator:
"""Evaluates % reads with coherent route + user-rated path plausibility."""
def __init__(self, memory_system: MandelMem):
self.memory_system = memory_system
def evaluate(self, queries: List[str], coherence_threshold: float = 0.7) -> EvaluationResult:
"""Test explainability of retrieval paths."""
start_time = time.time()
coherent_routes = 0
total_routes = 0
route_details = []
for query in queries:
result = self.memory_system.read(query, k=5, with_trace=True, with_explanation=True)
if result.trace:
# Analyze route coherence
coherence_score = self._analyze_route_coherence(result.trace, query)
if coherence_score >= coherence_threshold:
coherent_routes += 1
route_details.append({
'query': query,
'trace': result.trace,
'coherence_score': coherence_score,
'explanation_available': result.explanation is not None
})
total_routes += 1
coherence_rate = coherent_routes / total_routes if total_routes > 0 else 0.0
return EvaluationResult(
metric_name='explainability',
score=coherence_rate,
details={
'coherent_routes': coherent_routes,
'total_routes': total_routes,
'coherence_rate': coherence_rate,
'route_details': route_details,
'evaluation_time': time.time() - start_time
},
timestamp=time.time()
)
def _analyze_route_coherence(self, trace: List[str], query: str) -> float:
"""Analyze coherence of routing trace."""
if len(trace) < 2:
return 1.0 # Single tile is coherent
# Simple coherence metric: consistent depth progression
depth_progression = []
for tile_id in trace:
if tile_id == "root":
depth = 0
else:
depth = len(tile_id) - 4 # Remove "root" prefix
depth_progression.append(depth)
# Check if depth increases monotonically
monotonic = all(depth_progression[i] <= depth_progression[i+1]
for i in range(len(depth_progression)-1))
# Check reasonable depth (not too shallow or deep)
final_depth = depth_progression[-1]
reasonable_depth = 3 <= final_depth <= 6
# Combine factors
coherence_score = 0.5 * (1.0 if monotonic else 0.0) + 0.5 * (1.0 if reasonable_depth else 0.0)
return coherence_score
class ComprehensiveEvaluator:
"""Runs all evaluation metrics and provides comprehensive assessment."""
def __init__(self, memory_system: MandelMem):
self.memory_system = memory_system
self.evaluators = {
'persistence_fidelity': PersistenceFidelityEvaluator(memory_system),
'plasticity': PlasticityEvaluator(memory_system),
'latency': LatencyEvaluator(memory_system),
'interference_resistance': InterferenceEvaluator(memory_system),
'compression_quality': CompressionQualityEvaluator(memory_system),
'explainability': ExplainabilityEvaluator(memory_system)
}
def run_full_evaluation(self, test_config: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
"""Run comprehensive evaluation suite."""
if test_config is None:
test_config = self._get_default_test_config()
results = {}
overall_start = time.time()
# Generate test data
generator = SyntheticDataGenerator()
test_batch = generator.generate_batch(100)
# Run each evaluation
print("Running comprehensive MandelMem evaluation...")
# 1. Persistence Fidelity
print(" - Testing persistence fidelity...")
must_keep = test_batch.targets[:20]
interference = test_batch.targets[20:80]
results['persistence_fidelity'] = self.evaluators['persistence_fidelity'].evaluate(
must_keep, n_writes=500, interference_items=interference
)
# 2. Plasticity
print(" - Testing plasticity...")
plasticity_items = test_batch.targets[80:100]
results['plasticity'] = self.evaluators['plasticity'].evaluate(plasticity_items)
# 3. Latency
print(" - Testing latency...")
queries = test_batch.queries[:10]
results['latency'] = self.evaluators['latency'].evaluate(queries, n_runs=20)
# 4. Interference Resistance
print(" - Testing interference resistance...")
old_task = test_batch.targets[:30]
new_task = test_batch.targets[30:60]
test_queries = test_batch.queries[:15]
results['interference_resistance'] = self.evaluators['interference_resistance'].evaluate(
old_task, new_task, test_queries
)
# 5. Compression Quality
print(" - Testing compression quality...")
compression_items = test_batch.targets[60:90]
compression_queries = test_batch.queries[15:25]
results['compression_quality'] = self.evaluators['compression_quality'].evaluate(
compression_items, compression_queries
)
# 6. Explainability
print(" - Testing explainability...")
explain_queries = test_batch.queries[25:35]
results['explainability'] = self.evaluators['explainability'].evaluate(explain_queries)
# Calculate overall score
scores = [result.score for result in results.values()]
overall_score = np.mean(scores)
evaluation_summary = {
'overall_score': overall_score,
'individual_results': results,
'evaluation_time': time.time() - overall_start,
'test_config': test_config,
'system_stats': self.memory_system.get_statistics()
}
print(f"Evaluation completed. Overall score: {overall_score:.3f}")
return evaluation_summary
def _get_default_test_config(self) -> Dict[str, Any]:
"""Get default test configuration."""
return {
'persistence_writes': 500,
'plasticity_edit_fraction': 0.3,
'latency_runs': 20,
'compression_k': 5,
'coherence_threshold': 0.7
}
def plot_evaluation_results(self, results: Dict[str, Any], save_path: Optional[str] = None):
"""Plot evaluation results."""
individual_results = results['individual_results']
# Extract scores and names
metric_names = list(individual_results.keys())
scores = [individual_results[name].score for name in metric_names]
# Create bar plot
plt.figure(figsize=(12, 6))
bars = plt.bar(metric_names, scores, color='skyblue', alpha=0.7)
# Add value labels on bars
for bar, score in zip(bars, scores):
plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
f'{score:.3f}', ha='center', va='bottom')
plt.title('MandelMem Evaluation Results')
plt.ylabel('Score (0-1)')
plt.ylim(0, 1.1)
plt.xticks(rotation=45, ha='right')
plt.grid(axis='y', alpha=0.3)
# Add overall score line
overall_score = results['overall_score']
plt.axhline(y=overall_score, color='red', linestyle='--',
label=f'Overall Score: {overall_score:.3f}')
plt.legend()
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
plt.show()
def run_quick_evaluation(memory_system: MandelMem) -> Dict[str, float]:
"""Run quick evaluation for development/debugging."""
evaluator = ComprehensiveEvaluator(memory_system)
# Quick test with minimal data
generator = SyntheticDataGenerator()
test_batch = generator.generate_batch(20)
quick_results = {}
# Test basic functionality
try:
# Write and read test
write_result = memory_system.write("test memory item", {'importance': 0.8})
read_result = memory_system.read("test memory", k=3)
quick_results['basic_functionality'] = 1.0 if read_result.items else 0.0
quick_results['write_persistence'] = 1.0 if write_result.persisted else 0.0
quick_results['read_latency'] = min(1.0, 1.0 / max(read_result.total_time, 0.001))
# System stats
stats = memory_system.get_statistics()
quick_results['memory_utilization'] = min(1.0, stats['quadtree']['total_items'] / 100)
except Exception as e:
print(f"Quick evaluation error: {e}")
quick_results['error'] = 0.0
return quick_results