Arthur Passuello
initial commit
5e1a30c
"""
In-memory embedding cache implementation.
This module provides a direct implementation of the EmbeddingCache interface
using in-memory storage with LRU eviction, content-based hashing, and
comprehensive statistics tracking.
Features:
- LRU eviction policy
- Content-based cache keys
- Memory usage monitoring
- Cache statistics and hit rate tracking
- Pattern-based invalidation
- Thread-safe operations
"""
import hashlib
import time
import fnmatch
import threading
from collections import OrderedDict
from typing import Optional, Dict, Any, List, Tuple
import numpy as np
import logging
from pathlib import Path
import sys
# Add project root for imports
project_root = Path(__file__).parent.parent.parent.parent.parent
sys.path.append(str(project_root))
from ..base import EmbeddingCache, ConfigurableEmbedderComponent
logger = logging.getLogger(__name__)
class MemoryCache(EmbeddingCache, ConfigurableEmbedderComponent):
"""
Direct implementation of in-memory embedding cache.
This cache uses content-based keys and LRU eviction to efficiently
store and retrieve embeddings while managing memory usage.
Configuration:
{
"max_entries": 100000,
"max_memory_mb": 1024, # Maximum memory usage in MB
"eviction_policy": "lru", # Currently only LRU supported
"ttl_seconds": null, # Time-to-live (null = no expiration)
"normalize_keys": true, # Normalize text before creating cache keys
"enable_statistics": true # Track detailed statistics
}
Performance Features:
- Fast O(1) lookup and insertion
- Memory-efficient storage
- LRU eviction for optimal hit rates
- Content-based deduplication
- Thread-safe operations
"""
def __init__(self, config: Dict[str, Any]):
"""
Initialize memory cache.
Args:
config: Cache configuration dictionary
"""
super().__init__(config)
# Configuration
self.max_entries = config.get("max_entries", 100000)
self.max_memory_mb = config.get("max_memory_mb", 1024)
self.eviction_policy = config.get("eviction_policy", "lru")
self.ttl_seconds = config.get("ttl_seconds")
self.normalize_keys = config.get("normalize_keys", True)
self.enable_statistics = config.get("enable_statistics", True)
# Cache storage (using OrderedDict for LRU)
self._cache: OrderedDict[str, Tuple[np.ndarray, float]] = OrderedDict()
self._lock = threading.RLock() # Thread safety
# Statistics
self._stats = {
"hits": 0,
"misses": 0,
"evictions": 0,
"invalidations": 0,
"memory_bytes": 0,
"created_time": time.time()
}
logger.info(f"MemoryCache initialized: max_entries={self.max_entries}, max_memory={self.max_memory_mb}MB")
def _validate_config(self) -> None:
"""
Validate cache configuration.
Raises:
ValueError: If configuration is invalid
"""
# Validate max_entries
max_entries = self.config.get("max_entries", 100000)
if not isinstance(max_entries, int) or max_entries < 1:
raise ValueError("max_entries must be a positive integer")
# Validate max_memory_mb
max_memory = self.config.get("max_memory_mb", 1024)
if not isinstance(max_memory, (int, float)) or max_memory <= 0:
raise ValueError("max_memory_mb must be a positive number")
# Validate eviction_policy
eviction_policy = self.config.get("eviction_policy", "lru")
if eviction_policy != "lru":
raise ValueError("Only 'lru' eviction policy is currently supported")
# Validate TTL
ttl = self.config.get("ttl_seconds")
if ttl is not None and (not isinstance(ttl, (int, float)) or ttl <= 0):
raise ValueError("ttl_seconds must be a positive number or null")
def _create_cache_key(self, text: str) -> str:
"""
Create a cache key from text content.
Args:
text: Text to create key for
Returns:
Cache key string
"""
# Normalize text if configured
if self.normalize_keys:
# Basic normalization: strip whitespace, lowercase
normalized_text = text.strip().lower()
else:
normalized_text = text
# Create content-based hash
content_hash = hashlib.sha256(normalized_text.encode('utf-8')).hexdigest()
return f"embedding:{content_hash[:16]}" # Truncate for efficiency
def _is_expired(self, timestamp: float) -> bool:
"""
Check if a cache entry is expired.
Args:
timestamp: Entry creation timestamp
Returns:
True if expired, False otherwise
"""
if self.ttl_seconds is None:
return False
return (time.time() - timestamp) > self.ttl_seconds
def _estimate_memory_usage(self, embedding: np.ndarray) -> int:
"""
Estimate memory usage of an embedding.
Args:
embedding: Embedding array
Returns:
Estimated memory usage in bytes
"""
# numpy array memory + overhead
array_bytes = embedding.nbytes
overhead_bytes = 64 # Approximate overhead for key, timestamp, etc.
return array_bytes + overhead_bytes
def _evict_if_necessary(self) -> None:
"""Evict entries if cache limits are exceeded."""
# Check entry count limit
while len(self._cache) >= self.max_entries:
self._evict_lru_entry()
# Check memory limit
max_memory_bytes = self.max_memory_mb * 1024 * 1024
while self._stats["memory_bytes"] > max_memory_bytes and self._cache:
self._evict_lru_entry()
def _evict_lru_entry(self) -> None:
"""Evict the least recently used entry."""
if not self._cache:
return
# Remove oldest entry (LRU)
key, (embedding, _) = self._cache.popitem(last=False)
# Update memory usage
memory_freed = self._estimate_memory_usage(embedding)
self._stats["memory_bytes"] -= memory_freed
self._stats["evictions"] += 1
logger.debug(f"Evicted cache entry: {key}, memory freed: {memory_freed} bytes")
def get(self, text: str) -> Optional[np.ndarray]:
"""
Retrieve cached embedding for text.
Args:
text: Text string to look up
Returns:
Cached embedding array or None if not found
Raises:
RuntimeError: If cache retrieval fails
"""
cache_key = self._create_cache_key(text)
with self._lock:
try:
if cache_key in self._cache:
embedding, timestamp = self._cache[cache_key]
# Check expiration
if self._is_expired(timestamp):
# Remove expired entry
del self._cache[cache_key]
memory_freed = self._estimate_memory_usage(embedding)
self._stats["memory_bytes"] -= memory_freed
if self.enable_statistics:
self._stats["misses"] += 1
return None
# Move to end (most recently used)
self._cache.move_to_end(cache_key)
if self.enable_statistics:
self._stats["hits"] += 1
logger.debug(f"Cache hit for key: {cache_key}")
return embedding.copy() # Return copy to prevent modification
else:
if self.enable_statistics:
self._stats["misses"] += 1
logger.debug(f"Cache miss for key: {cache_key}")
return None
except Exception as e:
raise RuntimeError(f"Cache retrieval failed for key '{cache_key}': {e}") from e
def put(self, text: str, embedding: np.ndarray) -> None:
"""
Store embedding in cache.
Args:
text: Text string as cache key
embedding: Embedding array to store
Raises:
ValueError: If text or embedding is invalid
RuntimeError: If cache storage fails
"""
if not text or not text.strip():
raise ValueError("Text key cannot be empty")
if embedding is None or embedding.size == 0:
raise ValueError("Embedding cannot be empty")
cache_key = self._create_cache_key(text)
current_time = time.time()
with self._lock:
try:
# Estimate memory usage
memory_needed = self._estimate_memory_usage(embedding)
# Remove existing entry if present
if cache_key in self._cache:
old_embedding, _ = self._cache[cache_key]
old_memory = self._estimate_memory_usage(old_embedding)
self._stats["memory_bytes"] -= old_memory
# Evict if necessary before adding
self._evict_if_necessary()
# Store new entry
self._cache[cache_key] = (embedding.copy(), current_time)
self._stats["memory_bytes"] += memory_needed
logger.debug(f"Cached embedding for key: {cache_key}, size: {memory_needed} bytes")
except Exception as e:
raise RuntimeError(f"Cache storage failed for key '{cache_key}': {e}") from e
def invalidate(self, pattern: str) -> int:
"""
Invalidate cache entries matching pattern.
Args:
pattern: Pattern to match for invalidation (supports wildcards)
Returns:
Number of entries invalidated
Raises:
RuntimeError: If invalidation fails
"""
with self._lock:
try:
keys_to_remove = []
# Find matching keys
for key in self._cache.keys():
if fnmatch.fnmatch(key, pattern):
keys_to_remove.append(key)
# Remove matching entries
invalidated_count = 0
for key in keys_to_remove:
if key in self._cache:
embedding, _ = self._cache[key]
memory_freed = self._estimate_memory_usage(embedding)
self._stats["memory_bytes"] -= memory_freed
del self._cache[key]
invalidated_count += 1
if self.enable_statistics:
self._stats["invalidations"] += invalidated_count
logger.info(f"Invalidated {invalidated_count} cache entries matching pattern: {pattern}")
return invalidated_count
except Exception as e:
raise RuntimeError(f"Cache invalidation failed for pattern '{pattern}': {e}") from e
def get_cache_stats(self) -> Dict[str, Any]:
"""
Get cache performance statistics.
Returns:
Dictionary with hit rate, size, evictions, etc.
"""
with self._lock:
total_requests = self._stats["hits"] + self._stats["misses"]
hit_rate = self._stats["hits"] / total_requests if total_requests > 0 else 0.0
return {
"size": len(self._cache),
"max_entries": self.max_entries,
"hits": self._stats["hits"],
"misses": self._stats["misses"],
"hit_rate": hit_rate,
"evictions": self._stats["evictions"],
"invalidations": self._stats["invalidations"],
"memory_bytes": self._stats["memory_bytes"],
"memory_mb": self._stats["memory_bytes"] / (1024 * 1024),
"max_memory_mb": self.max_memory_mb,
"memory_usage_percent": (self._stats["memory_bytes"] / (self.max_memory_mb * 1024 * 1024)) * 100,
"uptime_seconds": time.time() - self._stats["created_time"],
"ttl_seconds": self.ttl_seconds,
"eviction_policy": self.eviction_policy
}
def clear(self) -> None:
"""
Clear all entries from the cache.
Raises:
RuntimeError: If cache clearing fails
"""
with self._lock:
try:
self._cache.clear()
self._stats["memory_bytes"] = 0
logger.info("Cache cleared")
except Exception as e:
raise RuntimeError(f"Cache clearing failed: {e}") from e
def get_cache_size(self) -> int:
"""
Get current number of cached entries.
Returns:
Number of entries currently in cache
"""
with self._lock:
return len(self._cache)
def get_memory_info(self) -> Dict[str, Any]:
"""
Get detailed memory usage information.
Returns:
Dictionary with memory statistics
"""
with self._lock:
return {
"current_memory_bytes": self._stats["memory_bytes"],
"current_memory_mb": self._stats["memory_bytes"] / (1024 * 1024),
"max_memory_mb": self.max_memory_mb,
"memory_efficiency": self._stats["memory_bytes"] / len(self._cache) if self._cache else 0,
"average_embedding_size": self._stats["memory_bytes"] / len(self._cache) if self._cache else 0
}
def cleanup_expired(self) -> int:
"""
Remove all expired entries from cache.
Returns:
Number of expired entries removed
"""
if self.ttl_seconds is None:
return 0
with self._lock:
current_time = time.time()
expired_keys = []
for key, (embedding, timestamp) in self._cache.items():
if self._is_expired(timestamp):
expired_keys.append(key)
# Remove expired entries
for key in expired_keys:
if key in self._cache:
embedding, _ = self._cache[key]
memory_freed = self._estimate_memory_usage(embedding)
self._stats["memory_bytes"] -= memory_freed
del self._cache[key]
if expired_keys:
logger.info(f"Cleaned up {len(expired_keys)} expired cache entries")
return len(expired_keys)