Spaces:
Running
Running
""" | |
Cross-encoder re-ranking module for improving search result relevance. | |
""" | |
import time | |
from typing import Any, Dict, List, Optional, Tuple | |
import numpy as np | |
import torch | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
from tqdm import tqdm | |
from .error_handler import EmbeddingError, ResourceError | |
from .search_engine import SearchResult | |
class CrossEncoderReranker: | |
"""Cross-encoder model for re-ranking search results.""" | |
def __init__(self, config: Dict[str, Any]): | |
self.config = config | |
self.reranker_config = config.get("models", {}).get("reranker", {}) | |
# Configuration | |
self.model_name = self.reranker_config.get("name", "cross-encoder/ms-marco-MiniLM-L-6-v2") | |
self.max_seq_length = self.reranker_config.get("max_seq_length", 512) | |
self.batch_size = self.reranker_config.get("batch_size", 16) | |
self.enabled = self.reranker_config.get("enabled", True) | |
self.device = self._get_device() | |
# Model components | |
self.tokenizer: Optional[AutoTokenizer] = None | |
self.model: Optional[AutoModelForSequenceClassification] = None | |
self._model_loaded = False | |
# Performance tracking | |
self.stats = { | |
"reranking_operations": 0, | |
"total_pairs_scored": 0, | |
"total_time": 0, | |
"model_load_time": 0, | |
"avg_batch_size": 0 | |
} | |
def _get_device(self) -> str: | |
"""Determine the best device for computation.""" | |
if torch.cuda.is_available(): | |
return "cuda" | |
elif torch.backends.mps.is_available(): # Apple Silicon | |
return "mps" | |
else: | |
return "cpu" | |
def _load_model(self) -> None: | |
"""Lazy load the cross-encoder model.""" | |
if not self.enabled: | |
print("Re-ranker is disabled in configuration") | |
return | |
if self._model_loaded: | |
return | |
try: | |
print(f"Loading re-ranker model: {self.model_name}") | |
start_time = time.time() | |
# Load tokenizer and model | |
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) | |
self.model = AutoModelForSequenceClassification.from_pretrained(self.model_name) | |
# Move to device | |
if self.device != "cpu": | |
self.model = self.model.to(self.device) | |
# Set to evaluation mode | |
self.model.eval() | |
load_time = time.time() - start_time | |
self.stats["model_load_time"] = load_time | |
print(f"Re-ranker model loaded in {load_time:.2f}s on device: {self.device}") | |
self._model_loaded = True | |
except Exception as e: | |
print(f"Failed to load re-ranker model: {e}") | |
self.enabled = False | |
raise EmbeddingError(f"Failed to load re-ranker model: {str(e)}") from e | |
def rerank( | |
self, | |
query: str, | |
results: List[SearchResult], | |
top_k: int = None | |
) -> List[SearchResult]: | |
""" | |
Re-rank search results using cross-encoder scores. | |
Args: | |
query: Original search query | |
results: List of search results to re-rank | |
top_k: Number of top results to return after re-ranking | |
Returns: | |
Re-ranked list of SearchResult objects | |
""" | |
if not self.enabled or not results: | |
return results | |
if not query or not query.strip(): | |
return results | |
start_time = time.time() | |
try: | |
# Load model if needed | |
self._load_model() | |
if not self._model_loaded: | |
print("Re-ranker model not available, returning original results") | |
return results | |
# Prepare query-document pairs | |
pairs = [] | |
for result in results: | |
# Use content or a reasonable excerpt | |
content = result.content | |
if len(content) > 500: # Truncate very long content | |
content = content[:500] + "..." | |
pairs.append((query.strip(), content)) | |
# Score pairs | |
scores = self._score_pairs(pairs) | |
# Normalize reranker scores to 0-1 range | |
if scores and len(scores) > 0: | |
min_score = min(scores) | |
max_score = max(scores) | |
if max_score > min_score: | |
# Normalize to 0-1 range | |
normalized_scores = [(score - min_score) / (max_score - min_score) for score in scores] | |
else: | |
# All scores are the same, set to 0.5 | |
normalized_scores = [0.5] * len(scores) | |
else: | |
normalized_scores = scores | |
# Update results with re-ranking scores | |
reranked_results = [] | |
for i, result in enumerate(results): | |
# Create new result with updated scores | |
reranked_result = SearchResult( | |
chunk_id=result.chunk_id, | |
content=result.content, | |
metadata=result.metadata, | |
vector_score=result.vector_score, | |
bm25_score=result.bm25_score, | |
final_score=float(normalized_scores[i]), # Use normalized re-ranker score | |
rank=0 # Will be updated after sorting | |
) | |
reranked_results.append(reranked_result) | |
# Sort by re-ranking scores | |
reranked_results.sort(key=lambda x: x.final_score, reverse=True) | |
# Update ranks | |
for i, result in enumerate(reranked_results): | |
result.rank = i + 1 | |
# Apply top_k limit | |
if top_k is not None: | |
reranked_results = reranked_results[:top_k] | |
# Update statistics | |
reranking_time = time.time() - start_time | |
self.stats["reranking_operations"] += 1 | |
self.stats["total_pairs_scored"] += len(pairs) | |
self.stats["total_time"] += reranking_time | |
return reranked_results | |
except Exception as e: | |
print(f"Re-ranking failed, returning original results: {e}") | |
return results | |
def _score_pairs(self, pairs: List[Tuple[str, str]]) -> np.ndarray: | |
"""Score query-document pairs using the cross-encoder.""" | |
if not pairs: | |
return np.array([]) | |
scores = [] | |
# Process in batches | |
for i in range(0, len(pairs), self.batch_size): | |
batch_pairs = pairs[i:i + self.batch_size] | |
batch_scores = self._score_batch(batch_pairs) | |
scores.extend(batch_scores) | |
return np.array(scores) | |
def _score_batch(self, batch_pairs: List[Tuple[str, str]]) -> List[float]: | |
"""Score a batch of query-document pairs.""" | |
try: | |
# Prepare inputs | |
queries = [pair[0] for pair in batch_pairs] | |
documents = [pair[1] for pair in batch_pairs] | |
# Tokenize | |
inputs = self.tokenizer( | |
queries, | |
documents, | |
padding=True, | |
truncation=True, | |
max_length=self.max_seq_length, | |
return_tensors="pt" | |
) | |
# Move to device | |
if self.device != "cpu": | |
inputs = {k: v.to(self.device) for k, v in inputs.items()} | |
# Get predictions | |
with torch.no_grad(): | |
outputs = self.model(**inputs) | |
# Get logits and convert to scores | |
logits = outputs.logits | |
# For binary classification models, use sigmoid on logits | |
# For regression models, use logits directly | |
if logits.size(-1) == 1: | |
# Regression output | |
scores = logits.squeeze(-1).cpu().numpy() | |
else: | |
# Classification output - use positive class probability | |
probs = torch.softmax(logits, dim=-1) | |
scores = probs[:, 1].cpu().numpy() # Positive class | |
return scores.tolist() | |
except torch.cuda.OutOfMemoryError as e: | |
raise ResourceError( | |
"GPU memory insufficient for re-ranking. " | |
"Try reducing batch_size or using CPU." | |
) from e | |
except Exception as e: | |
raise EmbeddingError(f"Failed to score batch: {str(e)}") from e | |
def score_single_pair(self, query: str, document: str) -> float: | |
"""Score a single query-document pair.""" | |
if not self.enabled or not query.strip() or not document.strip(): | |
return 0.0 | |
try: | |
scores = self._score_pairs([(query, document)]) | |
return float(scores[0]) if len(scores) > 0 else 0.0 | |
except Exception as e: | |
print(f"Failed to score single pair: {e}") | |
return 0.0 | |
def warmup(self) -> None: | |
"""Warm up the re-ranker with a sample query-document pair.""" | |
if not self.enabled: | |
return | |
self._load_model() | |
if not self._model_loaded: | |
return | |
# Run a sample prediction to warm up | |
sample_pairs = [("sample query", "sample document text")] | |
try: | |
self._score_pairs(sample_pairs) | |
print("Re-ranker model warmed up successfully") | |
except Exception as e: | |
print(f"Re-ranker warmup failed: {e}") | |
def get_stats(self) -> Dict[str, Any]: | |
"""Get re-ranker performance statistics.""" | |
stats = self.stats.copy() | |
if stats["reranking_operations"] > 0: | |
stats["avg_time_per_operation"] = stats["total_time"] / stats["reranking_operations"] | |
stats["avg_pairs_per_operation"] = stats["total_pairs_scored"] / stats["reranking_operations"] | |
else: | |
stats["avg_time_per_operation"] = 0 | |
stats["avg_pairs_per_operation"] = 0 | |
if stats["total_pairs_scored"] > 0: | |
stats["avg_time_per_pair"] = stats["total_time"] / stats["total_pairs_scored"] | |
else: | |
stats["avg_time_per_pair"] = 0 | |
stats["model_loaded"] = self._model_loaded | |
stats["enabled"] = self.enabled | |
stats["device"] = self.device | |
stats["model_name"] = self.model_name | |
return stats | |
def unload_model(self) -> None: | |
"""Unload the model to free memory.""" | |
if self.model is not None: | |
del self.model | |
del self.tokenizer | |
self.model = None | |
self.tokenizer = None | |
self._model_loaded = False | |
# Clear GPU cache if using CUDA | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
print("Re-ranker model unloaded") | |
def is_available(self) -> bool: | |
"""Check if re-ranker is available and enabled.""" | |
return self.enabled and self._model_loaded | |
class RerankingPipeline: | |
"""Pipeline for applying re-ranking with fallback options.""" | |
def __init__(self, config: Dict[str, Any]): | |
self.config = config | |
self.search_config = config.get("search", {}) | |
# Re-ranking configuration | |
self.rerank_top_k = self.search_config.get("rerank_top_k", 50) | |
self.final_top_k = self.search_config.get("final_top_k", 10) | |
self.enable_reranking = config.get("models", {}).get("reranker", {}).get("enabled", True) | |
# Initialize re-ranker | |
self.reranker = CrossEncoderReranker(config) if self.enable_reranking else None | |
# Statistics | |
self.stats = { | |
"pipeline_calls": 0, | |
"reranking_applied": 0, | |
"fallback_used": 0, | |
"avg_input_results": 0, | |
"avg_output_results": 0 | |
} | |
def process( | |
self, | |
query: str, | |
results: List[SearchResult], | |
apply_reranking: bool = True | |
) -> List[SearchResult]: | |
""" | |
Process search results through the re-ranking pipeline. | |
Args: | |
query: Original search query | |
results: Search results to process | |
apply_reranking: Whether to apply re-ranking | |
Returns: | |
Processed results (re-ranked if enabled and successful) | |
""" | |
if not results: | |
return results | |
start_input_count = len(results) | |
self.stats["pipeline_calls"] += 1 | |
self.stats["avg_input_results"] = ( | |
(self.stats["avg_input_results"] * (self.stats["pipeline_calls"] - 1) + start_input_count) | |
/ self.stats["pipeline_calls"] | |
) | |
# Apply re-ranking if enabled and requested | |
if (apply_reranking and | |
self.enable_reranking and | |
self.reranker is not None and | |
len(results) > 1): | |
try: | |
# Limit candidates for re-ranking to improve performance | |
candidates = results[:self.rerank_top_k] | |
# Apply re-ranking | |
reranked_results = self.reranker.rerank(query, candidates) | |
# Combine with remaining results if any | |
if len(results) > self.rerank_top_k: | |
remaining_results = results[self.rerank_top_k:] | |
# Adjust ranks for remaining results | |
for i, result in enumerate(remaining_results): | |
result.rank = len(reranked_results) + i + 1 | |
final_results = reranked_results + remaining_results | |
else: | |
final_results = reranked_results | |
self.stats["reranking_applied"] += 1 | |
except Exception as e: | |
print(f"Re-ranking failed, using original results: {e}") | |
final_results = results | |
self.stats["fallback_used"] += 1 | |
else: | |
final_results = results | |
# Apply final top-k limit | |
final_results = final_results[:self.final_top_k] | |
# Update output statistics | |
output_count = len(final_results) | |
self.stats["avg_output_results"] = ( | |
(self.stats["avg_output_results"] * (self.stats["pipeline_calls"] - 1) + output_count) | |
/ self.stats["pipeline_calls"] | |
) | |
return final_results | |
def get_stats(self) -> Dict[str, Any]: | |
"""Get pipeline statistics.""" | |
stats = self.stats.copy() | |
if self.reranker: | |
stats["reranker_stats"] = self.reranker.get_stats() | |
stats["reranking_enabled"] = self.enable_reranking | |
stats["rerank_top_k"] = self.rerank_top_k | |
stats["final_top_k"] = self.final_top_k | |
if stats["pipeline_calls"] > 0: | |
stats["reranking_success_rate"] = stats["reranking_applied"] / stats["pipeline_calls"] | |
stats["fallback_rate"] = stats["fallback_used"] / stats["pipeline_calls"] | |
else: | |
stats["reranking_success_rate"] = 0 | |
stats["fallback_rate"] = 0 | |
return stats | |
def warmup(self) -> None: | |
"""Warm up the re-ranking pipeline.""" | |
if self.reranker: | |
self.reranker.warmup() | |
def unload_models(self) -> None: | |
"""Unload re-ranker models to free memory.""" | |
if self.reranker: | |
self.reranker.unload_model() |