import logging import asyncio from typing import List, Optional, Dict, Any import numpy as np from sentence_transformers import SentenceTransformer import torch import config logger = logging.getLogger(__name__) class EmbeddingService: def __init__(self): self.config = config.config self.model_name = self.config.EMBEDDING_MODEL self.model = None self.device = 'cuda' if torch.cuda.is_available() else 'cpu' # Load model lazily self._load_model() def _load_model(self): """Load the embedding model""" try: logger.info(f"Loading embedding model: {self.model_name}") self.model = SentenceTransformer(self.model_name, device=self.device) logger.info(f"Embedding model loaded successfully on {self.device}") except Exception as e: logger.error(f"Failed to load embedding model: {str(e)}") # Fallback to a smaller model try: self.model_name = "all-MiniLM-L6-v2" self.model = SentenceTransformer(self.model_name, device=self.device) logger.info(f"Loaded fallback embedding model: {self.model_name}") except Exception as fallback_error: logger.error(f"Failed to load fallback model: {str(fallback_error)}") raise async def generate_embeddings(self, texts: List[str], batch_size: int = 32) -> List[List[float]]: """Generate embeddings for a list of texts""" if not texts: return [] if self.model is None: raise RuntimeError("Embedding model not loaded") try: # Filter out empty texts non_empty_texts = [text for text in texts if text and text.strip()] if not non_empty_texts: logger.warning("No non-empty texts provided for embedding") return [] logger.info(f"Generating embeddings for {len(non_empty_texts)} texts") # Process in batches to manage memory all_embeddings = [] for i in range(0, len(non_empty_texts), batch_size): batch = non_empty_texts[i:i + batch_size] # Run embedding generation in thread pool to avoid blocking loop = asyncio.get_event_loop() batch_embeddings = await loop.run_in_executor( None, self._generate_batch_embeddings, batch ) all_embeddings.extend(batch_embeddings) logger.info(f"Generated {len(all_embeddings)} embeddings") return all_embeddings except Exception as e: logger.error(f"Error generating embeddings: {str(e)}") raise def _generate_batch_embeddings(self, texts: List[str]) -> List[List[float]]: """Generate embeddings for a batch of texts (synchronous)""" try: # Generate embeddings embeddings = self.model.encode( texts, convert_to_numpy=True, normalize_embeddings=True, batch_size=len(texts) ) # Convert to list of lists return embeddings.tolist() except Exception as e: logger.error(f"Error in batch embedding generation: {str(e)}") raise async def generate_single_embedding(self, text: str) -> Optional[List[float]]: """Generate embedding for a single text""" if not text or not text.strip(): return None try: embeddings = await self.generate_embeddings([text]) return embeddings[0] if embeddings else None except Exception as e: logger.error(f"Error generating single embedding: {str(e)}") return None def get_embedding_dimension(self) -> int: """Get the dimension of embeddings produced by the model""" if self.model is None: raise RuntimeError("Embedding model not loaded") return self.model.get_sentence_embedding_dimension() def compute_similarity(self, embedding1: List[float], embedding2: List[float]) -> float: """Compute cosine similarity between two embeddings""" try: # Convert to numpy arrays emb1 = np.array(embedding1) emb2 = np.array(embedding2) # Compute cosine similarity similarity = np.dot(emb1, emb2) / (np.linalg.norm(emb1) * np.linalg.norm(emb2)) return float(similarity) except Exception as e: logger.error(f"Error computing similarity: {str(e)}") return 0.0 def compute_similarities(self, query_embedding: List[float], embeddings: List[List[float]]) -> List[float]: """Compute similarities between a query embedding and multiple embeddings""" try: query_emb = np.array(query_embedding) emb_matrix = np.array(embeddings) # Compute cosine similarities similarities = np.dot(emb_matrix, query_emb) / ( np.linalg.norm(emb_matrix, axis=1) * np.linalg.norm(query_emb) ) return similarities.tolist() except Exception as e: logger.error(f"Error computing similarities: {str(e)}") return [0.0] * len(embeddings) async def embed_chunks(self, chunks: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """Embed a list of chunks and add embeddings to them""" if not chunks: return [] try: # Extract texts texts = [chunk.get('content', '') for chunk in chunks] # Generate embeddings embeddings = await self.generate_embeddings(texts) # Add embeddings to chunks embedded_chunks = [] for i, chunk in enumerate(chunks): if i < len(embeddings): chunk_copy = chunk.copy() chunk_copy['embedding'] = embeddings[i] embedded_chunks.append(chunk_copy) else: logger.warning(f"No embedding generated for chunk {i}") embedded_chunks.append(chunk) return embedded_chunks except Exception as e: logger.error(f"Error embedding chunks: {str(e)}") raise def validate_embedding(self, embedding: List[float]) -> bool: """Validate that an embedding is properly formatted""" try: if not embedding: return False if not isinstance(embedding, list): return False if len(embedding) != self.get_embedding_dimension(): return False # Check for NaN or infinite values emb_array = np.array(embedding) if np.isnan(emb_array).any() or np.isinf(emb_array).any(): return False return True except Exception: return False async def get_model_info(self) -> Dict[str, Any]: """Get information about the loaded model""" try: return { "model_name": self.model_name, "device": self.device, "embedding_dimension": self.get_embedding_dimension(), "max_sequence_length": getattr(self.model, 'max_seq_length', 'unknown'), "model_loaded": self.model is not None } except Exception as e: logger.error(f"Error getting model info: {str(e)}") return {"error": str(e)}