File size: 12,067 Bytes
11d9dfb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
"""
Embedding management system using sentence-transformers with caching and optimization.
"""

import os
import pickle
import hashlib
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Any, Union
import numpy as np
import torch
from sentence_transformers import SentenceTransformer
from tqdm import tqdm
import time

from .error_handler import EmbeddingError, ResourceError
from .cache_manager import CacheManager


class EmbeddingManager:
    """Manages document embeddings with caching and batch processing."""
    
    def __init__(self, config: Dict[str, Any], cache_manager: Optional[CacheManager] = None):
        self.config = config
        self.model_config = config.get("models", {}).get("embedding", {})
        self.cache_config = config.get("cache", {})
        
        self.model_name = self.model_config.get("name", "sentence-transformers/all-MiniLM-L6-v2")
        self.max_seq_length = self.model_config.get("max_seq_length", 256)
        self.batch_size = self.model_config.get("batch_size", 32)
        self.device = self._get_device()
        
        self.model: Optional[SentenceTransformer] = None
        self.cache_manager = cache_manager
        self._model_loaded = False
        
        # Performance tracking
        self.stats = {
            "embeddings_generated": 0,
            "cache_hits": 0,
            "total_time": 0,
            "batch_count": 0
        }
    
    def _get_device(self) -> str:
        """Determine the best device for computation."""
        device_config = self.model_config.get("device", "auto")
        
        if device_config == "auto":
            if torch.cuda.is_available():
                return "cuda"
            elif torch.backends.mps.is_available():  # Apple Silicon
                return "mps"
            else:
                return "cpu"
        else:
            return device_config
    
    def _load_model(self) -> None:
        """Lazy load the sentence transformer model."""
        if self._model_loaded and self.model is not None:
            return
        
        try:
            print(f"Loading embedding model: {self.model_name}")
            start_time = time.time()
            
            self.model = SentenceTransformer(self.model_name)
            
            # Set device
            if self.device != "cpu":
                self.model = self.model.to(self.device)
            
            # Set max sequence length
            if hasattr(self.model, 'max_seq_length'):
                self.model.max_seq_length = self.max_seq_length
            
            load_time = time.time() - start_time
            print(f"Model loaded in {load_time:.2f}s on device: {self.device}")
            
            self._model_loaded = True
            
        except Exception as e:
            raise EmbeddingError(f"Failed to load embedding model: {str(e)}") from e
    
    def generate_embeddings(
        self, 
        texts: List[str], 
        show_progress: bool = True
    ) -> np.ndarray:
        """
        Generate embeddings for a list of texts with caching.
        
        Args:
            texts: List of text strings to embed
            show_progress: Whether to show progress bar
            
        Returns:
            Array of embeddings with shape (len(texts), embedding_dim)
        """
        if not texts:
            return np.array([])
        
        start_time = time.time()
        
        # Check cache for existing embeddings
        cached_embeddings, missing_indices, missing_texts = self._check_cache(texts)
        
        # Generate embeddings for missing texts
        if missing_texts:
            self._load_model()
            new_embeddings = self._generate_batch_embeddings(missing_texts, show_progress)
            
            # Cache new embeddings
            self._cache_embeddings(missing_texts, new_embeddings)
        else:
            new_embeddings = np.array([])
        
        # Combine cached and new embeddings
        all_embeddings = self._combine_embeddings(texts, cached_embeddings, missing_indices, new_embeddings)
        
        # Update stats
        generation_time = time.time() - start_time
        self.stats["total_time"] += generation_time
        self.stats["embeddings_generated"] += len(texts)
        
        return all_embeddings
    
    def _check_cache(self, texts: List[str]) -> Tuple[Dict[str, np.ndarray], List[int], List[str]]:
        """Check cache for existing embeddings."""
        cached_embeddings = {}
        missing_indices = []
        missing_texts = []
        
        if not self.cache_manager:
            return cached_embeddings, list(range(len(texts))), texts
        
        for i, text in enumerate(texts):
            cache_key = self._get_cache_key(text)
            cached_embedding = self.cache_manager.get(f"embedding_{cache_key}")
            
            if cached_embedding is not None:
                cached_embeddings[text] = cached_embedding
                self.stats["cache_hits"] += 1
            else:
                missing_indices.append(i)
                missing_texts.append(text)
        
        return cached_embeddings, missing_indices, missing_texts
    
    def _generate_batch_embeddings(self, texts: List[str], show_progress: bool) -> np.ndarray:
        """Generate embeddings in batches."""
        try:
            embeddings = []
            
            # Process in batches
            batches = [texts[i:i + self.batch_size] for i in range(0, len(texts), self.batch_size)]
            
            if show_progress and len(batches) > 1:
                batches = tqdm(batches, desc="Generating embeddings")
            
            for batch in batches:
                try:
                    # Generate embeddings for batch
                    batch_embeddings = self.model.encode(
                        batch,
                        convert_to_numpy=True,
                        show_progress_bar=False,
                        batch_size=len(batch)
                    )
                    
                    embeddings.append(batch_embeddings)
                    self.stats["batch_count"] += 1
                    
                except Exception as e:
                    raise EmbeddingError(f"Failed to generate embeddings for batch: {str(e)}") from e
            
            if not embeddings:
                return np.array([])
            
            # Combine all batch embeddings
            all_embeddings = np.vstack(embeddings)
            return all_embeddings
            
        except torch.cuda.OutOfMemoryError as e:
            raise ResourceError(
                "GPU memory insufficient for embedding generation. "
                "Try reducing batch_size or using CPU."
            ) from e
        except Exception as e:
            raise EmbeddingError(f"Failed to generate embeddings: {str(e)}") from e
    
    def _cache_embeddings(self, texts: List[str], embeddings: np.ndarray) -> None:
        """Cache generated embeddings."""
        if not self.cache_manager or embeddings.size == 0:
            return
        
        for text, embedding in zip(texts, embeddings):
            cache_key = self._get_cache_key(text)
            self.cache_manager.set(f"embedding_{cache_key}", embedding)
    
    def _combine_embeddings(
        self, 
        original_texts: List[str],
        cached_embeddings: Dict[str, np.ndarray],
        missing_indices: List[int],
        new_embeddings: np.ndarray
    ) -> np.ndarray:
        """Combine cached and newly generated embeddings."""
        if not original_texts:
            return np.array([])
        
        # Get embedding dimension
        if new_embeddings.size > 0:
            embedding_dim = new_embeddings.shape[1]
        elif cached_embeddings:
            embedding_dim = next(iter(cached_embeddings.values())).shape[0]
        else:
            # Fallback - load model to get dimension
            self._load_model()
            sample_embedding = self.model.encode(["sample"], convert_to_numpy=True)
            embedding_dim = sample_embedding.shape[1]
        
        # Initialize result array
        result = np.zeros((len(original_texts), embedding_dim))
        
        # Fill in cached embeddings
        for i, text in enumerate(original_texts):
            if text in cached_embeddings:
                result[i] = cached_embeddings[text]
        
        # Fill in new embeddings
        if new_embeddings.size > 0:
            for i, original_idx in enumerate(missing_indices):
                result[original_idx] = new_embeddings[i]
        
        return result
    
    def _get_cache_key(self, text: str) -> str:
        """Generate cache key for text."""
        # Include model name and config in hash for cache invalidation
        cache_input = f"{self.model_name}_{self.max_seq_length}_{text}"
        return hashlib.md5(cache_input.encode()).hexdigest()
    
    def get_embedding_dimension(self) -> int:
        """Get the dimension of embeddings."""
        self._load_model()
        
        # Generate a sample embedding to get dimensions
        sample_embedding = self.model.encode(["sample"], convert_to_numpy=True)
        return sample_embedding.shape[1]
    
    def compute_similarity(self, query_embedding: np.ndarray, doc_embeddings: np.ndarray) -> np.ndarray:
        """Compute cosine similarity between query and document embeddings."""
        if query_embedding.ndim == 1:
            query_embedding = query_embedding.reshape(1, -1)
        
        # Normalize vectors
        query_norm = query_embedding / np.linalg.norm(query_embedding, axis=1, keepdims=True)
        doc_norm = doc_embeddings / np.linalg.norm(doc_embeddings, axis=1, keepdims=True)
        
        # Compute cosine similarity
        similarities = np.dot(query_norm, doc_norm.T).flatten()
        return similarities
    
    def clear_cache(self) -> None:
        """Clear embedding cache."""
        if self.cache_manager:
            # Clear all embedding entries
            keys_to_remove = []
            for key in self.cache_manager._memory_cache.keys():
                if key.startswith("embedding_"):
                    keys_to_remove.append(key)
            
            for key in keys_to_remove:
                self.cache_manager.delete(key)
    
    def get_stats(self) -> Dict[str, Any]:
        """Get performance statistics."""
        stats = self.stats.copy()
        
        if stats["embeddings_generated"] > 0:
            stats["avg_time_per_embedding"] = stats["total_time"] / stats["embeddings_generated"]
        else:
            stats["avg_time_per_embedding"] = 0
        
        if stats["batch_count"] > 0:
            stats["avg_batch_size"] = stats["embeddings_generated"] / stats["batch_count"]
        else:
            stats["avg_batch_size"] = 0
        
        stats["cache_hit_rate"] = (
            stats["cache_hits"] / (stats["cache_hits"] + stats["embeddings_generated"]) 
            if (stats["cache_hits"] + stats["embeddings_generated"]) > 0 else 0
        )
        
        stats["model_loaded"] = self._model_loaded
        stats["device"] = self.device
        
        return stats
    
    def warmup(self) -> None:
        """Warm up the model with a sample embedding."""
        self._load_model()
        
        # Generate a sample embedding to warm up the model
        sample_texts = ["This is a sample text for model warmup."]
        self.model.encode(sample_texts, convert_to_numpy=True, show_progress_bar=False)
        
        print("Embedding model warmed up successfully")
    
    def unload_model(self) -> None:
        """Unload the model to free memory."""
        if self.model is not None:
            del self.model
            self.model = None
            self._model_loaded = False
            
            # Clear GPU cache if using CUDA
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            
            print("Embedding model unloaded")