enhanced-rag-demo / src /components /embedders /models /sentence_transformer_model.py
Arthur Passuello
Trying to fix imports
0242f02
"""
SentenceTransformer embedding model implementation.
This module provides a direct implementation of the EmbeddingModel interface
for SentenceTransformer models, extracted from the shared_utils functionality
to make the embedder component self-contained.
Features:
- MPS acceleration for Apple Silicon
- Efficient model caching and reuse
- Configurable device selection
- Embedding normalization support
- Memory-efficient processing
"""
import torch
import numpy as np
from typing import List, Dict, Any, Optional
from sentence_transformers import SentenceTransformer
import logging
import os
from pathlib import Path
import sys
# Add project root for imports
project_root = Path(__file__).parent.parent.parent.parent.parent
sys.path.append(str(project_root))
from ..base import EmbeddingModel, ConfigurableEmbedderComponent
logger = logging.getLogger(__name__)
class SentenceTransformerModel(EmbeddingModel, ConfigurableEmbedderComponent):
"""
Direct implementation of SentenceTransformer embedding model.
This class provides a self-contained implementation of the EmbeddingModel
interface using SentenceTransformers, with MPS acceleration support and
efficient model management.
Configuration:
{
"model_name": "sentence-transformers/all-MiniLM-L6-v2",
"device": "auto", # or "mps", "cuda", "cpu"
"normalize_embeddings": true,
"cache_folder": null, # or path to cache directory
"trust_remote_code": false
}
Performance Features:
- Apple Silicon MPS acceleration
- Model caching to avoid reloading
- Memory-efficient inference mode
- Configurable device selection
"""
# Class-level model cache to avoid reloading
_model_cache: Dict[str, SentenceTransformer] = {}
def __init__(self, config: Dict[str, Any]):
"""
Initialize SentenceTransformer model.
Args:
config: Model configuration dictionary
"""
super().__init__(config)
self.model_name = config.get("model_name", "sentence-transformers/all-MiniLM-L6-v2")
self.device = self._determine_device(config.get("device", "auto"))
self.normalize_embeddings = config.get("normalize_embeddings", True)
self.cache_folder = config.get("cache_folder")
self.trust_remote_code = config.get("trust_remote_code", False)
# Load model
self._model = self._load_model()
self._embedding_dim = None
self._max_seq_length = None
logger.info(f"SentenceTransformerModel initialized: {self.model_name} on {self.device}")
def _validate_config(self) -> None:
"""
Validate model configuration.
Raises:
ValueError: If configuration is invalid
"""
required_keys = ["model_name"]
for key in required_keys:
if key not in self.config:
raise ValueError(f"Missing required configuration key: {key}")
# Validate device
device = self.config.get("device", "auto")
valid_devices = ["auto", "cpu", "cuda", "mps"]
if device not in valid_devices:
raise ValueError(f"Invalid device '{device}'. Must be one of: {valid_devices}")
# Validate model name
model_name = self.config["model_name"]
if not isinstance(model_name, str) or not model_name.strip():
raise ValueError("model_name must be a non-empty string")
def _determine_device(self, device_config: str) -> str:
"""
Determine the best device to use based on configuration and availability.
Args:
device_config: Device configuration ("auto", "mps", "cuda", "cpu")
Returns:
Device string to use
"""
if device_config == "auto":
# Auto-detect best available device
if torch.backends.mps.is_available() and torch.backends.mps.is_built():
return "mps"
elif torch.cuda.is_available():
return "cuda"
else:
return "cpu"
else:
# Use specified device (validation happens in _validate_config)
return device_config
def _load_model(self) -> SentenceTransformer:
"""
Load SentenceTransformer model with caching.
Returns:
Loaded SentenceTransformer model
Raises:
RuntimeError: If model loading fails
"""
cache_key = f"{self.model_name}:{self.device}"
# Check cache first
if cache_key in self._model_cache:
logger.debug(f"Using cached model: {cache_key}")
return self._model_cache[cache_key]
try:
# Load model with custom cache folder if specified
if self.cache_folder:
cache_dir = Path(self.cache_folder)
cache_dir.mkdir(parents=True, exist_ok=True)
model = SentenceTransformer(
self.model_name,
cache_folder=str(cache_dir),
trust_remote_code=self.trust_remote_code
)
else:
# Use default cache behavior with multiple fallbacks for cloud environments
cache_attempts = [
# Try default cache first
None,
# HuggingFace Spaces compatible paths
os.environ.get('SENTENCE_TRANSFORMERS_HOME', '/tmp/.cache/sentence-transformers'),
'/tmp/.cache/sentence-transformers',
'/app/.cache/sentence-transformers', # Common in containerized environments
'./models/cache', # Local fallback
'/tmp/models' # Final fallback
]
model = None
last_error = None
for cache_dir in cache_attempts:
try:
if cache_dir:
# Ensure cache directory exists and is writable
os.makedirs(cache_dir, exist_ok=True)
# Test if directory is writable
test_file = os.path.join(cache_dir, '.write_test')
with open(test_file, 'w') as f:
f.write('test')
os.remove(test_file)
model = SentenceTransformer(
self.model_name,
cache_folder=cache_dir,
trust_remote_code=self.trust_remote_code
)
else:
model = SentenceTransformer(
self.model_name,
trust_remote_code=self.trust_remote_code
)
break # Success - exit loop
except (OSError, PermissionError, Exception) as e:
last_error = e
logger.warning(f"Cache attempt failed for {cache_dir}: {e}")
continue
if model is None:
raise RuntimeError(f"Failed to load model with any cache configuration. Last error: {last_error}")
# Move to device and set to eval mode
model = model.to(self.device)
model.eval()
# Cache the model
self._model_cache[cache_key] = model
logger.info(f"Loaded model {self.model_name} on device {self.device}")
return model
except Exception as e:
raise RuntimeError(f"Failed to load SentenceTransformer model '{self.model_name}': {e}") from e
def encode(self, texts: List[str]) -> np.ndarray:
"""
Encode texts to embeddings.
Args:
texts: List of text strings to embed
Returns:
numpy array of shape (len(texts), embedding_dim)
Raises:
ValueError: If texts list is empty
RuntimeError: If encoding fails
"""
if not texts:
raise ValueError("Cannot encode empty text list")
try:
with torch.no_grad():
embeddings = self._model.encode(
texts,
convert_to_numpy=True,
normalize_embeddings=self.normalize_embeddings,
batch_size=32, # Default batch size, will be overridden by BatchProcessor
show_progress_bar=False
).astype(np.float32)
# Cache embedding dimension on first use
if self._embedding_dim is None:
self._embedding_dim = embeddings.shape[1]
return embeddings
except Exception as e:
raise RuntimeError(f"Failed to encode texts: {e}") from e
def get_model_name(self) -> str:
"""
Return model identifier.
Returns:
String identifier for the embedding model
"""
return self.model_name
def get_embedding_dim(self) -> int:
"""
Return embedding dimension.
Returns:
Integer dimension of embeddings produced by this model
"""
if self._embedding_dim is None:
# Get dimension by encoding a dummy text
dummy_embedding = self.encode(["test"])
self._embedding_dim = dummy_embedding.shape[1]
return self._embedding_dim
def get_max_sequence_length(self) -> int:
"""
Return maximum sequence length supported by the model.
Returns:
Maximum number of tokens the model can process
"""
if self._max_seq_length is None:
try:
# Get max sequence length from model
self._max_seq_length = self._model.get_max_seq_length()
except AttributeError:
# Fallback for models without this method
self._max_seq_length = 512 # Common default
logger.warning(f"Could not determine max sequence length for {self.model_name}, using default: 512")
return self._max_seq_length
def is_available(self) -> bool:
"""
Check if the model is available and ready for use.
Returns:
True if model is loaded and ready, False otherwise
"""
try:
return self._model is not None and hasattr(self._model, 'encode')
except Exception:
return False
def get_device_info(self) -> Dict[str, Any]:
"""
Get information about the device being used.
Returns:
Dictionary with device information
"""
device_info = {
"device": self.device,
"device_available": True
}
if self.device == "mps":
device_info.update({
"mps_available": torch.backends.mps.is_available(),
"mps_built": torch.backends.mps.is_built()
})
elif self.device == "cuda":
device_info.update({
"cuda_available": torch.cuda.is_available(),
"cuda_device_count": torch.cuda.device_count() if torch.cuda.is_available() else 0
})
return device_info
def get_model_info(self) -> Dict[str, Any]:
"""
Get comprehensive model information.
Returns:
Dictionary with model configuration and status
"""
return {
"model_name": self.model_name,
"embedding_dim": self.get_embedding_dim(),
"max_sequence_length": self.get_max_sequence_length(),
"device": self.device,
"normalize_embeddings": self.normalize_embeddings,
"is_available": self.is_available(),
"cache_folder": self.cache_folder,
"trust_remote_code": self.trust_remote_code,
"component_type": "sentence_transformer_model"
}
@classmethod
def clear_model_cache(cls) -> None:
"""Clear the model cache to free memory."""
cls._model_cache.clear()
logger.info("SentenceTransformer model cache cleared")
@classmethod
def get_cache_info(cls) -> Dict[str, Any]:
"""
Get information about the model cache.
Returns:
Dictionary with cache statistics
"""
return {
"cached_models": list(cls._model_cache.keys()),
"cache_size": len(cls._model_cache)
}