|
import logging |
|
import asyncio |
|
from typing import List, Optional, Dict, Any |
|
import numpy as np |
|
from sentence_transformers import SentenceTransformer |
|
import torch |
|
import config |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
class EmbeddingService: |
|
def __init__(self): |
|
self.config = config.config |
|
self.model_name = self.config.EMBEDDING_MODEL |
|
self.model = None |
|
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
|
|
self._load_model() |
|
|
|
def _load_model(self): |
|
"""Load the embedding model""" |
|
try: |
|
logger.info(f"Loading embedding model: {self.model_name}") |
|
self.model = SentenceTransformer(self.model_name, device=self.device) |
|
logger.info(f"Embedding model loaded successfully on {self.device}") |
|
except Exception as e: |
|
logger.error(f"Failed to load embedding model: {str(e)}") |
|
|
|
try: |
|
self.model_name = "all-MiniLM-L6-v2" |
|
self.model = SentenceTransformer(self.model_name, device=self.device) |
|
logger.info(f"Loaded fallback embedding model: {self.model_name}") |
|
except Exception as fallback_error: |
|
logger.error(f"Failed to load fallback model: {str(fallback_error)}") |
|
raise |
|
|
|
async def generate_embeddings(self, texts: List[str], batch_size: int = 32) -> List[List[float]]: |
|
"""Generate embeddings for a list of texts""" |
|
if not texts: |
|
return [] |
|
|
|
if self.model is None: |
|
raise RuntimeError("Embedding model not loaded") |
|
|
|
try: |
|
|
|
non_empty_texts = [text for text in texts if text and text.strip()] |
|
if not non_empty_texts: |
|
logger.warning("No non-empty texts provided for embedding") |
|
return [] |
|
|
|
logger.info(f"Generating embeddings for {len(non_empty_texts)} texts") |
|
|
|
|
|
all_embeddings = [] |
|
for i in range(0, len(non_empty_texts), batch_size): |
|
batch = non_empty_texts[i:i + batch_size] |
|
|
|
|
|
loop = asyncio.get_event_loop() |
|
batch_embeddings = await loop.run_in_executor( |
|
None, |
|
self._generate_batch_embeddings, |
|
batch |
|
) |
|
all_embeddings.extend(batch_embeddings) |
|
|
|
logger.info(f"Generated {len(all_embeddings)} embeddings") |
|
return all_embeddings |
|
|
|
except Exception as e: |
|
logger.error(f"Error generating embeddings: {str(e)}") |
|
raise |
|
|
|
def _generate_batch_embeddings(self, texts: List[str]) -> List[List[float]]: |
|
"""Generate embeddings for a batch of texts (synchronous)""" |
|
try: |
|
|
|
embeddings = self.model.encode( |
|
texts, |
|
convert_to_numpy=True, |
|
normalize_embeddings=True, |
|
batch_size=len(texts) |
|
) |
|
|
|
|
|
return embeddings.tolist() |
|
except Exception as e: |
|
logger.error(f"Error in batch embedding generation: {str(e)}") |
|
raise |
|
|
|
async def generate_single_embedding(self, text: str) -> Optional[List[float]]: |
|
"""Generate embedding for a single text""" |
|
if not text or not text.strip(): |
|
return None |
|
|
|
try: |
|
embeddings = await self.generate_embeddings([text]) |
|
return embeddings[0] if embeddings else None |
|
except Exception as e: |
|
logger.error(f"Error generating single embedding: {str(e)}") |
|
return None |
|
|
|
def get_embedding_dimension(self) -> int: |
|
"""Get the dimension of embeddings produced by the model""" |
|
if self.model is None: |
|
raise RuntimeError("Embedding model not loaded") |
|
|
|
return self.model.get_sentence_embedding_dimension() |
|
|
|
def compute_similarity(self, embedding1: List[float], embedding2: List[float]) -> float: |
|
"""Compute cosine similarity between two embeddings""" |
|
try: |
|
|
|
emb1 = np.array(embedding1) |
|
emb2 = np.array(embedding2) |
|
|
|
|
|
similarity = np.dot(emb1, emb2) / (np.linalg.norm(emb1) * np.linalg.norm(emb2)) |
|
|
|
return float(similarity) |
|
except Exception as e: |
|
logger.error(f"Error computing similarity: {str(e)}") |
|
return 0.0 |
|
|
|
def compute_similarities(self, query_embedding: List[float], embeddings: List[List[float]]) -> List[float]: |
|
"""Compute similarities between a query embedding and multiple embeddings""" |
|
try: |
|
query_emb = np.array(query_embedding) |
|
emb_matrix = np.array(embeddings) |
|
|
|
|
|
similarities = np.dot(emb_matrix, query_emb) / ( |
|
np.linalg.norm(emb_matrix, axis=1) * np.linalg.norm(query_emb) |
|
) |
|
|
|
return similarities.tolist() |
|
except Exception as e: |
|
logger.error(f"Error computing similarities: {str(e)}") |
|
return [0.0] * len(embeddings) |
|
|
|
async def embed_chunks(self, chunks: List[Dict[str, Any]]) -> List[Dict[str, Any]]: |
|
"""Embed a list of chunks and add embeddings to them""" |
|
if not chunks: |
|
return [] |
|
|
|
try: |
|
|
|
texts = [chunk.get('content', '') for chunk in chunks] |
|
|
|
|
|
embeddings = await self.generate_embeddings(texts) |
|
|
|
|
|
embedded_chunks = [] |
|
for i, chunk in enumerate(chunks): |
|
if i < len(embeddings): |
|
chunk_copy = chunk.copy() |
|
chunk_copy['embedding'] = embeddings[i] |
|
embedded_chunks.append(chunk_copy) |
|
else: |
|
logger.warning(f"No embedding generated for chunk {i}") |
|
embedded_chunks.append(chunk) |
|
|
|
return embedded_chunks |
|
except Exception as e: |
|
logger.error(f"Error embedding chunks: {str(e)}") |
|
raise |
|
|
|
def validate_embedding(self, embedding: List[float]) -> bool: |
|
"""Validate that an embedding is properly formatted""" |
|
try: |
|
if not embedding: |
|
return False |
|
|
|
if not isinstance(embedding, list): |
|
return False |
|
|
|
if len(embedding) != self.get_embedding_dimension(): |
|
return False |
|
|
|
|
|
emb_array = np.array(embedding) |
|
if np.isnan(emb_array).any() or np.isinf(emb_array).any(): |
|
return False |
|
|
|
return True |
|
except Exception: |
|
return False |
|
|
|
async def get_model_info(self) -> Dict[str, Any]: |
|
"""Get information about the loaded model""" |
|
try: |
|
return { |
|
"model_name": self.model_name, |
|
"device": self.device, |
|
"embedding_dimension": self.get_embedding_dimension(), |
|
"max_sequence_length": getattr(self.model, 'max_seq_length', 'unknown'), |
|
"model_loaded": self.model is not None |
|
} |
|
except Exception as e: |
|
logger.error(f"Error getting model info: {str(e)}") |
|
return {"error": str(e)} |