Spaces:
Running
Running
""" | |
Sentence Transformer embedder adapter for the modular RAG system. | |
This module provides an adapter that wraps the existing embedding generation | |
functionality to conform to the Embedder interface, enabling it to be used | |
in the modular architecture while preserving all existing functionality. | |
""" | |
import sys | |
from pathlib import Path | |
from typing import List, Dict, Any, Optional, TYPE_CHECKING | |
# Add project root to path for imports | |
project_root = Path(__file__).parent.parent.parent.parent.parent | |
sys.path.append(str(project_root)) | |
from src.core.interfaces import Embedder, HealthStatus | |
from shared_utils.embeddings.generator import generate_embeddings | |
if TYPE_CHECKING: | |
from src.core.platform_orchestrator import PlatformOrchestrator | |
class SentenceTransformerEmbedder(Embedder): | |
""" | |
Adapter for existing sentence transformer embedding generator. | |
This class wraps the generate_embeddings function to provide an Embedder | |
interface while maintaining all the performance optimizations and caching | |
capabilities of the original implementation. | |
Features: | |
- Content-based caching for performance | |
- Apple Silicon MPS acceleration | |
- Batch processing for efficiency | |
- 384-dimensional embeddings | |
- 100+ texts/second on M4-Pro | |
Example: | |
embedder = SentenceTransformerEmbedder( | |
model_name="sentence-transformers/multi-qa-MiniLM-L6-cos-v1", | |
use_mps=True | |
) | |
embeddings = embedder.embed(["Hello world", "How are you?"]) | |
""" | |
def __init__( | |
self, | |
model_name: str = "sentence-transformers/multi-qa-MiniLM-L6-cos-v1", | |
batch_size: int = 32, | |
use_mps: bool = True | |
): | |
""" | |
Initialize the sentence transformer embedder. | |
Args: | |
model_name: SentenceTransformer model identifier (default: multi-qa-MiniLM-L6-cos-v1) | |
batch_size: Processing batch size for efficiency (default: 32) | |
use_mps: Use Apple Silicon acceleration if available (default: True) | |
""" | |
self.model_name = model_name | |
self.batch_size = batch_size | |
self.use_mps = use_mps | |
self._embedding_dim = None | |
# Platform services (initialized via initialize_services) | |
self.platform: Optional['PlatformOrchestrator'] = None | |
def embed(self, texts: List[str]) -> List[List[float]]: | |
""" | |
Generate embeddings for a list of texts. | |
This method uses the existing generate_embeddings function with | |
performance optimizations including content-based caching and | |
MPS acceleration. | |
Args: | |
texts: List of text strings to embed | |
Returns: | |
List of embedding vectors, where each vector is a list of floats | |
Raises: | |
ValueError: If texts list is empty | |
RuntimeError: If embedding generation fails | |
""" | |
if not texts: | |
raise ValueError("Cannot generate embeddings for empty text list") | |
try: | |
# Use existing function with caching and optimization | |
embeddings_array = generate_embeddings( | |
texts=texts, | |
model_name=self.model_name, | |
batch_size=self.batch_size, | |
use_mps=self.use_mps | |
) | |
# Convert numpy array to list of lists | |
embeddings_list = embeddings_array.tolist() | |
# Cache embedding dimension for future reference | |
if self._embedding_dim is None and embeddings_list: | |
self._embedding_dim = len(embeddings_list[0]) | |
return embeddings_list | |
except Exception as e: | |
raise RuntimeError(f"Failed to generate embeddings: {str(e)}") from e | |
def embedding_dim(self) -> int: | |
""" | |
Get the embedding dimension. | |
Returns: | |
Integer dimension of embeddings (typically 384 for multi-qa-MiniLM-L6-cos-v1) | |
Note: | |
If embeddings haven't been generated yet, this method will generate | |
a dummy embedding to determine the dimension. | |
""" | |
if self._embedding_dim is not None: | |
return self._embedding_dim | |
# Generate a dummy embedding to get dimension | |
dummy_embeddings = self.embed(["test"]) | |
return len(dummy_embeddings[0]) | |
def get_model_info(self) -> dict: | |
""" | |
Get information about the current model configuration. | |
Returns: | |
Dictionary with model configuration details | |
""" | |
return { | |
"model_name": self.model_name, | |
"batch_size": self.batch_size, | |
"use_mps": self.use_mps, | |
"embedding_dimension": self.embedding_dim() if self._embedding_dim else "unknown", | |
"component_type": "sentence_transformer" | |
} | |
def supports_batching(self) -> bool: | |
""" | |
Check if this embedder supports batch processing. | |
Returns: | |
True, as this implementation supports efficient batch processing | |
""" | |
return True | |
def get_cache_stats(self) -> dict: | |
""" | |
Get statistics about the embedding cache. | |
Note: This would require access to the cache from the original function. | |
For now, returns basic info about caching support. | |
Returns: | |
Dictionary with cache information | |
""" | |
return { | |
"caching_enabled": True, | |
"cache_type": "content_based", | |
"note": "Cache statistics require access to global cache from generator module" | |
} | |
# ComponentBase interface implementation | |
def initialize_services(self, platform: 'PlatformOrchestrator') -> None: | |
"""Initialize platform services for the component. | |
Args: | |
platform: PlatformOrchestrator instance providing services | |
""" | |
self.platform = platform | |
def get_health_status(self) -> HealthStatus: | |
"""Get the current health status of the component. | |
Returns: | |
HealthStatus object with component health information | |
""" | |
if self.platform: | |
return self.platform.check_component_health(self) | |
# Fallback if platform services not initialized | |
is_healthy = True | |
issues = [] | |
# Basic health checks | |
if not self.model_name: | |
is_healthy = False | |
issues.append("Model name not configured") | |
if self.batch_size <= 0: | |
is_healthy = False | |
issues.append("Invalid batch size") | |
return HealthStatus( | |
is_healthy=is_healthy, | |
issues=issues, | |
metrics={ | |
"model_name": self.model_name, | |
"batch_size": self.batch_size, | |
"use_mps": self.use_mps, | |
"embedding_dim": self._embedding_dim | |
}, | |
component_name=self.__class__.__name__ | |
) | |
def get_metrics(self) -> Dict[str, Any]: | |
"""Get component-specific metrics. | |
Returns: | |
Dictionary containing component metrics | |
""" | |
if self.platform: | |
try: | |
component_metrics = self.platform.analytics_service.collect_component_metrics(self) | |
return { | |
"component_name": component_metrics.component_name, | |
"component_type": component_metrics.component_type, | |
"success_count": component_metrics.success_count, | |
"error_count": component_metrics.error_count, | |
"resource_usage": component_metrics.resource_usage, | |
"performance_metrics": component_metrics.performance_metrics, | |
"timestamp": component_metrics.timestamp | |
} | |
except Exception: | |
# Fallback if platform service fails | |
pass | |
# Fallback metrics | |
return { | |
"model_name": self.model_name, | |
"batch_size": self.batch_size, | |
"use_mps": self.use_mps, | |
"embedding_dimension": self._embedding_dim, | |
"caching_enabled": True, | |
"component_type": "sentence_transformer_embedder" | |
} | |
def get_capabilities(self) -> List[str]: | |
"""Get list of component capabilities. | |
Returns: | |
List of capability strings | |
""" | |
return [ | |
"text_embedding", | |
"batch_processing", | |
"content_based_caching", | |
"mps_acceleration", | |
"sentence_transformer_models", | |
"384_dimensional_embeddings" | |
] |