Spaces:
Running
Running
""" | |
Embedding management system using sentence-transformers with caching and optimization. | |
""" | |
import os | |
import pickle | |
import hashlib | |
from pathlib import Path | |
from typing import Dict, List, Optional, Tuple, Any, Union | |
import numpy as np | |
import torch | |
from sentence_transformers import SentenceTransformer | |
from tqdm import tqdm | |
import time | |
from .error_handler import EmbeddingError, ResourceError | |
from .cache_manager import CacheManager | |
class EmbeddingManager: | |
"""Manages document embeddings with caching and batch processing.""" | |
def __init__(self, config: Dict[str, Any], cache_manager: Optional[CacheManager] = None): | |
self.config = config | |
self.model_config = config.get("models", {}).get("embedding", {}) | |
self.cache_config = config.get("cache", {}) | |
self.model_name = self.model_config.get("name", "sentence-transformers/all-MiniLM-L6-v2") | |
self.max_seq_length = self.model_config.get("max_seq_length", 256) | |
self.batch_size = self.model_config.get("batch_size", 32) | |
self.device = self._get_device() | |
self.model: Optional[SentenceTransformer] = None | |
self.cache_manager = cache_manager | |
self._model_loaded = False | |
# Performance tracking | |
self.stats = { | |
"embeddings_generated": 0, | |
"cache_hits": 0, | |
"total_time": 0, | |
"batch_count": 0 | |
} | |
def _get_device(self) -> str: | |
"""Determine the best device for computation.""" | |
device_config = self.model_config.get("device", "auto") | |
if device_config == "auto": | |
if torch.cuda.is_available(): | |
return "cuda" | |
elif torch.backends.mps.is_available(): # Apple Silicon | |
return "mps" | |
else: | |
return "cpu" | |
else: | |
return device_config | |
def _load_model(self) -> None: | |
"""Lazy load the sentence transformer model.""" | |
if self._model_loaded and self.model is not None: | |
return | |
try: | |
print(f"Loading embedding model: {self.model_name}") | |
start_time = time.time() | |
self.model = SentenceTransformer(self.model_name) | |
# Set device | |
if self.device != "cpu": | |
self.model = self.model.to(self.device) | |
# Set max sequence length | |
if hasattr(self.model, 'max_seq_length'): | |
self.model.max_seq_length = self.max_seq_length | |
load_time = time.time() - start_time | |
print(f"Model loaded in {load_time:.2f}s on device: {self.device}") | |
self._model_loaded = True | |
except Exception as e: | |
raise EmbeddingError(f"Failed to load embedding model: {str(e)}") from e | |
def generate_embeddings( | |
self, | |
texts: List[str], | |
show_progress: bool = True | |
) -> np.ndarray: | |
""" | |
Generate embeddings for a list of texts with caching. | |
Args: | |
texts: List of text strings to embed | |
show_progress: Whether to show progress bar | |
Returns: | |
Array of embeddings with shape (len(texts), embedding_dim) | |
""" | |
if not texts: | |
return np.array([]) | |
start_time = time.time() | |
# Check cache for existing embeddings | |
cached_embeddings, missing_indices, missing_texts = self._check_cache(texts) | |
# Generate embeddings for missing texts | |
if missing_texts: | |
self._load_model() | |
new_embeddings = self._generate_batch_embeddings(missing_texts, show_progress) | |
# Cache new embeddings | |
self._cache_embeddings(missing_texts, new_embeddings) | |
else: | |
new_embeddings = np.array([]) | |
# Combine cached and new embeddings | |
all_embeddings = self._combine_embeddings(texts, cached_embeddings, missing_indices, new_embeddings) | |
# Update stats | |
generation_time = time.time() - start_time | |
self.stats["total_time"] += generation_time | |
self.stats["embeddings_generated"] += len(texts) | |
return all_embeddings | |
def _check_cache(self, texts: List[str]) -> Tuple[Dict[str, np.ndarray], List[int], List[str]]: | |
"""Check cache for existing embeddings.""" | |
cached_embeddings = {} | |
missing_indices = [] | |
missing_texts = [] | |
if not self.cache_manager: | |
return cached_embeddings, list(range(len(texts))), texts | |
for i, text in enumerate(texts): | |
cache_key = self._get_cache_key(text) | |
cached_embedding = self.cache_manager.get(f"embedding_{cache_key}") | |
if cached_embedding is not None: | |
cached_embeddings[text] = cached_embedding | |
self.stats["cache_hits"] += 1 | |
else: | |
missing_indices.append(i) | |
missing_texts.append(text) | |
return cached_embeddings, missing_indices, missing_texts | |
def _generate_batch_embeddings(self, texts: List[str], show_progress: bool) -> np.ndarray: | |
"""Generate embeddings in batches.""" | |
try: | |
embeddings = [] | |
# Process in batches | |
batches = [texts[i:i + self.batch_size] for i in range(0, len(texts), self.batch_size)] | |
if show_progress and len(batches) > 1: | |
batches = tqdm(batches, desc="Generating embeddings") | |
for batch in batches: | |
try: | |
# Generate embeddings for batch | |
batch_embeddings = self.model.encode( | |
batch, | |
convert_to_numpy=True, | |
show_progress_bar=False, | |
batch_size=len(batch) | |
) | |
embeddings.append(batch_embeddings) | |
self.stats["batch_count"] += 1 | |
except Exception as e: | |
raise EmbeddingError(f"Failed to generate embeddings for batch: {str(e)}") from e | |
if not embeddings: | |
return np.array([]) | |
# Combine all batch embeddings | |
all_embeddings = np.vstack(embeddings) | |
return all_embeddings | |
except torch.cuda.OutOfMemoryError as e: | |
raise ResourceError( | |
"GPU memory insufficient for embedding generation. " | |
"Try reducing batch_size or using CPU." | |
) from e | |
except Exception as e: | |
raise EmbeddingError(f"Failed to generate embeddings: {str(e)}") from e | |
def _cache_embeddings(self, texts: List[str], embeddings: np.ndarray) -> None: | |
"""Cache generated embeddings.""" | |
if not self.cache_manager or embeddings.size == 0: | |
return | |
for text, embedding in zip(texts, embeddings): | |
cache_key = self._get_cache_key(text) | |
self.cache_manager.set(f"embedding_{cache_key}", embedding) | |
def _combine_embeddings( | |
self, | |
original_texts: List[str], | |
cached_embeddings: Dict[str, np.ndarray], | |
missing_indices: List[int], | |
new_embeddings: np.ndarray | |
) -> np.ndarray: | |
"""Combine cached and newly generated embeddings.""" | |
if not original_texts: | |
return np.array([]) | |
# Get embedding dimension | |
if new_embeddings.size > 0: | |
embedding_dim = new_embeddings.shape[1] | |
elif cached_embeddings: | |
embedding_dim = next(iter(cached_embeddings.values())).shape[0] | |
else: | |
# Fallback - load model to get dimension | |
self._load_model() | |
sample_embedding = self.model.encode(["sample"], convert_to_numpy=True) | |
embedding_dim = sample_embedding.shape[1] | |
# Initialize result array | |
result = np.zeros((len(original_texts), embedding_dim)) | |
# Fill in cached embeddings | |
for i, text in enumerate(original_texts): | |
if text in cached_embeddings: | |
result[i] = cached_embeddings[text] | |
# Fill in new embeddings | |
if new_embeddings.size > 0: | |
for i, original_idx in enumerate(missing_indices): | |
result[original_idx] = new_embeddings[i] | |
return result | |
def _get_cache_key(self, text: str) -> str: | |
"""Generate cache key for text.""" | |
# Include model name and config in hash for cache invalidation | |
cache_input = f"{self.model_name}_{self.max_seq_length}_{text}" | |
return hashlib.md5(cache_input.encode()).hexdigest() | |
def get_embedding_dimension(self) -> int: | |
"""Get the dimension of embeddings.""" | |
self._load_model() | |
# Generate a sample embedding to get dimensions | |
sample_embedding = self.model.encode(["sample"], convert_to_numpy=True) | |
return sample_embedding.shape[1] | |
def compute_similarity(self, query_embedding: np.ndarray, doc_embeddings: np.ndarray) -> np.ndarray: | |
"""Compute cosine similarity between query and document embeddings.""" | |
if query_embedding.ndim == 1: | |
query_embedding = query_embedding.reshape(1, -1) | |
# Normalize vectors | |
query_norm = query_embedding / np.linalg.norm(query_embedding, axis=1, keepdims=True) | |
doc_norm = doc_embeddings / np.linalg.norm(doc_embeddings, axis=1, keepdims=True) | |
# Compute cosine similarity | |
similarities = np.dot(query_norm, doc_norm.T).flatten() | |
return similarities | |
def clear_cache(self) -> None: | |
"""Clear embedding cache.""" | |
if self.cache_manager: | |
# Clear all embedding entries | |
keys_to_remove = [] | |
for key in self.cache_manager._memory_cache.keys(): | |
if key.startswith("embedding_"): | |
keys_to_remove.append(key) | |
for key in keys_to_remove: | |
self.cache_manager.delete(key) | |
def get_stats(self) -> Dict[str, Any]: | |
"""Get performance statistics.""" | |
stats = self.stats.copy() | |
if stats["embeddings_generated"] > 0: | |
stats["avg_time_per_embedding"] = stats["total_time"] / stats["embeddings_generated"] | |
else: | |
stats["avg_time_per_embedding"] = 0 | |
if stats["batch_count"] > 0: | |
stats["avg_batch_size"] = stats["embeddings_generated"] / stats["batch_count"] | |
else: | |
stats["avg_batch_size"] = 0 | |
stats["cache_hit_rate"] = ( | |
stats["cache_hits"] / (stats["cache_hits"] + stats["embeddings_generated"]) | |
if (stats["cache_hits"] + stats["embeddings_generated"]) > 0 else 0 | |
) | |
stats["model_loaded"] = self._model_loaded | |
stats["device"] = self.device | |
return stats | |
def warmup(self) -> None: | |
"""Warm up the model with a sample embedding.""" | |
self._load_model() | |
# Generate a sample embedding to warm up the model | |
sample_texts = ["This is a sample text for model warmup."] | |
self.model.encode(sample_texts, convert_to_numpy=True, show_progress_bar=False) | |
print("Embedding model warmed up successfully") | |
def unload_model(self) -> None: | |
"""Unload the model to free memory.""" | |
if self.model is not None: | |
del self.model | |
self.model = None | |
self._model_loaded = False | |
# Clear GPU cache if using CUDA | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
print("Embedding model unloaded") |