|
""" |
|
Cross-encoder re-ranking module for improving search result relevance. |
|
""" |
|
|
|
import time |
|
from typing import Any, Dict, List, Optional, Tuple |
|
import numpy as np |
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
from tqdm import tqdm |
|
|
|
from .error_handler import EmbeddingError, ResourceError |
|
from .search_engine import SearchResult |
|
|
|
|
|
class CrossEncoderReranker: |
|
"""Cross-encoder model for re-ranking search results.""" |
|
|
|
def __init__(self, config: Dict[str, Any]): |
|
self.config = config |
|
self.reranker_config = config.get("models", {}).get("reranker", {}) |
|
|
|
|
|
self.model_name = self.reranker_config.get("name", "cross-encoder/ms-marco-MiniLM-L-6-v2") |
|
self.max_seq_length = self.reranker_config.get("max_seq_length", 512) |
|
self.batch_size = self.reranker_config.get("batch_size", 16) |
|
self.enabled = self.reranker_config.get("enabled", True) |
|
self.device = self._get_device() |
|
|
|
|
|
self.tokenizer: Optional[AutoTokenizer] = None |
|
self.model: Optional[AutoModelForSequenceClassification] = None |
|
self._model_loaded = False |
|
|
|
|
|
self.stats = { |
|
"reranking_operations": 0, |
|
"total_pairs_scored": 0, |
|
"total_time": 0, |
|
"model_load_time": 0, |
|
"avg_batch_size": 0 |
|
} |
|
|
|
def _get_device(self) -> str: |
|
"""Determine the best device for computation.""" |
|
if torch.cuda.is_available(): |
|
return "cuda" |
|
elif torch.backends.mps.is_available(): |
|
return "mps" |
|
else: |
|
return "cpu" |
|
|
|
def _load_model(self) -> None: |
|
"""Lazy load the cross-encoder model.""" |
|
if not self.enabled: |
|
print("Re-ranker is disabled in configuration") |
|
return |
|
|
|
if self._model_loaded: |
|
return |
|
|
|
try: |
|
print(f"Loading re-ranker model: {self.model_name}") |
|
start_time = time.time() |
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) |
|
self.model = AutoModelForSequenceClassification.from_pretrained(self.model_name) |
|
|
|
|
|
if self.device != "cpu": |
|
self.model = self.model.to(self.device) |
|
|
|
|
|
self.model.eval() |
|
|
|
load_time = time.time() - start_time |
|
self.stats["model_load_time"] = load_time |
|
|
|
print(f"Re-ranker model loaded in {load_time:.2f}s on device: {self.device}") |
|
self._model_loaded = True |
|
|
|
except Exception as e: |
|
print(f"Failed to load re-ranker model: {e}") |
|
self.enabled = False |
|
raise EmbeddingError(f"Failed to load re-ranker model: {str(e)}") from e |
|
|
|
def rerank( |
|
self, |
|
query: str, |
|
results: List[SearchResult], |
|
top_k: int = None |
|
) -> List[SearchResult]: |
|
""" |
|
Re-rank search results using cross-encoder scores. |
|
|
|
Args: |
|
query: Original search query |
|
results: List of search results to re-rank |
|
top_k: Number of top results to return after re-ranking |
|
|
|
Returns: |
|
Re-ranked list of SearchResult objects |
|
""" |
|
if not self.enabled or not results: |
|
return results |
|
|
|
if not query or not query.strip(): |
|
return results |
|
|
|
start_time = time.time() |
|
|
|
try: |
|
|
|
self._load_model() |
|
|
|
if not self._model_loaded: |
|
print("Re-ranker model not available, returning original results") |
|
return results |
|
|
|
|
|
pairs = [] |
|
for result in results: |
|
|
|
content = result.content |
|
if len(content) > 500: |
|
content = content[:500] + "..." |
|
|
|
pairs.append((query.strip(), content)) |
|
|
|
|
|
scores = self._score_pairs(pairs) |
|
|
|
|
|
if scores and len(scores) > 0: |
|
min_score = min(scores) |
|
max_score = max(scores) |
|
if max_score > min_score: |
|
|
|
normalized_scores = [(score - min_score) / (max_score - min_score) for score in scores] |
|
else: |
|
|
|
normalized_scores = [0.5] * len(scores) |
|
else: |
|
normalized_scores = scores |
|
|
|
|
|
reranked_results = [] |
|
for i, result in enumerate(results): |
|
|
|
reranked_result = SearchResult( |
|
chunk_id=result.chunk_id, |
|
content=result.content, |
|
metadata=result.metadata, |
|
vector_score=result.vector_score, |
|
bm25_score=result.bm25_score, |
|
final_score=float(normalized_scores[i]), |
|
rank=0 |
|
) |
|
reranked_results.append(reranked_result) |
|
|
|
|
|
reranked_results.sort(key=lambda x: x.final_score, reverse=True) |
|
|
|
|
|
for i, result in enumerate(reranked_results): |
|
result.rank = i + 1 |
|
|
|
|
|
if top_k is not None: |
|
reranked_results = reranked_results[:top_k] |
|
|
|
|
|
reranking_time = time.time() - start_time |
|
self.stats["reranking_operations"] += 1 |
|
self.stats["total_pairs_scored"] += len(pairs) |
|
self.stats["total_time"] += reranking_time |
|
|
|
return reranked_results |
|
|
|
except Exception as e: |
|
print(f"Re-ranking failed, returning original results: {e}") |
|
return results |
|
|
|
def _score_pairs(self, pairs: List[Tuple[str, str]]) -> np.ndarray: |
|
"""Score query-document pairs using the cross-encoder.""" |
|
if not pairs: |
|
return np.array([]) |
|
|
|
scores = [] |
|
|
|
|
|
for i in range(0, len(pairs), self.batch_size): |
|
batch_pairs = pairs[i:i + self.batch_size] |
|
batch_scores = self._score_batch(batch_pairs) |
|
scores.extend(batch_scores) |
|
|
|
return np.array(scores) |
|
|
|
def _score_batch(self, batch_pairs: List[Tuple[str, str]]) -> List[float]: |
|
"""Score a batch of query-document pairs.""" |
|
try: |
|
|
|
queries = [pair[0] for pair in batch_pairs] |
|
documents = [pair[1] for pair in batch_pairs] |
|
|
|
|
|
inputs = self.tokenizer( |
|
queries, |
|
documents, |
|
padding=True, |
|
truncation=True, |
|
max_length=self.max_seq_length, |
|
return_tensors="pt" |
|
) |
|
|
|
|
|
if self.device != "cpu": |
|
inputs = {k: v.to(self.device) for k, v in inputs.items()} |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = self.model(**inputs) |
|
|
|
|
|
logits = outputs.logits |
|
|
|
|
|
|
|
if logits.size(-1) == 1: |
|
|
|
scores = logits.squeeze(-1).cpu().numpy() |
|
else: |
|
|
|
probs = torch.softmax(logits, dim=-1) |
|
scores = probs[:, 1].cpu().numpy() |
|
|
|
return scores.tolist() |
|
|
|
except torch.cuda.OutOfMemoryError as e: |
|
raise ResourceError( |
|
"GPU memory insufficient for re-ranking. " |
|
"Try reducing batch_size or using CPU." |
|
) from e |
|
except Exception as e: |
|
raise EmbeddingError(f"Failed to score batch: {str(e)}") from e |
|
|
|
def score_single_pair(self, query: str, document: str) -> float: |
|
"""Score a single query-document pair.""" |
|
if not self.enabled or not query.strip() or not document.strip(): |
|
return 0.0 |
|
|
|
try: |
|
scores = self._score_pairs([(query, document)]) |
|
return float(scores[0]) if len(scores) > 0 else 0.0 |
|
except Exception as e: |
|
print(f"Failed to score single pair: {e}") |
|
return 0.0 |
|
|
|
def warmup(self) -> None: |
|
"""Warm up the re-ranker with a sample query-document pair.""" |
|
if not self.enabled: |
|
return |
|
|
|
self._load_model() |
|
|
|
if not self._model_loaded: |
|
return |
|
|
|
|
|
sample_pairs = [("sample query", "sample document text")] |
|
try: |
|
self._score_pairs(sample_pairs) |
|
print("Re-ranker model warmed up successfully") |
|
except Exception as e: |
|
print(f"Re-ranker warmup failed: {e}") |
|
|
|
def get_stats(self) -> Dict[str, Any]: |
|
"""Get re-ranker performance statistics.""" |
|
stats = self.stats.copy() |
|
|
|
if stats["reranking_operations"] > 0: |
|
stats["avg_time_per_operation"] = stats["total_time"] / stats["reranking_operations"] |
|
stats["avg_pairs_per_operation"] = stats["total_pairs_scored"] / stats["reranking_operations"] |
|
else: |
|
stats["avg_time_per_operation"] = 0 |
|
stats["avg_pairs_per_operation"] = 0 |
|
|
|
if stats["total_pairs_scored"] > 0: |
|
stats["avg_time_per_pair"] = stats["total_time"] / stats["total_pairs_scored"] |
|
else: |
|
stats["avg_time_per_pair"] = 0 |
|
|
|
stats["model_loaded"] = self._model_loaded |
|
stats["enabled"] = self.enabled |
|
stats["device"] = self.device |
|
stats["model_name"] = self.model_name |
|
|
|
return stats |
|
|
|
def unload_model(self) -> None: |
|
"""Unload the model to free memory.""" |
|
if self.model is not None: |
|
del self.model |
|
del self.tokenizer |
|
self.model = None |
|
self.tokenizer = None |
|
self._model_loaded = False |
|
|
|
|
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
|
|
print("Re-ranker model unloaded") |
|
|
|
def is_available(self) -> bool: |
|
"""Check if re-ranker is available and enabled.""" |
|
return self.enabled and self._model_loaded |
|
|
|
|
|
class RerankingPipeline: |
|
"""Pipeline for applying re-ranking with fallback options.""" |
|
|
|
def __init__(self, config: Dict[str, Any]): |
|
self.config = config |
|
self.search_config = config.get("search", {}) |
|
|
|
|
|
self.rerank_top_k = self.search_config.get("rerank_top_k", 50) |
|
self.final_top_k = self.search_config.get("final_top_k", 10) |
|
self.enable_reranking = config.get("models", {}).get("reranker", {}).get("enabled", True) |
|
|
|
|
|
self.reranker = CrossEncoderReranker(config) if self.enable_reranking else None |
|
|
|
|
|
self.stats = { |
|
"pipeline_calls": 0, |
|
"reranking_applied": 0, |
|
"fallback_used": 0, |
|
"avg_input_results": 0, |
|
"avg_output_results": 0 |
|
} |
|
|
|
def process( |
|
self, |
|
query: str, |
|
results: List[SearchResult], |
|
apply_reranking: bool = True |
|
) -> List[SearchResult]: |
|
""" |
|
Process search results through the re-ranking pipeline. |
|
|
|
Args: |
|
query: Original search query |
|
results: Search results to process |
|
apply_reranking: Whether to apply re-ranking |
|
|
|
Returns: |
|
Processed results (re-ranked if enabled and successful) |
|
""" |
|
if not results: |
|
return results |
|
|
|
start_input_count = len(results) |
|
self.stats["pipeline_calls"] += 1 |
|
self.stats["avg_input_results"] = ( |
|
(self.stats["avg_input_results"] * (self.stats["pipeline_calls"] - 1) + start_input_count) |
|
/ self.stats["pipeline_calls"] |
|
) |
|
|
|
|
|
if (apply_reranking and |
|
self.enable_reranking and |
|
self.reranker is not None and |
|
len(results) > 1): |
|
|
|
try: |
|
|
|
candidates = results[:self.rerank_top_k] |
|
|
|
|
|
reranked_results = self.reranker.rerank(query, candidates) |
|
|
|
|
|
if len(results) > self.rerank_top_k: |
|
remaining_results = results[self.rerank_top_k:] |
|
|
|
for i, result in enumerate(remaining_results): |
|
result.rank = len(reranked_results) + i + 1 |
|
|
|
final_results = reranked_results + remaining_results |
|
else: |
|
final_results = reranked_results |
|
|
|
self.stats["reranking_applied"] += 1 |
|
|
|
except Exception as e: |
|
print(f"Re-ranking failed, using original results: {e}") |
|
final_results = results |
|
self.stats["fallback_used"] += 1 |
|
else: |
|
final_results = results |
|
|
|
|
|
final_results = final_results[:self.final_top_k] |
|
|
|
|
|
output_count = len(final_results) |
|
self.stats["avg_output_results"] = ( |
|
(self.stats["avg_output_results"] * (self.stats["pipeline_calls"] - 1) + output_count) |
|
/ self.stats["pipeline_calls"] |
|
) |
|
|
|
return final_results |
|
|
|
def get_stats(self) -> Dict[str, Any]: |
|
"""Get pipeline statistics.""" |
|
stats = self.stats.copy() |
|
|
|
if self.reranker: |
|
stats["reranker_stats"] = self.reranker.get_stats() |
|
|
|
stats["reranking_enabled"] = self.enable_reranking |
|
stats["rerank_top_k"] = self.rerank_top_k |
|
stats["final_top_k"] = self.final_top_k |
|
|
|
if stats["pipeline_calls"] > 0: |
|
stats["reranking_success_rate"] = stats["reranking_applied"] / stats["pipeline_calls"] |
|
stats["fallback_rate"] = stats["fallback_used"] / stats["pipeline_calls"] |
|
else: |
|
stats["reranking_success_rate"] = 0 |
|
stats["fallback_rate"] = 0 |
|
|
|
return stats |
|
|
|
def warmup(self) -> None: |
|
"""Warm up the re-ranking pipeline.""" |
|
if self.reranker: |
|
self.reranker.warmup() |
|
|
|
def unload_models(self) -> None: |
|
"""Unload re-ranker models to free memory.""" |
|
if self.reranker: |
|
self.reranker.unload_model() |