Spaces:
Running
Running
""" | |
Adaptive Strategies for Neural Reranking. | |
This module provides query-type aware reranking strategies that can | |
adapt model selection and parameters based on query characteristics | |
to optimize relevance and performance. | |
Migrated from reranking/ module and simplified for integration with | |
the enhanced neural reranker in the rerankers/ component. | |
""" | |
import logging | |
import re | |
import time | |
from typing import List, Dict, Any, Optional | |
from dataclasses import dataclass | |
logger = logging.getLogger(__name__) | |
class QueryAnalysis: | |
"""Results of query analysis.""" | |
query_type: str | |
confidence: float | |
features: Dict[str, Any] | |
recommended_model: str | |
class QueryTypeDetector: | |
""" | |
Detects query types to enable adaptive reranking strategies. | |
Analyzes queries to classify them into categories like technical, | |
procedural, comparative, etc. to enable optimal model selection. | |
""" | |
def __init__(self, confidence_threshold: float = 0.7): | |
""" | |
Initialize query type detector. | |
Args: | |
confidence_threshold: Minimum confidence for classification | |
""" | |
self.confidence_threshold = confidence_threshold | |
self.stats = { | |
"classifications": 0, | |
"type_counts": {}, | |
"high_confidence": 0, | |
"low_confidence": 0 | |
} | |
# Model strategies for different query types | |
self.strategies = { | |
"technical": "technical_model", | |
"general": "default_model", | |
"comparative": "technical_model", | |
"procedural": "default_model", | |
"factual": "default_model" | |
} | |
# Define patterns for different query types | |
self.patterns = { | |
"technical": [ | |
r'\b(api|protocol|implementation|configuration|architecture)\b', | |
r'\b(install|setup|configure|deploy)\b', | |
r'\b(error|exception|debug|troubleshoot)\b', | |
r'\b(version|compatibility|requirement)\b' | |
], | |
"procedural": [ | |
r'\bhow to\b', | |
r'\bstep by step\b', | |
r'\bguide|tutorial|walkthrough\b', | |
r'\bprocess|procedure|workflow\b' | |
], | |
"comparative": [ | |
r'\bvs\b|\bversus\b', | |
r'\bdifference between\b', | |
r'\bcompare|comparison\b', | |
r'\bbetter|best|worse|worst\b' | |
], | |
"factual": [ | |
r'\bwhat is\b|\bwho is\b|\bwhere is\b', | |
r'\bdefine|definition\b', | |
r'\bexplain|describe\b' | |
], | |
"general": [] # Catch-all for queries that don't match other patterns | |
} | |
logger.info("QueryTypeDetector initialized with built-in patterns") | |
def classify_query(self, query: str) -> QueryAnalysis: | |
""" | |
Classify a query into a type category. | |
Args: | |
query: The search query to classify | |
Returns: | |
Query analysis with type, confidence, and features | |
""" | |
query_lower = query.lower() | |
type_scores = {} | |
# Calculate scores for each query type | |
for query_type, patterns in self.patterns.items(): | |
if not patterns: # Skip empty pattern lists (like general) | |
continue | |
score = 0 | |
matches = 0 | |
for pattern in patterns: | |
if re.search(pattern, query_lower): | |
matches += 1 | |
score += 1 | |
# Normalize score by number of patterns | |
if patterns: | |
type_scores[query_type] = score / len(patterns) | |
# Find the best matching type | |
if type_scores: | |
best_type = max(type_scores.keys(), key=lambda k: type_scores[k]) | |
confidence = type_scores[best_type] | |
else: | |
best_type = "general" | |
confidence = 0.5 # Default confidence for general queries | |
# Apply confidence threshold | |
if confidence < self.confidence_threshold: | |
best_type = "general" | |
confidence = 0.5 | |
# Extract additional features | |
features = self._extract_features(query) | |
# Get recommended model | |
recommended_model = self.strategies.get(best_type, "default_model") | |
# Update statistics | |
self._update_stats(best_type, confidence) | |
return QueryAnalysis( | |
query_type=best_type, | |
confidence=confidence, | |
features=features, | |
recommended_model=recommended_model | |
) | |
def _extract_features(self, query: str) -> Dict[str, Any]: | |
"""Extract additional features from the query.""" | |
features = { | |
"length": len(query), | |
"word_count": len(query.split()), | |
"has_question_mark": "?" in query, | |
"has_quotes": '"' in query or "'" in query, | |
"is_uppercase": query.isupper(), | |
"starts_with_question_word": query.lower().startswith(('what', 'how', 'when', 'where', 'why', 'who')), | |
"technical_terms": len([w for w in query.lower().split() if w in ['api', 'protocol', 'config', 'setup']]) | |
} | |
return features | |
def _update_stats(self, query_type: str, confidence: float): | |
"""Update classification statistics.""" | |
self.stats["classifications"] += 1 | |
self.stats["type_counts"][query_type] = self.stats["type_counts"].get(query_type, 0) + 1 | |
if confidence >= 0.8: | |
self.stats["high_confidence"] += 1 | |
elif confidence < 0.5: | |
self.stats["low_confidence"] += 1 | |
def get_stats(self) -> Dict[str, Any]: | |
"""Get classification statistics.""" | |
return self.stats.copy() | |
class AdaptiveStrategies: | |
""" | |
Adaptive reranking strategies that adjust based on query characteristics. | |
This component analyzes queries and selects optimal models and parameters | |
to maximize relevance while maintaining performance targets. | |
""" | |
def __init__( | |
self, | |
enabled: bool = True, | |
confidence_threshold: float = 0.7, | |
enable_dynamic_switching: bool = False, | |
performance_window: int = 100, | |
quality_threshold: float = 0.8 | |
): | |
""" | |
Initialize adaptive strategies. | |
Args: | |
enabled: Whether adaptive strategies are enabled | |
confidence_threshold: Minimum confidence for query classification | |
enable_dynamic_switching: Whether to enable performance-based model switching | |
performance_window: Number of queries to track for performance | |
quality_threshold: Quality threshold for model switching | |
""" | |
self.enabled = enabled | |
self.enable_dynamic_switching = enable_dynamic_switching | |
self.performance_window = performance_window | |
self.quality_threshold = quality_threshold | |
self.detector = QueryTypeDetector(confidence_threshold) if enabled else None | |
self.stats = { | |
"model_selections": 0, | |
"adaptations": 0, | |
"fallbacks": 0 | |
} | |
# Performance tracking for adaptive adjustments | |
self.performance_history = [] | |
logger.info(f"AdaptiveStrategies initialized, enabled={enabled}") | |
def select_model( | |
self, | |
query: str, | |
available_models: List[str], | |
default_model: str | |
) -> str: | |
""" | |
Select the optimal model for a given query. | |
Args: | |
query: The search query | |
available_models: List of available model names | |
default_model: Default model to fall back to | |
Returns: | |
Name of the selected model | |
""" | |
if not self.enabled or not self.detector: | |
return default_model | |
try: | |
# Classify the query | |
analysis = self.detector.classify_query(query) | |
# Get recommended model | |
recommended_model = analysis.recommended_model | |
# Check if recommended model is available | |
if recommended_model in available_models: | |
selected_model = recommended_model | |
else: | |
logger.warning(f"Recommended model {recommended_model} not available, using default") | |
selected_model = default_model | |
self.stats["fallbacks"] += 1 | |
# Consider performance-based adaptations | |
if self.enable_dynamic_switching: | |
selected_model = self._consider_performance_adaptation( | |
selected_model, available_models, default_model | |
) | |
self.stats["model_selections"] += 1 | |
logger.debug(f"Selected model '{selected_model}' for query type '{analysis.query_type}' " | |
f"(confidence: {analysis.confidence:.2f})") | |
return selected_model | |
except Exception as e: | |
logger.error(f"Model selection failed: {e}") | |
self.stats["fallbacks"] += 1 | |
return default_model | |
def _consider_performance_adaptation( | |
self, | |
current_selection: str, | |
available_models: List[str], | |
default_model: str | |
) -> str: | |
"""Consider performance-based model adaptation.""" | |
try: | |
# Check recent performance history | |
if len(self.performance_history) >= self.performance_window: | |
recent_performance = self.performance_history[-self.performance_window:] | |
# Calculate average quality for current selection | |
current_model_performance = [ | |
p for p in recent_performance | |
if p.get("model") == current_selection | |
] | |
if current_model_performance: | |
avg_quality = sum(p.get("quality", 0) for p in current_model_performance) / len(current_model_performance) | |
# Switch if quality is below threshold | |
if avg_quality < self.quality_threshold: | |
logger.info(f"Switching from {current_selection} due to low quality: {avg_quality:.2f}") | |
self.stats["adaptations"] += 1 | |
return default_model | |
return current_selection | |
except Exception as e: | |
logger.warning(f"Performance adaptation failed: {e}") | |
return current_selection | |
def adapt_parameters( | |
self, | |
query: str, | |
model_name: str, | |
base_config: Dict[str, Any] | |
) -> Dict[str, Any]: | |
""" | |
Adapt model parameters based on query characteristics. | |
Args: | |
query: The search query | |
model_name: Selected model name | |
base_config: Base model configuration | |
Returns: | |
Adapted configuration | |
""" | |
if not self.enabled: | |
return base_config | |
try: | |
adapted_config = base_config.copy() | |
# Adapt batch size based on query complexity | |
query_complexity = self._assess_query_complexity(query) | |
if query_complexity == "high": | |
adapted_config["batch_size"] = max(1, adapted_config.get("batch_size", 16) // 2) | |
elif query_complexity == "low": | |
adapted_config["batch_size"] = min(64, adapted_config.get("batch_size", 16) * 2) | |
# Adapt number of candidates based on query type | |
if self.detector: | |
analysis = self.detector.classify_query(query) | |
if analysis.query_type == "technical": | |
# Technical queries might benefit from more candidates | |
adapted_config["max_candidates"] = min(100, adapted_config.get("max_candidates", 50) * 1.5) | |
elif analysis.query_type == "factual": | |
# Factual queries might need fewer candidates | |
adapted_config["max_candidates"] = max(10, adapted_config.get("max_candidates", 50) // 2) | |
return adapted_config | |
except Exception as e: | |
logger.error(f"Parameter adaptation failed: {e}") | |
return base_config | |
def _assess_query_complexity(self, query: str) -> str: | |
"""Assess query complexity for parameter adaptation.""" | |
word_count = len(query.split()) | |
if word_count > 10: | |
return "high" | |
elif word_count < 3: | |
return "low" | |
else: | |
return "medium" | |
def record_performance( | |
self, | |
model: str, | |
query_type: str, | |
latency_ms: float, | |
quality_score: float | |
): | |
""" | |
Record performance metrics for adaptive learning. | |
Args: | |
model: Model used | |
query_type: Type of query | |
latency_ms: Processing latency | |
quality_score: Quality metric (0-1) | |
""" | |
performance_record = { | |
"model": model, | |
"query_type": query_type, | |
"latency_ms": latency_ms, | |
"quality": quality_score, | |
"timestamp": time.time() | |
} | |
self.performance_history.append(performance_record) | |
# Keep only recent history | |
max_history = self.performance_window * 2 | |
if len(self.performance_history) > max_history: | |
self.performance_history = self.performance_history[-max_history:] | |
def get_stats(self) -> Dict[str, Any]: | |
"""Get adaptive strategies statistics.""" | |
stats = self.stats.copy() | |
if self.detector: | |
stats["detector"] = self.detector.get_stats() | |
# Add performance summary | |
if self.performance_history: | |
recent_performance = self.performance_history[-50:] # Last 50 records | |
stats["recent_performance"] = { | |
"avg_latency_ms": sum(p["latency_ms"] for p in recent_performance) / len(recent_performance), | |
"avg_quality": sum(p["quality"] for p in recent_performance) / len(recent_performance), | |
"total_records": len(self.performance_history) | |
} | |
return stats | |
def reset_stats(self) -> None: | |
"""Reset adaptive strategies statistics.""" | |
self.stats = { | |
"model_selections": 0, | |
"adaptations": 0, | |
"fallbacks": 0 | |
} | |
if self.detector: | |
self.detector.stats = { | |
"classifications": 0, | |
"type_counts": {}, | |
"high_confidence": 0, | |
"low_confidence": 0 | |
} |