|
""" |
|
Embedding management system using sentence-transformers with caching and optimization. |
|
""" |
|
|
|
import os |
|
import pickle |
|
import hashlib |
|
from pathlib import Path |
|
from typing import Dict, List, Optional, Tuple, Any, Union |
|
import numpy as np |
|
import torch |
|
from sentence_transformers import SentenceTransformer |
|
from tqdm import tqdm |
|
import time |
|
|
|
from .error_handler import EmbeddingError, ResourceError |
|
from .cache_manager import CacheManager |
|
|
|
|
|
class EmbeddingManager: |
|
"""Manages document embeddings with caching and batch processing.""" |
|
|
|
def __init__(self, config: Dict[str, Any], cache_manager: Optional[CacheManager] = None): |
|
self.config = config |
|
self.model_config = config.get("models", {}).get("embedding", {}) |
|
self.cache_config = config.get("cache", {}) |
|
|
|
self.model_name = self.model_config.get("name", "sentence-transformers/all-MiniLM-L6-v2") |
|
self.max_seq_length = self.model_config.get("max_seq_length", 256) |
|
self.batch_size = self.model_config.get("batch_size", 32) |
|
self.device = self._get_device() |
|
|
|
self.model: Optional[SentenceTransformer] = None |
|
self.cache_manager = cache_manager |
|
self._model_loaded = False |
|
|
|
|
|
self.stats = { |
|
"embeddings_generated": 0, |
|
"cache_hits": 0, |
|
"total_time": 0, |
|
"batch_count": 0 |
|
} |
|
|
|
def _get_device(self) -> str: |
|
"""Determine the best device for computation.""" |
|
device_config = self.model_config.get("device", "auto") |
|
|
|
if device_config == "auto": |
|
if torch.cuda.is_available(): |
|
return "cuda" |
|
elif torch.backends.mps.is_available(): |
|
return "mps" |
|
else: |
|
return "cpu" |
|
else: |
|
return device_config |
|
|
|
def _load_model(self) -> None: |
|
"""Lazy load the sentence transformer model.""" |
|
if self._model_loaded and self.model is not None: |
|
return |
|
|
|
try: |
|
print(f"Loading embedding model: {self.model_name}") |
|
start_time = time.time() |
|
|
|
self.model = SentenceTransformer(self.model_name) |
|
|
|
|
|
if self.device != "cpu": |
|
self.model = self.model.to(self.device) |
|
|
|
|
|
if hasattr(self.model, 'max_seq_length'): |
|
self.model.max_seq_length = self.max_seq_length |
|
|
|
load_time = time.time() - start_time |
|
print(f"Model loaded in {load_time:.2f}s on device: {self.device}") |
|
|
|
self._model_loaded = True |
|
|
|
except Exception as e: |
|
raise EmbeddingError(f"Failed to load embedding model: {str(e)}") from e |
|
|
|
def generate_embeddings( |
|
self, |
|
texts: List[str], |
|
show_progress: bool = True |
|
) -> np.ndarray: |
|
""" |
|
Generate embeddings for a list of texts with caching. |
|
|
|
Args: |
|
texts: List of text strings to embed |
|
show_progress: Whether to show progress bar |
|
|
|
Returns: |
|
Array of embeddings with shape (len(texts), embedding_dim) |
|
""" |
|
if not texts: |
|
return np.array([]) |
|
|
|
start_time = time.time() |
|
|
|
|
|
cached_embeddings, missing_indices, missing_texts = self._check_cache(texts) |
|
|
|
|
|
if missing_texts: |
|
self._load_model() |
|
new_embeddings = self._generate_batch_embeddings(missing_texts, show_progress) |
|
|
|
|
|
self._cache_embeddings(missing_texts, new_embeddings) |
|
else: |
|
new_embeddings = np.array([]) |
|
|
|
|
|
all_embeddings = self._combine_embeddings(texts, cached_embeddings, missing_indices, new_embeddings) |
|
|
|
|
|
generation_time = time.time() - start_time |
|
self.stats["total_time"] += generation_time |
|
self.stats["embeddings_generated"] += len(texts) |
|
|
|
return all_embeddings |
|
|
|
def _check_cache(self, texts: List[str]) -> Tuple[Dict[str, np.ndarray], List[int], List[str]]: |
|
"""Check cache for existing embeddings.""" |
|
cached_embeddings = {} |
|
missing_indices = [] |
|
missing_texts = [] |
|
|
|
if not self.cache_manager: |
|
return cached_embeddings, list(range(len(texts))), texts |
|
|
|
for i, text in enumerate(texts): |
|
cache_key = self._get_cache_key(text) |
|
cached_embedding = self.cache_manager.get(f"embedding_{cache_key}") |
|
|
|
if cached_embedding is not None: |
|
cached_embeddings[text] = cached_embedding |
|
self.stats["cache_hits"] += 1 |
|
else: |
|
missing_indices.append(i) |
|
missing_texts.append(text) |
|
|
|
return cached_embeddings, missing_indices, missing_texts |
|
|
|
def _generate_batch_embeddings(self, texts: List[str], show_progress: bool) -> np.ndarray: |
|
"""Generate embeddings in batches.""" |
|
try: |
|
embeddings = [] |
|
|
|
|
|
batches = [texts[i:i + self.batch_size] for i in range(0, len(texts), self.batch_size)] |
|
|
|
if show_progress and len(batches) > 1: |
|
batches = tqdm(batches, desc="Generating embeddings") |
|
|
|
for batch in batches: |
|
try: |
|
|
|
batch_embeddings = self.model.encode( |
|
batch, |
|
convert_to_numpy=True, |
|
show_progress_bar=False, |
|
batch_size=len(batch) |
|
) |
|
|
|
embeddings.append(batch_embeddings) |
|
self.stats["batch_count"] += 1 |
|
|
|
except Exception as e: |
|
raise EmbeddingError(f"Failed to generate embeddings for batch: {str(e)}") from e |
|
|
|
if not embeddings: |
|
return np.array([]) |
|
|
|
|
|
all_embeddings = np.vstack(embeddings) |
|
return all_embeddings |
|
|
|
except torch.cuda.OutOfMemoryError as e: |
|
raise ResourceError( |
|
"GPU memory insufficient for embedding generation. " |
|
"Try reducing batch_size or using CPU." |
|
) from e |
|
except Exception as e: |
|
raise EmbeddingError(f"Failed to generate embeddings: {str(e)}") from e |
|
|
|
def _cache_embeddings(self, texts: List[str], embeddings: np.ndarray) -> None: |
|
"""Cache generated embeddings.""" |
|
if not self.cache_manager or embeddings.size == 0: |
|
return |
|
|
|
for text, embedding in zip(texts, embeddings): |
|
cache_key = self._get_cache_key(text) |
|
self.cache_manager.set(f"embedding_{cache_key}", embedding) |
|
|
|
def _combine_embeddings( |
|
self, |
|
original_texts: List[str], |
|
cached_embeddings: Dict[str, np.ndarray], |
|
missing_indices: List[int], |
|
new_embeddings: np.ndarray |
|
) -> np.ndarray: |
|
"""Combine cached and newly generated embeddings.""" |
|
if not original_texts: |
|
return np.array([]) |
|
|
|
|
|
if new_embeddings.size > 0: |
|
embedding_dim = new_embeddings.shape[1] |
|
elif cached_embeddings: |
|
embedding_dim = next(iter(cached_embeddings.values())).shape[0] |
|
else: |
|
|
|
self._load_model() |
|
sample_embedding = self.model.encode(["sample"], convert_to_numpy=True) |
|
embedding_dim = sample_embedding.shape[1] |
|
|
|
|
|
result = np.zeros((len(original_texts), embedding_dim)) |
|
|
|
|
|
for i, text in enumerate(original_texts): |
|
if text in cached_embeddings: |
|
result[i] = cached_embeddings[text] |
|
|
|
|
|
if new_embeddings.size > 0: |
|
for i, original_idx in enumerate(missing_indices): |
|
result[original_idx] = new_embeddings[i] |
|
|
|
return result |
|
|
|
def _get_cache_key(self, text: str) -> str: |
|
"""Generate cache key for text.""" |
|
|
|
cache_input = f"{self.model_name}_{self.max_seq_length}_{text}" |
|
return hashlib.md5(cache_input.encode()).hexdigest() |
|
|
|
def get_embedding_dimension(self) -> int: |
|
"""Get the dimension of embeddings.""" |
|
self._load_model() |
|
|
|
|
|
sample_embedding = self.model.encode(["sample"], convert_to_numpy=True) |
|
return sample_embedding.shape[1] |
|
|
|
def compute_similarity(self, query_embedding: np.ndarray, doc_embeddings: np.ndarray) -> np.ndarray: |
|
"""Compute cosine similarity between query and document embeddings.""" |
|
if query_embedding.ndim == 1: |
|
query_embedding = query_embedding.reshape(1, -1) |
|
|
|
|
|
query_norm = query_embedding / np.linalg.norm(query_embedding, axis=1, keepdims=True) |
|
doc_norm = doc_embeddings / np.linalg.norm(doc_embeddings, axis=1, keepdims=True) |
|
|
|
|
|
similarities = np.dot(query_norm, doc_norm.T).flatten() |
|
return similarities |
|
|
|
def clear_cache(self) -> None: |
|
"""Clear embedding cache.""" |
|
if self.cache_manager: |
|
|
|
keys_to_remove = [] |
|
for key in self.cache_manager._memory_cache.keys(): |
|
if key.startswith("embedding_"): |
|
keys_to_remove.append(key) |
|
|
|
for key in keys_to_remove: |
|
self.cache_manager.delete(key) |
|
|
|
def get_stats(self) -> Dict[str, Any]: |
|
"""Get performance statistics.""" |
|
stats = self.stats.copy() |
|
|
|
if stats["embeddings_generated"] > 0: |
|
stats["avg_time_per_embedding"] = stats["total_time"] / stats["embeddings_generated"] |
|
else: |
|
stats["avg_time_per_embedding"] = 0 |
|
|
|
if stats["batch_count"] > 0: |
|
stats["avg_batch_size"] = stats["embeddings_generated"] / stats["batch_count"] |
|
else: |
|
stats["avg_batch_size"] = 0 |
|
|
|
stats["cache_hit_rate"] = ( |
|
stats["cache_hits"] / (stats["cache_hits"] + stats["embeddings_generated"]) |
|
if (stats["cache_hits"] + stats["embeddings_generated"]) > 0 else 0 |
|
) |
|
|
|
stats["model_loaded"] = self._model_loaded |
|
stats["device"] = self.device |
|
|
|
return stats |
|
|
|
def warmup(self) -> None: |
|
"""Warm up the model with a sample embedding.""" |
|
self._load_model() |
|
|
|
|
|
sample_texts = ["This is a sample text for model warmup."] |
|
self.model.encode(sample_texts, convert_to_numpy=True, show_progress_bar=False) |
|
|
|
print("Embedding model warmed up successfully") |
|
|
|
def unload_model(self) -> None: |
|
"""Unload the model to free memory.""" |
|
if self.model is not None: |
|
del self.model |
|
self.model = None |
|
self._model_loaded = False |
|
|
|
|
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
|
|
print("Embedding model unloaded") |