Spaces:
Running
Running
""" | |
Model Management for Neural Reranking. | |
This module provides sophisticated model management capabilities for neural | |
reranking including multi-backend support, lazy loading, caching, and | |
performance optimization for cross-encoder transformer models. | |
Simplified from reranking/cross_encoder_models.py for integration with | |
the enhanced neural reranker in the rerankers/ component. | |
""" | |
import logging | |
import time | |
import os | |
from typing import Dict, List, Optional, Any, Union | |
import threading | |
from dataclasses import dataclass | |
import numpy as np | |
logger = logging.getLogger(__name__) | |
class ModelConfig: | |
"""Configuration for individual neural reranking models.""" | |
# Model identification | |
name: str = "cross-encoder/ms-marco-MiniLM-L6-v2" | |
backend: str = "sentence_transformers" # "sentence_transformers", "tensorflow", "keras", "huggingface_api" | |
model_type: str = "cross_encoder" # "cross_encoder", "bi_encoder", "ensemble" | |
# Model parameters | |
max_length: int = 512 | |
device: str = "auto" # "auto", "cpu", "cuda", "mps" | |
cache_size: int = 1000 | |
# Performance settings | |
batch_size: int = 16 | |
optimization_level: str = "balanced" # "speed", "balanced", "quality" | |
enable_quantization: bool = False | |
# Model-specific settings | |
trust_remote_code: bool = False | |
local_files_only: bool = False | |
revision: Optional[str] = None | |
# HuggingFace API settings (for backend="huggingface_api") | |
api_token: Optional[str] = None | |
timeout: int = 30 | |
fallback_to_local: bool = True | |
max_candidates: int = 100 | |
score_threshold: float = 0.0 | |
class ModelInfo: | |
"""Information about a loaded model.""" | |
name: str | |
backend: str | |
device: str | |
loaded: bool = False | |
load_time: float = 0.0 | |
inference_count: int = 0 | |
total_inference_time: float = 0.0 | |
last_used: float = 0.0 | |
memory_usage_mb: float = 0.0 | |
error_count: int = 0 | |
class ModelManager: | |
""" | |
Manager for individual cross-encoder models. | |
Handles model loading, caching, and lifecycle management for a single | |
cross-encoder model with support for multiple backends. | |
""" | |
def __init__(self, name: str, config: ModelConfig): | |
""" | |
Initialize model manager. | |
Args: | |
name: Model identifier | |
config: Model configuration | |
""" | |
self.name = name | |
self.config = config | |
self.model = None | |
self.tokenizer = None | |
self._lock = threading.Lock() | |
self.info = ModelInfo( | |
name=name, | |
backend=config.backend, | |
device=config.device | |
) | |
logger.info(f"ModelManager created for {name} ({config.backend})") | |
def load_model(self) -> bool: | |
""" | |
Load the model if not already loaded. | |
Returns: | |
True if model loaded successfully | |
""" | |
with self._lock: | |
if self.info.loaded: | |
return True | |
try: | |
start_time = time.time() | |
if self.config.backend == "sentence_transformers": | |
self._load_sentence_transformer() | |
elif self.config.backend == "huggingface_api": | |
self._load_huggingface_api() | |
else: | |
raise ValueError(f"Unsupported backend: {self.config.backend}") | |
load_time = time.time() - start_time | |
self.info.load_time = load_time | |
self.info.loaded = True | |
self.info.last_used = time.time() | |
logger.info(f"Model {self.name} loaded in {load_time:.2f}s") | |
return True | |
except Exception as e: | |
logger.error(f"Failed to load model {self.name}: {e}") | |
self.info.error_count += 1 | |
return False | |
def _resolve_device(self) -> str: | |
""" | |
Resolve 'auto' device to appropriate device for current system. | |
Returns: | |
Resolved device string ('cpu', 'cuda', 'mps', etc.) | |
""" | |
if self.config.device != "auto": | |
return self.config.device | |
try: | |
import torch | |
# Check for MPS (Metal Performance Shaders) on Apple Silicon | |
if torch.backends.mps.is_available() and torch.backends.mps.is_built(): | |
return "mps" | |
# Check for CUDA | |
if torch.cuda.is_available(): | |
return "cuda" | |
# Fall back to CPU | |
return "cpu" | |
except ImportError: | |
# If torch is not available, default to CPU | |
return "cpu" | |
def _load_sentence_transformer(self): | |
"""Load model using sentence-transformers.""" | |
try: | |
from sentence_transformers import CrossEncoder | |
# Resolve device if set to 'auto' | |
device = self._resolve_device() | |
self.model = CrossEncoder( | |
self.config.name, | |
max_length=self.config.max_length, | |
device=device, | |
trust_remote_code=self.config.trust_remote_code | |
) | |
logger.debug(f"Sentence transformer model loaded: {self.config.name} on device: {device}") | |
except ImportError: | |
raise ImportError("sentence-transformers library not available") | |
except Exception as e: | |
raise RuntimeError(f"Failed to load sentence transformer: {e}") | |
def _load_huggingface_api(self): | |
"""Load model using HuggingFace Inference API.""" | |
try: | |
from huggingface_hub import InferenceClient | |
# Get API token from config or environment | |
api_token = ( | |
self.config.api_token or | |
os.getenv("HF_TOKEN") or | |
os.getenv("HUGGINGFACE_API_TOKEN") or | |
os.getenv("HF_API_TOKEN") | |
) | |
if not api_token: | |
raise ValueError("HuggingFace API token required for huggingface_api backend") | |
# Create inference client | |
self.model = InferenceClient(token=api_token) | |
self.api_model_name = self.config.name | |
logger.debug(f"HuggingFace API client initialized for model: {self.config.name}") | |
except ImportError: | |
raise ImportError("huggingface_hub library not available. Install with: pip install huggingface-hub") | |
except Exception as e: | |
raise RuntimeError(f"Failed to initialize HuggingFace API client: {e}") | |
def predict(self, query_doc_pairs: List[List[str]]) -> List[float]: | |
""" | |
Generate predictions for query-document pairs. | |
Args: | |
query_doc_pairs: List of [query, document] pairs | |
Returns: | |
List of relevance scores | |
""" | |
if not self.info.loaded and not self.load_model(): | |
raise RuntimeError(f"Model {self.name} not available") | |
start_time = time.time() | |
try: | |
if self.config.backend == "huggingface_api": | |
scores = self._predict_api(query_doc_pairs) | |
else: | |
scores = self._predict_local(query_doc_pairs) | |
# Update statistics | |
inference_time = time.time() - start_time | |
self.info.inference_count += 1 | |
self.info.total_inference_time += inference_time | |
self.info.last_used = time.time() | |
return scores | |
except Exception as e: | |
self.info.error_count += 1 | |
logger.error(f"Model prediction failed for {self.name}: {e}") | |
# Try fallback to local if API fails and fallback is enabled | |
if self.config.backend == "huggingface_api" and self.config.fallback_to_local: | |
logger.warning(f"API prediction failed, attempting fallback to local model") | |
try: | |
return self._fallback_to_local(query_doc_pairs) | |
except Exception as fallback_error: | |
logger.error(f"Fallback to local model also failed: {fallback_error}") | |
raise | |
def _predict_local(self, query_doc_pairs: List[List[str]]) -> List[float]: | |
""" | |
Generate predictions using local model. | |
Args: | |
query_doc_pairs: List of [query, document] pairs | |
Returns: | |
List of relevance scores | |
""" | |
scores = self.model.predict(query_doc_pairs) | |
# Convert to list if numpy array | |
if hasattr(scores, 'tolist'): | |
scores = scores.tolist() | |
return scores | |
def _predict_api(self, query_doc_pairs: List[List[str]]) -> List[float]: | |
""" | |
Generate predictions using HuggingFace API. | |
Args: | |
query_doc_pairs: List of [query, document] pairs | |
Returns: | |
List of relevance scores | |
""" | |
# Filter by max_candidates if specified | |
if self.config.max_candidates > 0 and len(query_doc_pairs) > self.config.max_candidates: | |
query_doc_pairs = query_doc_pairs[:self.config.max_candidates] | |
logger.debug(f"Filtered to {self.config.max_candidates} candidates for API efficiency") | |
# Group by query for efficient batch processing | |
query_groups = {} | |
for i, (query, document) in enumerate(query_doc_pairs): | |
if query not in query_groups: | |
query_groups[query] = [] | |
query_groups[query].append((i, document)) | |
# Process each query group | |
all_scores = [0.0] * len(query_doc_pairs) | |
for query, doc_pairs in query_groups.items(): | |
try: | |
# Prepare documents for this query | |
documents = [] | |
indices = [] | |
for idx, document in doc_pairs: | |
# Truncate document if too long | |
if len(document) > self.config.max_length: | |
document = document[:self.config.max_length - 50] + "..." | |
documents.append(document) | |
indices.append(idx) | |
# Use HuggingFace API for cross-encoder text ranking | |
# Format: {"inputs": {"source_sentence": "query", "sentences": ["doc1", "doc2", ...]}} | |
import requests | |
api_url = f"https://api-inference.huggingface.co/models/{self.api_model_name}" | |
headers = {"Authorization": f"Bearer {self.config.api_token}"} | |
payload = { | |
"inputs": { | |
"source_sentence": query, | |
"sentences": documents | |
} | |
} | |
response = requests.post(api_url, headers=headers, json=payload, timeout=self.config.timeout) | |
if response.status_code == 200: | |
result = response.json() | |
# Extract scores from response | |
if isinstance(result, list) and len(result) == len(documents): | |
for i, score in enumerate(result): | |
if isinstance(score, dict) and 'score' in score: | |
all_scores[indices[i]] = float(score['score']) | |
elif isinstance(score, (int, float)): | |
all_scores[indices[i]] = float(score) | |
else: | |
all_scores[indices[i]] = 0.0 | |
else: | |
logger.warning(f"Unexpected API response format: {result}") | |
for idx in indices: | |
all_scores[idx] = 0.0 | |
else: | |
logger.warning(f"API request failed: {response.status_code} - {response.text}") | |
for idx in indices: | |
all_scores[idx] = 0.0 | |
except Exception as e: | |
logger.warning(f"API prediction failed for query '{query}': {e}") | |
for idx in indices: | |
all_scores[idx] = 0.0 | |
# Apply score threshold filtering | |
if self.config.score_threshold > 0: | |
all_scores = [max(score, self.config.score_threshold) for score in all_scores] | |
return all_scores | |
def _fallback_to_local(self, query_doc_pairs: List[List[str]]) -> List[float]: | |
""" | |
Fallback to local model when API fails. | |
Args: | |
query_doc_pairs: List of [query, document] pairs | |
Returns: | |
List of relevance scores | |
""" | |
logger.info("Attempting fallback to local sentence-transformers model") | |
# Temporarily switch to local backend | |
original_backend = self.config.backend | |
self.config.backend = "sentence_transformers" | |
try: | |
# Unload API client | |
self.model = None | |
self.info.loaded = False | |
# Load local model | |
if self.load_model(): | |
scores = self._predict_local(query_doc_pairs) | |
logger.info("Successfully fell back to local model") | |
return scores | |
else: | |
raise RuntimeError("Failed to load local fallback model") | |
finally: | |
# Restore original backend | |
self.config.backend = original_backend | |
def unload_model(self): | |
"""Unload the model to free memory.""" | |
with self._lock: | |
if self.info.loaded: | |
self.model = None | |
self.tokenizer = None | |
self.info.loaded = False | |
logger.info(f"Model {self.name} unloaded") | |
def get_info(self) -> ModelInfo: | |
"""Get model information.""" | |
return self.info | |
def get_stats(self) -> Dict[str, Any]: | |
"""Get model statistics.""" | |
avg_inference_time = 0.0 | |
if self.info.inference_count > 0: | |
avg_inference_time = self.info.total_inference_time / self.info.inference_count | |
return { | |
"name": self.name, | |
"loaded": self.info.loaded, | |
"inference_count": self.info.inference_count, | |
"avg_inference_time_ms": avg_inference_time * 1000, | |
"total_inference_time": self.info.total_inference_time, | |
"error_count": self.info.error_count, | |
"last_used": self.info.last_used | |
} | |
class CrossEncoderModels: | |
""" | |
Multi-model manager for cross-encoder models. | |
Manages multiple cross-encoder models with lazy loading, caching, | |
and automatic model selection based on configuration. | |
""" | |
def __init__(self, models_config: Dict[str, ModelConfig]): | |
""" | |
Initialize cross-encoder models manager. | |
Args: | |
models_config: Dictionary of model configurations | |
""" | |
self.models_config = models_config | |
self.managers: Dict[str, ModelManager] = {} | |
self.default_model = None | |
# Create model managers | |
for name, config in models_config.items(): | |
self.managers[name] = ModelManager(name, config) | |
# Set default model | |
if models_config: | |
self.default_model = list(models_config.keys())[0] | |
self.stats = { | |
"total_predictions": 0, | |
"model_switches": 0, | |
"cache_hits": 0, | |
"cache_misses": 0 | |
} | |
logger.info(f"CrossEncoderModels initialized with {len(models_config)} models") | |
def predict( | |
self, | |
query_doc_pairs: List[List[str]], | |
model_name: Optional[str] = None | |
) -> List[float]: | |
""" | |
Generate predictions using specified or default model. | |
Args: | |
query_doc_pairs: List of [query, document] pairs | |
model_name: Name of model to use (defaults to default_model) | |
Returns: | |
List of relevance scores | |
""" | |
if not query_doc_pairs: | |
return [] | |
# Select model | |
selected_model = model_name or self.default_model | |
if selected_model not in self.managers: | |
logger.warning(f"Model {selected_model} not found, using default") | |
selected_model = self.default_model | |
if not selected_model: | |
raise RuntimeError("No models available") | |
try: | |
manager = self.managers[selected_model] | |
scores = manager.predict(query_doc_pairs) | |
self.stats["total_predictions"] += 1 | |
return scores | |
except Exception as e: | |
logger.error(f"Prediction failed with model {selected_model}: {e}") | |
# Try fallback to default model if different | |
if selected_model != self.default_model: | |
logger.info(f"Trying fallback to default model: {self.default_model}") | |
return self.predict(query_doc_pairs, self.default_model) | |
else: | |
raise | |
def get_available_models(self) -> List[str]: | |
"""Get list of available model names.""" | |
return list(self.managers.keys()) | |
def is_model_loaded(self, model_name: str) -> bool: | |
"""Check if a model is loaded.""" | |
if model_name in self.managers: | |
return self.managers[model_name].info.loaded | |
return False | |
def load_model(self, model_name: str) -> bool: | |
""" | |
Load a specific model. | |
Args: | |
model_name: Name of model to load | |
Returns: | |
True if loaded successfully | |
""" | |
if model_name in self.managers: | |
return self.managers[model_name].load_model() | |
return False | |
def unload_model(self, model_name: str): | |
"""Unload a specific model.""" | |
if model_name in self.managers: | |
self.managers[model_name].unload_model() | |
def unload_all_models(self): | |
"""Unload all models to free memory.""" | |
for manager in self.managers.values(): | |
manager.unload_model() | |
def get_model_stats(self) -> Dict[str, Dict[str, Any]]: | |
"""Get statistics for all models.""" | |
return {name: manager.get_stats() for name, manager in self.managers.items()} | |
def get_stats(self) -> Dict[str, Any]: | |
"""Get overall statistics.""" | |
stats = self.stats.copy() | |
stats["models"] = self.get_model_stats() | |
stats["total_models"] = len(self.managers) | |
stats["loaded_models"] = sum(1 for m in self.managers.values() if m.info.loaded) | |
return stats | |
def reset_stats(self) -> None: | |
"""Reset statistics.""" | |
self.stats = { | |
"total_predictions": 0, | |
"model_switches": 0, | |
"cache_hits": 0, | |
"cache_misses": 0 | |
} |