RAG / src /embedding_manager.py
Jialun He
1st version
11d9dfb
"""
Embedding management system using sentence-transformers with caching and optimization.
"""
import os
import pickle
import hashlib
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Any, Union
import numpy as np
import torch
from sentence_transformers import SentenceTransformer
from tqdm import tqdm
import time
from .error_handler import EmbeddingError, ResourceError
from .cache_manager import CacheManager
class EmbeddingManager:
"""Manages document embeddings with caching and batch processing."""
def __init__(self, config: Dict[str, Any], cache_manager: Optional[CacheManager] = None):
self.config = config
self.model_config = config.get("models", {}).get("embedding", {})
self.cache_config = config.get("cache", {})
self.model_name = self.model_config.get("name", "sentence-transformers/all-MiniLM-L6-v2")
self.max_seq_length = self.model_config.get("max_seq_length", 256)
self.batch_size = self.model_config.get("batch_size", 32)
self.device = self._get_device()
self.model: Optional[SentenceTransformer] = None
self.cache_manager = cache_manager
self._model_loaded = False
# Performance tracking
self.stats = {
"embeddings_generated": 0,
"cache_hits": 0,
"total_time": 0,
"batch_count": 0
}
def _get_device(self) -> str:
"""Determine the best device for computation."""
device_config = self.model_config.get("device", "auto")
if device_config == "auto":
if torch.cuda.is_available():
return "cuda"
elif torch.backends.mps.is_available(): # Apple Silicon
return "mps"
else:
return "cpu"
else:
return device_config
def _load_model(self) -> None:
"""Lazy load the sentence transformer model."""
if self._model_loaded and self.model is not None:
return
try:
print(f"Loading embedding model: {self.model_name}")
start_time = time.time()
self.model = SentenceTransformer(self.model_name)
# Set device
if self.device != "cpu":
self.model = self.model.to(self.device)
# Set max sequence length
if hasattr(self.model, 'max_seq_length'):
self.model.max_seq_length = self.max_seq_length
load_time = time.time() - start_time
print(f"Model loaded in {load_time:.2f}s on device: {self.device}")
self._model_loaded = True
except Exception as e:
raise EmbeddingError(f"Failed to load embedding model: {str(e)}") from e
def generate_embeddings(
self,
texts: List[str],
show_progress: bool = True
) -> np.ndarray:
"""
Generate embeddings for a list of texts with caching.
Args:
texts: List of text strings to embed
show_progress: Whether to show progress bar
Returns:
Array of embeddings with shape (len(texts), embedding_dim)
"""
if not texts:
return np.array([])
start_time = time.time()
# Check cache for existing embeddings
cached_embeddings, missing_indices, missing_texts = self._check_cache(texts)
# Generate embeddings for missing texts
if missing_texts:
self._load_model()
new_embeddings = self._generate_batch_embeddings(missing_texts, show_progress)
# Cache new embeddings
self._cache_embeddings(missing_texts, new_embeddings)
else:
new_embeddings = np.array([])
# Combine cached and new embeddings
all_embeddings = self._combine_embeddings(texts, cached_embeddings, missing_indices, new_embeddings)
# Update stats
generation_time = time.time() - start_time
self.stats["total_time"] += generation_time
self.stats["embeddings_generated"] += len(texts)
return all_embeddings
def _check_cache(self, texts: List[str]) -> Tuple[Dict[str, np.ndarray], List[int], List[str]]:
"""Check cache for existing embeddings."""
cached_embeddings = {}
missing_indices = []
missing_texts = []
if not self.cache_manager:
return cached_embeddings, list(range(len(texts))), texts
for i, text in enumerate(texts):
cache_key = self._get_cache_key(text)
cached_embedding = self.cache_manager.get(f"embedding_{cache_key}")
if cached_embedding is not None:
cached_embeddings[text] = cached_embedding
self.stats["cache_hits"] += 1
else:
missing_indices.append(i)
missing_texts.append(text)
return cached_embeddings, missing_indices, missing_texts
def _generate_batch_embeddings(self, texts: List[str], show_progress: bool) -> np.ndarray:
"""Generate embeddings in batches."""
try:
embeddings = []
# Process in batches
batches = [texts[i:i + self.batch_size] for i in range(0, len(texts), self.batch_size)]
if show_progress and len(batches) > 1:
batches = tqdm(batches, desc="Generating embeddings")
for batch in batches:
try:
# Generate embeddings for batch
batch_embeddings = self.model.encode(
batch,
convert_to_numpy=True,
show_progress_bar=False,
batch_size=len(batch)
)
embeddings.append(batch_embeddings)
self.stats["batch_count"] += 1
except Exception as e:
raise EmbeddingError(f"Failed to generate embeddings for batch: {str(e)}") from e
if not embeddings:
return np.array([])
# Combine all batch embeddings
all_embeddings = np.vstack(embeddings)
return all_embeddings
except torch.cuda.OutOfMemoryError as e:
raise ResourceError(
"GPU memory insufficient for embedding generation. "
"Try reducing batch_size or using CPU."
) from e
except Exception as e:
raise EmbeddingError(f"Failed to generate embeddings: {str(e)}") from e
def _cache_embeddings(self, texts: List[str], embeddings: np.ndarray) -> None:
"""Cache generated embeddings."""
if not self.cache_manager or embeddings.size == 0:
return
for text, embedding in zip(texts, embeddings):
cache_key = self._get_cache_key(text)
self.cache_manager.set(f"embedding_{cache_key}", embedding)
def _combine_embeddings(
self,
original_texts: List[str],
cached_embeddings: Dict[str, np.ndarray],
missing_indices: List[int],
new_embeddings: np.ndarray
) -> np.ndarray:
"""Combine cached and newly generated embeddings."""
if not original_texts:
return np.array([])
# Get embedding dimension
if new_embeddings.size > 0:
embedding_dim = new_embeddings.shape[1]
elif cached_embeddings:
embedding_dim = next(iter(cached_embeddings.values())).shape[0]
else:
# Fallback - load model to get dimension
self._load_model()
sample_embedding = self.model.encode(["sample"], convert_to_numpy=True)
embedding_dim = sample_embedding.shape[1]
# Initialize result array
result = np.zeros((len(original_texts), embedding_dim))
# Fill in cached embeddings
for i, text in enumerate(original_texts):
if text in cached_embeddings:
result[i] = cached_embeddings[text]
# Fill in new embeddings
if new_embeddings.size > 0:
for i, original_idx in enumerate(missing_indices):
result[original_idx] = new_embeddings[i]
return result
def _get_cache_key(self, text: str) -> str:
"""Generate cache key for text."""
# Include model name and config in hash for cache invalidation
cache_input = f"{self.model_name}_{self.max_seq_length}_{text}"
return hashlib.md5(cache_input.encode()).hexdigest()
def get_embedding_dimension(self) -> int:
"""Get the dimension of embeddings."""
self._load_model()
# Generate a sample embedding to get dimensions
sample_embedding = self.model.encode(["sample"], convert_to_numpy=True)
return sample_embedding.shape[1]
def compute_similarity(self, query_embedding: np.ndarray, doc_embeddings: np.ndarray) -> np.ndarray:
"""Compute cosine similarity between query and document embeddings."""
if query_embedding.ndim == 1:
query_embedding = query_embedding.reshape(1, -1)
# Normalize vectors
query_norm = query_embedding / np.linalg.norm(query_embedding, axis=1, keepdims=True)
doc_norm = doc_embeddings / np.linalg.norm(doc_embeddings, axis=1, keepdims=True)
# Compute cosine similarity
similarities = np.dot(query_norm, doc_norm.T).flatten()
return similarities
def clear_cache(self) -> None:
"""Clear embedding cache."""
if self.cache_manager:
# Clear all embedding entries
keys_to_remove = []
for key in self.cache_manager._memory_cache.keys():
if key.startswith("embedding_"):
keys_to_remove.append(key)
for key in keys_to_remove:
self.cache_manager.delete(key)
def get_stats(self) -> Dict[str, Any]:
"""Get performance statistics."""
stats = self.stats.copy()
if stats["embeddings_generated"] > 0:
stats["avg_time_per_embedding"] = stats["total_time"] / stats["embeddings_generated"]
else:
stats["avg_time_per_embedding"] = 0
if stats["batch_count"] > 0:
stats["avg_batch_size"] = stats["embeddings_generated"] / stats["batch_count"]
else:
stats["avg_batch_size"] = 0
stats["cache_hit_rate"] = (
stats["cache_hits"] / (stats["cache_hits"] + stats["embeddings_generated"])
if (stats["cache_hits"] + stats["embeddings_generated"]) > 0 else 0
)
stats["model_loaded"] = self._model_loaded
stats["device"] = self.device
return stats
def warmup(self) -> None:
"""Warm up the model with a sample embedding."""
self._load_model()
# Generate a sample embedding to warm up the model
sample_texts = ["This is a sample text for model warmup."]
self.model.encode(sample_texts, convert_to_numpy=True, show_progress_bar=False)
print("Embedding model warmed up successfully")
def unload_model(self) -> None:
"""Unload the model to free memory."""
if self.model is not None:
del self.model
self.model = None
self._model_loaded = False
# Clear GPU cache if using CUDA
if torch.cuda.is_available():
torch.cuda.empty_cache()
print("Embedding model unloaded")