""" Embedding management system using sentence-transformers with caching and optimization. """ import os import pickle import hashlib from pathlib import Path from typing import Dict, List, Optional, Tuple, Any, Union import numpy as np import torch from sentence_transformers import SentenceTransformer from tqdm import tqdm import time from .error_handler import EmbeddingError, ResourceError from .cache_manager import CacheManager class EmbeddingManager: """Manages document embeddings with caching and batch processing.""" def __init__(self, config: Dict[str, Any], cache_manager: Optional[CacheManager] = None): self.config = config self.model_config = config.get("models", {}).get("embedding", {}) self.cache_config = config.get("cache", {}) self.model_name = self.model_config.get("name", "sentence-transformers/all-MiniLM-L6-v2") self.max_seq_length = self.model_config.get("max_seq_length", 256) self.batch_size = self.model_config.get("batch_size", 32) self.device = self._get_device() self.model: Optional[SentenceTransformer] = None self.cache_manager = cache_manager self._model_loaded = False # Performance tracking self.stats = { "embeddings_generated": 0, "cache_hits": 0, "total_time": 0, "batch_count": 0 } def _get_device(self) -> str: """Determine the best device for computation.""" device_config = self.model_config.get("device", "auto") if device_config == "auto": if torch.cuda.is_available(): return "cuda" elif torch.backends.mps.is_available(): # Apple Silicon return "mps" else: return "cpu" else: return device_config def _load_model(self) -> None: """Lazy load the sentence transformer model.""" if self._model_loaded and self.model is not None: return try: print(f"Loading embedding model: {self.model_name}") start_time = time.time() self.model = SentenceTransformer(self.model_name) # Set device if self.device != "cpu": self.model = self.model.to(self.device) # Set max sequence length if hasattr(self.model, 'max_seq_length'): self.model.max_seq_length = self.max_seq_length load_time = time.time() - start_time print(f"Model loaded in {load_time:.2f}s on device: {self.device}") self._model_loaded = True except Exception as e: raise EmbeddingError(f"Failed to load embedding model: {str(e)}") from e def generate_embeddings( self, texts: List[str], show_progress: bool = True ) -> np.ndarray: """ Generate embeddings for a list of texts with caching. Args: texts: List of text strings to embed show_progress: Whether to show progress bar Returns: Array of embeddings with shape (len(texts), embedding_dim) """ if not texts: return np.array([]) start_time = time.time() # Check cache for existing embeddings cached_embeddings, missing_indices, missing_texts = self._check_cache(texts) # Generate embeddings for missing texts if missing_texts: self._load_model() new_embeddings = self._generate_batch_embeddings(missing_texts, show_progress) # Cache new embeddings self._cache_embeddings(missing_texts, new_embeddings) else: new_embeddings = np.array([]) # Combine cached and new embeddings all_embeddings = self._combine_embeddings(texts, cached_embeddings, missing_indices, new_embeddings) # Update stats generation_time = time.time() - start_time self.stats["total_time"] += generation_time self.stats["embeddings_generated"] += len(texts) return all_embeddings def _check_cache(self, texts: List[str]) -> Tuple[Dict[str, np.ndarray], List[int], List[str]]: """Check cache for existing embeddings.""" cached_embeddings = {} missing_indices = [] missing_texts = [] if not self.cache_manager: return cached_embeddings, list(range(len(texts))), texts for i, text in enumerate(texts): cache_key = self._get_cache_key(text) cached_embedding = self.cache_manager.get(f"embedding_{cache_key}") if cached_embedding is not None: cached_embeddings[text] = cached_embedding self.stats["cache_hits"] += 1 else: missing_indices.append(i) missing_texts.append(text) return cached_embeddings, missing_indices, missing_texts def _generate_batch_embeddings(self, texts: List[str], show_progress: bool) -> np.ndarray: """Generate embeddings in batches.""" try: embeddings = [] # Process in batches batches = [texts[i:i + self.batch_size] for i in range(0, len(texts), self.batch_size)] if show_progress and len(batches) > 1: batches = tqdm(batches, desc="Generating embeddings") for batch in batches: try: # Generate embeddings for batch batch_embeddings = self.model.encode( batch, convert_to_numpy=True, show_progress_bar=False, batch_size=len(batch) ) embeddings.append(batch_embeddings) self.stats["batch_count"] += 1 except Exception as e: raise EmbeddingError(f"Failed to generate embeddings for batch: {str(e)}") from e if not embeddings: return np.array([]) # Combine all batch embeddings all_embeddings = np.vstack(embeddings) return all_embeddings except torch.cuda.OutOfMemoryError as e: raise ResourceError( "GPU memory insufficient for embedding generation. " "Try reducing batch_size or using CPU." ) from e except Exception as e: raise EmbeddingError(f"Failed to generate embeddings: {str(e)}") from e def _cache_embeddings(self, texts: List[str], embeddings: np.ndarray) -> None: """Cache generated embeddings.""" if not self.cache_manager or embeddings.size == 0: return for text, embedding in zip(texts, embeddings): cache_key = self._get_cache_key(text) self.cache_manager.set(f"embedding_{cache_key}", embedding) def _combine_embeddings( self, original_texts: List[str], cached_embeddings: Dict[str, np.ndarray], missing_indices: List[int], new_embeddings: np.ndarray ) -> np.ndarray: """Combine cached and newly generated embeddings.""" if not original_texts: return np.array([]) # Get embedding dimension if new_embeddings.size > 0: embedding_dim = new_embeddings.shape[1] elif cached_embeddings: embedding_dim = next(iter(cached_embeddings.values())).shape[0] else: # Fallback - load model to get dimension self._load_model() sample_embedding = self.model.encode(["sample"], convert_to_numpy=True) embedding_dim = sample_embedding.shape[1] # Initialize result array result = np.zeros((len(original_texts), embedding_dim)) # Fill in cached embeddings for i, text in enumerate(original_texts): if text in cached_embeddings: result[i] = cached_embeddings[text] # Fill in new embeddings if new_embeddings.size > 0: for i, original_idx in enumerate(missing_indices): result[original_idx] = new_embeddings[i] return result def _get_cache_key(self, text: str) -> str: """Generate cache key for text.""" # Include model name and config in hash for cache invalidation cache_input = f"{self.model_name}_{self.max_seq_length}_{text}" return hashlib.md5(cache_input.encode()).hexdigest() def get_embedding_dimension(self) -> int: """Get the dimension of embeddings.""" self._load_model() # Generate a sample embedding to get dimensions sample_embedding = self.model.encode(["sample"], convert_to_numpy=True) return sample_embedding.shape[1] def compute_similarity(self, query_embedding: np.ndarray, doc_embeddings: np.ndarray) -> np.ndarray: """Compute cosine similarity between query and document embeddings.""" if query_embedding.ndim == 1: query_embedding = query_embedding.reshape(1, -1) # Normalize vectors query_norm = query_embedding / np.linalg.norm(query_embedding, axis=1, keepdims=True) doc_norm = doc_embeddings / np.linalg.norm(doc_embeddings, axis=1, keepdims=True) # Compute cosine similarity similarities = np.dot(query_norm, doc_norm.T).flatten() return similarities def clear_cache(self) -> None: """Clear embedding cache.""" if self.cache_manager: # Clear all embedding entries keys_to_remove = [] for key in self.cache_manager._memory_cache.keys(): if key.startswith("embedding_"): keys_to_remove.append(key) for key in keys_to_remove: self.cache_manager.delete(key) def get_stats(self) -> Dict[str, Any]: """Get performance statistics.""" stats = self.stats.copy() if stats["embeddings_generated"] > 0: stats["avg_time_per_embedding"] = stats["total_time"] / stats["embeddings_generated"] else: stats["avg_time_per_embedding"] = 0 if stats["batch_count"] > 0: stats["avg_batch_size"] = stats["embeddings_generated"] / stats["batch_count"] else: stats["avg_batch_size"] = 0 stats["cache_hit_rate"] = ( stats["cache_hits"] / (stats["cache_hits"] + stats["embeddings_generated"]) if (stats["cache_hits"] + stats["embeddings_generated"]) > 0 else 0 ) stats["model_loaded"] = self._model_loaded stats["device"] = self.device return stats def warmup(self) -> None: """Warm up the model with a sample embedding.""" self._load_model() # Generate a sample embedding to warm up the model sample_texts = ["This is a sample text for model warmup."] self.model.encode(sample_texts, convert_to_numpy=True, show_progress_bar=False) print("Embedding model warmed up successfully") def unload_model(self) -> None: """Unload the model to free memory.""" if self.model is not None: del self.model self.model = None self._model_loaded = False # Clear GPU cache if using CUDA if torch.cuda.is_available(): torch.cuda.empty_cache() print("Embedding model unloaded")