""" Semantic Reranker implementation for Modular Retriever Architecture. This module provides a direct implementation of cross-encoder based reranking to improve retrieval quality by reordering candidates based on query relevance. """ import logging from typing import List, Dict, Any, Tuple, Optional import numpy as np from src.core.interfaces import Document from .base import Reranker logger = logging.getLogger(__name__) class SemanticReranker(Reranker): """ Semantic reranker using cross-encoder models. This is a direct implementation that uses cross-encoder models to rerank retrieval results based on query-document relevance scores. No external API dependencies are required. Features: - Cross-encoder model support (sentence-transformers) - Configurable reranking threshold - Batch processing for efficiency - Optional reranking (can be disabled) - Performance monitoring Example: config = { "model": "cross-encoder/ms-marco-MiniLM-L-6-v2", "enabled": True, "batch_size": 32, "top_k": 10, "score_threshold": 0.0 } reranker = SemanticReranker(config) results = reranker.rerank(query, documents, initial_scores) """ def __init__(self, config: Dict[str, Any]): """ Initialize semantic reranker. Args: config: Configuration dictionary with: - model: Cross-encoder model name (default: "cross-encoder/ms-marco-MiniLM-L-6-v2") - enabled: Whether reranking is enabled (default: True) - batch_size: Batch size for model inference (default: 32) - top_k: Maximum number of documents to rerank (default: 10) - score_threshold: Minimum score threshold for reranking (default: 0.0) - device: Device to run model on (default: "auto") """ self.config = config self.model_name = config.get("model", "cross-encoder/ms-marco-MiniLM-L-6-v2") self.enabled = config.get("enabled", True) self.batch_size = config.get("batch_size", 32) self.top_k = config.get("top_k", 10) self.score_threshold = config.get("score_threshold", 0.0) self.device = config.get("device", "auto") # Initialize model lazily self.model = None self._model_loaded = False self._initialized = True # Always initialized (model loading is lazy) # Validation if self.batch_size <= 0: raise ValueError("batch_size must be positive") if self.top_k <= 0: raise ValueError("top_k must be positive") logger.info(f"SemanticReranker initialized with model={self.model_name}, enabled={self.enabled}") def _load_model(self) -> None: """Load the cross-encoder model lazily.""" if self._model_loaded: return if not self.enabled: logger.info("Reranker disabled, skipping model loading") self._model_loaded = True return try: from sentence_transformers import CrossEncoder logger.info(f"Loading cross-encoder model: {self.model_name}") self.model = CrossEncoder(self.model_name, device=self.device) self._model_loaded = True logger.info("Cross-encoder model loaded successfully") except ImportError: logger.warning("sentence-transformers not available, disabling reranker") self.enabled = False self._model_loaded = True except Exception as e: logger.error(f"Failed to load cross-encoder model: {e}") self.enabled = False self._model_loaded = True def rerank( self, query: str, documents: List[Document], initial_scores: List[float] ) -> List[Tuple[int, float]]: """ Rerank documents based on query relevance. Args: query: The search query documents: List of candidate documents initial_scores: Initial relevance scores from fusion Returns: List of (document_index, reranked_score) tuples sorted by score """ if not self.enabled: # Return original ranking if reranking is disabled return [(i, score) for i, score in enumerate(initial_scores)] if not documents or not query.strip(): return [] # Load model if not already loaded self._load_model() if not self._model_loaded or self.model is None: # Fallback to original ranking return [(i, score) for i, score in enumerate(initial_scores)] # Limit to top_k documents for efficiency num_docs = min(len(documents), self.top_k) # Create query-document pairs for cross-encoder query_doc_pairs = [] doc_indices = [] for i in range(num_docs): doc_text = documents[i].content # Truncate very long documents for efficiency if len(doc_text) > 2000: doc_text = doc_text[:2000] + "..." query_doc_pairs.append([query, doc_text]) doc_indices.append(i) try: # Get cross-encoder scores in batches cross_encoder_scores = [] for i in range(0, len(query_doc_pairs), self.batch_size): batch = query_doc_pairs[i:i + self.batch_size] batch_scores = self.model.predict(batch) cross_encoder_scores.extend(batch_scores) # Create reranked results reranked_results = [] for i, score in enumerate(cross_encoder_scores): doc_idx = doc_indices[i] # Apply score threshold if score >= self.score_threshold: reranked_results.append((doc_idx, float(score))) # Add remaining documents that weren't reranked for i in range(num_docs, len(documents)): if i < len(initial_scores): reranked_results.append((i, initial_scores[i])) # Sort by reranked score (descending) reranked_results.sort(key=lambda x: x[1], reverse=True) return reranked_results except Exception as e: logger.error(f"Reranking failed: {e}, falling back to original ranking") return [(i, score) for i, score in enumerate(initial_scores)] def is_enabled(self) -> bool: """ Check if reranking is enabled. Returns: True if reranking should be performed """ return self.enabled def get_reranker_info(self) -> Dict[str, Any]: """ Get information about the reranker. Returns: Dictionary with reranker configuration and statistics """ info = { "model": self.model_name, "enabled": self.enabled, "batch_size": self.batch_size, "top_k": self.top_k, "score_threshold": self.score_threshold, "device": self.device, "model_loaded": self._model_loaded } if self.model is not None: info["model_device"] = str(self.model.device) return info def enable(self) -> None: """Enable reranking.""" self.enabled = True logger.info("Reranker enabled") def disable(self) -> None: """Disable reranking.""" self.enabled = False logger.info("Reranker disabled") def set_top_k(self, top_k: int) -> None: """ Set the maximum number of documents to rerank. Args: top_k: Maximum number of documents to rerank """ if top_k <= 0: raise ValueError("top_k must be positive") self.top_k = top_k logger.info(f"Reranker top_k set to {top_k}") def set_score_threshold(self, threshold: float) -> None: """ Set the minimum score threshold for reranking. Args: threshold: Minimum score threshold """ self.score_threshold = threshold logger.info(f"Reranker score threshold set to {threshold}") def predict_scores(self, query: str, documents: List[Document]) -> List[float]: """ Get cross-encoder scores for query-document pairs. Args: query: The search query documents: List of documents Returns: List of relevance scores """ if not self.enabled: return [0.0] * len(documents) self._load_model() if not self._model_loaded or self.model is None: return [0.0] * len(documents) # Create query-document pairs query_doc_pairs = [] for doc in documents: doc_text = doc.content if len(doc_text) > 2000: doc_text = doc_text[:2000] + "..." query_doc_pairs.append([query, doc_text]) try: # Get scores in batches scores = [] for i in range(0, len(query_doc_pairs), self.batch_size): batch = query_doc_pairs[i:i + self.batch_size] batch_scores = self.model.predict(batch) scores.extend(batch_scores) return [float(score) for score in scores] except Exception as e: logger.error(f"Score prediction failed: {e}") return [0.0] * len(documents) def get_model_info(self) -> Dict[str, Any]: """ Get information about the loaded model. Returns: Dictionary with model information """ if not self._model_loaded or self.model is None: return {"status": "not_loaded"} info = { "status": "loaded", "model_name": self.model_name, "device": str(self.model.device) if hasattr(self.model, 'device') else "unknown" } # Try to get model-specific info try: if hasattr(self.model, 'model'): info["model_type"] = type(self.model.model).__name__ if hasattr(self.model, 'tokenizer'): info["tokenizer_type"] = type(self.model.tokenizer).__name__ except: pass return info