Arthur Passuello
initial commit
5e1a30c
"""
Adaptive Strategies for Neural Reranking.
This module provides query-type aware reranking strategies that can
adapt model selection and parameters based on query characteristics
to optimize relevance and performance.
Migrated from reranking/ module and simplified for integration with
the enhanced neural reranker in the rerankers/ component.
"""
import logging
import re
import time
from typing import List, Dict, Any, Optional
from dataclasses import dataclass
logger = logging.getLogger(__name__)
@dataclass
class QueryAnalysis:
"""Results of query analysis."""
query_type: str
confidence: float
features: Dict[str, Any]
recommended_model: str
class QueryTypeDetector:
"""
Detects query types to enable adaptive reranking strategies.
Analyzes queries to classify them into categories like technical,
procedural, comparative, etc. to enable optimal model selection.
"""
def __init__(self, confidence_threshold: float = 0.7):
"""
Initialize query type detector.
Args:
confidence_threshold: Minimum confidence for classification
"""
self.confidence_threshold = confidence_threshold
self.stats = {
"classifications": 0,
"type_counts": {},
"high_confidence": 0,
"low_confidence": 0
}
# Model strategies for different query types
self.strategies = {
"technical": "technical_model",
"general": "default_model",
"comparative": "technical_model",
"procedural": "default_model",
"factual": "default_model"
}
# Define patterns for different query types
self.patterns = {
"technical": [
r'\b(api|protocol|implementation|configuration|architecture)\b',
r'\b(install|setup|configure|deploy)\b',
r'\b(error|exception|debug|troubleshoot)\b',
r'\b(version|compatibility|requirement)\b'
],
"procedural": [
r'\bhow to\b',
r'\bstep by step\b',
r'\bguide|tutorial|walkthrough\b',
r'\bprocess|procedure|workflow\b'
],
"comparative": [
r'\bvs\b|\bversus\b',
r'\bdifference between\b',
r'\bcompare|comparison\b',
r'\bbetter|best|worse|worst\b'
],
"factual": [
r'\bwhat is\b|\bwho is\b|\bwhere is\b',
r'\bdefine|definition\b',
r'\bexplain|describe\b'
],
"general": [] # Catch-all for queries that don't match other patterns
}
logger.info("QueryTypeDetector initialized with built-in patterns")
def classify_query(self, query: str) -> QueryAnalysis:
"""
Classify a query into a type category.
Args:
query: The search query to classify
Returns:
Query analysis with type, confidence, and features
"""
query_lower = query.lower()
type_scores = {}
# Calculate scores for each query type
for query_type, patterns in self.patterns.items():
if not patterns: # Skip empty pattern lists (like general)
continue
score = 0
matches = 0
for pattern in patterns:
if re.search(pattern, query_lower):
matches += 1
score += 1
# Normalize score by number of patterns
if patterns:
type_scores[query_type] = score / len(patterns)
# Find the best matching type
if type_scores:
best_type = max(type_scores.keys(), key=lambda k: type_scores[k])
confidence = type_scores[best_type]
else:
best_type = "general"
confidence = 0.5 # Default confidence for general queries
# Apply confidence threshold
if confidence < self.confidence_threshold:
best_type = "general"
confidence = 0.5
# Extract additional features
features = self._extract_features(query)
# Get recommended model
recommended_model = self.strategies.get(best_type, "default_model")
# Update statistics
self._update_stats(best_type, confidence)
return QueryAnalysis(
query_type=best_type,
confidence=confidence,
features=features,
recommended_model=recommended_model
)
def _extract_features(self, query: str) -> Dict[str, Any]:
"""Extract additional features from the query."""
features = {
"length": len(query),
"word_count": len(query.split()),
"has_question_mark": "?" in query,
"has_quotes": '"' in query or "'" in query,
"is_uppercase": query.isupper(),
"starts_with_question_word": query.lower().startswith(('what', 'how', 'when', 'where', 'why', 'who')),
"technical_terms": len([w for w in query.lower().split() if w in ['api', 'protocol', 'config', 'setup']])
}
return features
def _update_stats(self, query_type: str, confidence: float):
"""Update classification statistics."""
self.stats["classifications"] += 1
self.stats["type_counts"][query_type] = self.stats["type_counts"].get(query_type, 0) + 1
if confidence >= 0.8:
self.stats["high_confidence"] += 1
elif confidence < 0.5:
self.stats["low_confidence"] += 1
def get_stats(self) -> Dict[str, Any]:
"""Get classification statistics."""
return self.stats.copy()
class AdaptiveStrategies:
"""
Adaptive reranking strategies that adjust based on query characteristics.
This component analyzes queries and selects optimal models and parameters
to maximize relevance while maintaining performance targets.
"""
def __init__(
self,
enabled: bool = True,
confidence_threshold: float = 0.7,
enable_dynamic_switching: bool = False,
performance_window: int = 100,
quality_threshold: float = 0.8
):
"""
Initialize adaptive strategies.
Args:
enabled: Whether adaptive strategies are enabled
confidence_threshold: Minimum confidence for query classification
enable_dynamic_switching: Whether to enable performance-based model switching
performance_window: Number of queries to track for performance
quality_threshold: Quality threshold for model switching
"""
self.enabled = enabled
self.enable_dynamic_switching = enable_dynamic_switching
self.performance_window = performance_window
self.quality_threshold = quality_threshold
self.detector = QueryTypeDetector(confidence_threshold) if enabled else None
self.stats = {
"model_selections": 0,
"adaptations": 0,
"fallbacks": 0
}
# Performance tracking for adaptive adjustments
self.performance_history = []
logger.info(f"AdaptiveStrategies initialized, enabled={enabled}")
def select_model(
self,
query: str,
available_models: List[str],
default_model: str
) -> str:
"""
Select the optimal model for a given query.
Args:
query: The search query
available_models: List of available model names
default_model: Default model to fall back to
Returns:
Name of the selected model
"""
if not self.enabled or not self.detector:
return default_model
try:
# Classify the query
analysis = self.detector.classify_query(query)
# Get recommended model
recommended_model = analysis.recommended_model
# Check if recommended model is available
if recommended_model in available_models:
selected_model = recommended_model
else:
logger.warning(f"Recommended model {recommended_model} not available, using default")
selected_model = default_model
self.stats["fallbacks"] += 1
# Consider performance-based adaptations
if self.enable_dynamic_switching:
selected_model = self._consider_performance_adaptation(
selected_model, available_models, default_model
)
self.stats["model_selections"] += 1
logger.debug(f"Selected model '{selected_model}' for query type '{analysis.query_type}' "
f"(confidence: {analysis.confidence:.2f})")
return selected_model
except Exception as e:
logger.error(f"Model selection failed: {e}")
self.stats["fallbacks"] += 1
return default_model
def _consider_performance_adaptation(
self,
current_selection: str,
available_models: List[str],
default_model: str
) -> str:
"""Consider performance-based model adaptation."""
try:
# Check recent performance history
if len(self.performance_history) >= self.performance_window:
recent_performance = self.performance_history[-self.performance_window:]
# Calculate average quality for current selection
current_model_performance = [
p for p in recent_performance
if p.get("model") == current_selection
]
if current_model_performance:
avg_quality = sum(p.get("quality", 0) for p in current_model_performance) / len(current_model_performance)
# Switch if quality is below threshold
if avg_quality < self.quality_threshold:
logger.info(f"Switching from {current_selection} due to low quality: {avg_quality:.2f}")
self.stats["adaptations"] += 1
return default_model
return current_selection
except Exception as e:
logger.warning(f"Performance adaptation failed: {e}")
return current_selection
def adapt_parameters(
self,
query: str,
model_name: str,
base_config: Dict[str, Any]
) -> Dict[str, Any]:
"""
Adapt model parameters based on query characteristics.
Args:
query: The search query
model_name: Selected model name
base_config: Base model configuration
Returns:
Adapted configuration
"""
if not self.enabled:
return base_config
try:
adapted_config = base_config.copy()
# Adapt batch size based on query complexity
query_complexity = self._assess_query_complexity(query)
if query_complexity == "high":
adapted_config["batch_size"] = max(1, adapted_config.get("batch_size", 16) // 2)
elif query_complexity == "low":
adapted_config["batch_size"] = min(64, adapted_config.get("batch_size", 16) * 2)
# Adapt number of candidates based on query type
if self.detector:
analysis = self.detector.classify_query(query)
if analysis.query_type == "technical":
# Technical queries might benefit from more candidates
adapted_config["max_candidates"] = min(100, adapted_config.get("max_candidates", 50) * 1.5)
elif analysis.query_type == "factual":
# Factual queries might need fewer candidates
adapted_config["max_candidates"] = max(10, adapted_config.get("max_candidates", 50) // 2)
return adapted_config
except Exception as e:
logger.error(f"Parameter adaptation failed: {e}")
return base_config
def _assess_query_complexity(self, query: str) -> str:
"""Assess query complexity for parameter adaptation."""
word_count = len(query.split())
if word_count > 10:
return "high"
elif word_count < 3:
return "low"
else:
return "medium"
def record_performance(
self,
model: str,
query_type: str,
latency_ms: float,
quality_score: float
):
"""
Record performance metrics for adaptive learning.
Args:
model: Model used
query_type: Type of query
latency_ms: Processing latency
quality_score: Quality metric (0-1)
"""
performance_record = {
"model": model,
"query_type": query_type,
"latency_ms": latency_ms,
"quality": quality_score,
"timestamp": time.time()
}
self.performance_history.append(performance_record)
# Keep only recent history
max_history = self.performance_window * 2
if len(self.performance_history) > max_history:
self.performance_history = self.performance_history[-max_history:]
def get_stats(self) -> Dict[str, Any]:
"""Get adaptive strategies statistics."""
stats = self.stats.copy()
if self.detector:
stats["detector"] = self.detector.get_stats()
# Add performance summary
if self.performance_history:
recent_performance = self.performance_history[-50:] # Last 50 records
stats["recent_performance"] = {
"avg_latency_ms": sum(p["latency_ms"] for p in recent_performance) / len(recent_performance),
"avg_quality": sum(p["quality"] for p in recent_performance) / len(recent_performance),
"total_records": len(self.performance_history)
}
return stats
def reset_stats(self) -> None:
"""Reset adaptive strategies statistics."""
self.stats = {
"model_selections": 0,
"adaptations": 0,
"fallbacks": 0
}
if self.detector:
self.detector.stats = {
"classifications": 0,
"type_counts": {},
"high_confidence": 0,
"low_confidence": 0
}