Spaces:
Running
Running
""" | |
Graph-Enhanced Fusion Strategy for Architecture-Compliant Graph Integration. | |
This module provides a fusion strategy that properly integrates graph-based | |
retrieval signals with dense and sparse retrieval results, following the | |
proper sub-component architecture patterns. | |
This replaces the misplaced graph/ component with proper fusion sub-component | |
enhancement. | |
""" | |
import logging | |
import time | |
from typing import List, Dict, Any, Tuple, Optional, Union | |
from collections import defaultdict | |
import numpy as np | |
from .base import FusionStrategy | |
from .rrf_fusion import RRFFusion | |
from src.core.interfaces import Document, RetrievalResult | |
# Import spaCy for entity extraction | |
try: | |
import spacy | |
SPACY_AVAILABLE = True | |
# Try to load the English model | |
try: | |
_nlp = spacy.load("en_core_web_sm") | |
NLP_MODEL_AVAILABLE = True | |
except IOError: | |
NLP_MODEL_AVAILABLE = False | |
_nlp = None | |
except ImportError: | |
SPACY_AVAILABLE = False | |
NLP_MODEL_AVAILABLE = False | |
_nlp = None | |
logger = logging.getLogger(__name__) | |
class GraphEnhancedFusionError(Exception): | |
"""Raised when graph-enhanced fusion operations fail.""" | |
pass | |
class GraphEnhancedRRFFusion(FusionStrategy): | |
""" | |
Graph-enhanced RRF fusion strategy with sophisticated capabilities. | |
This fusion strategy extends standard RRF to incorporate graph-based | |
retrieval signals as a third input stream, providing enhanced relevance | |
through document relationship analysis. | |
The implementation follows proper architecture patterns by enhancing | |
the existing fusion sub-component rather than creating a separate | |
graph component. | |
Features: | |
- ✅ Standard RRF fusion (dense + sparse) | |
- ✅ Graph signal integration (third stream) | |
- ✅ Configurable fusion weights | |
- ✅ Entity-based document scoring | |
- ✅ Relationship-aware relevance boosting | |
- ✅ Performance optimization with caching | |
- ✅ Graceful degradation without graph signals | |
Architecture Compliance: | |
- Properly located in fusion/ sub-component ✅ | |
- Extends existing FusionStrategy interface ✅ | |
- Direct implementation (no external APIs) ✅ | |
- Backward compatible with existing fusion ✅ | |
Example: | |
config = { | |
"base_fusion": { | |
"k": 60, | |
"weights": {"dense": 0.6, "sparse": 0.3} | |
}, | |
"graph_enhancement": { | |
"enabled": True, | |
"graph_weight": 0.1, | |
"entity_boost": 0.15, | |
"relationship_boost": 0.1 | |
} | |
} | |
fusion = GraphEnhancedRRFFusion(config) | |
results = fusion.fuse_results(dense_results, sparse_results) | |
""" | |
def __init__(self, config: Dict[str, Any]): | |
""" | |
Initialize graph-enhanced RRF fusion strategy. | |
Args: | |
config: Configuration dictionary for graph-enhanced fusion | |
""" | |
self.config = config | |
# Initialize base RRF fusion | |
base_config = config.get("base_fusion", { | |
"k": 60, | |
"weights": {"dense": 0.7, "sparse": 0.3} | |
}) | |
self.base_fusion = RRFFusion(base_config) | |
# Graph enhancement configuration | |
self.graph_config = config.get("graph_enhancement", { | |
"enabled": True, | |
"graph_weight": 0.1, | |
"entity_boost": 0.15, | |
"relationship_boost": 0.1, | |
"similarity_threshold": 0.7, | |
"max_graph_hops": 3 | |
}) | |
# Performance tracking | |
self.stats = { | |
"total_fusions": 0, | |
"graph_enhanced_fusions": 0, | |
"entity_boosts_applied": 0, | |
"relationship_boosts_applied": 0, | |
"avg_graph_latency_ms": 0.0, | |
"total_graph_latency_ms": 0.0, | |
"fallback_activations": 0 | |
} | |
# Graph analysis components (lightweight, self-contained) | |
self.entity_cache = {} | |
self.relationship_cache = {} | |
# Document store for entity/relationship analysis | |
self.documents = [] | |
self.query_cache = {} | |
# Entity extraction setup | |
self.nlp = _nlp if NLP_MODEL_AVAILABLE else None | |
if not NLP_MODEL_AVAILABLE and self.graph_config.get("enabled", True): | |
logger.warning("spaCy model not available, falling back to keyword matching for entity extraction") | |
logger.info(f"GraphEnhancedRRFFusion initialized with graph_enabled={self.graph_config['enabled']}") | |
def set_documents_and_query(self, documents: List[Document], query: str) -> None: | |
""" | |
Set the documents and current query for entity/relationship analysis. | |
Args: | |
documents: List of documents being processed | |
query: Current query string | |
""" | |
self.documents = documents | |
self.current_query = query | |
# Clear query-specific caches | |
self.query_cache = {} | |
def fuse_results( | |
self, | |
dense_results: List[Tuple[int, float]], | |
sparse_results: List[Tuple[int, float]] | |
) -> List[Tuple[int, float]]: | |
""" | |
Fuse dense and sparse results with graph enhancement. | |
This method maintains backward compatibility with the standard | |
FusionStrategy interface while adding graph signal support | |
when available. | |
Args: | |
dense_results: List of (document_index, score) from dense retrieval | |
sparse_results: List of (document_index, score) from sparse retrieval | |
Returns: | |
List of (document_index, fused_score) tuples sorted by score | |
""" | |
start_time = time.time() | |
self.stats["total_fusions"] += 1 | |
try: | |
# Step 1: Apply base RRF fusion (dense + sparse) | |
base_fused = self.base_fusion.fuse_results(dense_results, sparse_results) | |
# Step 2: Apply graph enhancement if enabled | |
if self.graph_config.get("enabled", True): | |
enhanced_results = self._apply_graph_enhancement( | |
base_fused, dense_results, sparse_results | |
) | |
self.stats["graph_enhanced_fusions"] += 1 | |
else: | |
enhanced_results = base_fused | |
# Step 3: Update performance statistics | |
graph_latency_ms = (time.time() - start_time) * 1000 | |
self._update_stats(graph_latency_ms) | |
return enhanced_results | |
except Exception as e: | |
logger.error(f"Graph-enhanced fusion failed: {e}") | |
self.stats["fallback_activations"] += 1 | |
# Fallback to base fusion | |
return self.base_fusion.fuse_results(dense_results, sparse_results) | |
def _apply_graph_enhancement( | |
self, | |
base_results: List[Tuple[int, float]], | |
dense_results: List[Tuple[int, float]], | |
sparse_results: List[Tuple[int, float]] | |
) -> List[Tuple[int, float]]: | |
""" | |
Apply graph-based enhancement with proper score scaling. | |
Fixes the scale mismatch where tiny RRF scores (~0.016) were | |
overwhelmed by large graph enhancements (~0.075), destroying | |
ranking discrimination. | |
Args: | |
base_results: Base RRF fusion results | |
dense_results: Original dense retrieval results | |
sparse_results: Original sparse retrieval results | |
Returns: | |
Graph-enhanced fusion results with proper score scaling | |
""" | |
try: | |
if not base_results: | |
return base_results | |
# Extract base scores and calculate range | |
base_scores = {doc_idx: score for doc_idx, score in base_results} | |
min_base = min(base_scores.values()) | |
max_base = max(base_scores.values()) | |
base_range = max_base - min_base | |
logger.debug(f"Base score range: {min_base:.6f} - {max_base:.6f} (spread: {base_range:.6f})") | |
# If base range is too small, normalize scores to improve discrimination | |
if base_range < 0.01: # Very small range indicates poor discrimination | |
logger.debug(f"Small base range detected, applying normalization") | |
# Normalize to [0.1, 1.0] range to preserve ranking while improving discrimination | |
normalized_scores = {} | |
for doc_idx, score in base_scores.items(): | |
if base_range > 0: | |
normalized = 0.1 + 0.9 * (score - min_base) / base_range | |
else: | |
normalized = 0.55 # Mid-range if all scores identical | |
normalized_scores[doc_idx] = normalized | |
base_scores = normalized_scores | |
min_base = min(base_scores.values()) | |
max_base = max(base_scores.values()) | |
base_range = max_base - min_base | |
logger.debug(f"Normalized score range: {min_base:.6f} - {max_base:.6f} (spread: {base_range:.6f})") | |
# Extract all document indices | |
all_doc_indices = set(base_scores.keys()) | |
for doc_idx, _ in dense_results: | |
all_doc_indices.add(doc_idx) | |
for doc_idx, _ in sparse_results: | |
all_doc_indices.add(doc_idx) | |
# Calculate graph enhancements | |
entity_boosts = self._calculate_entity_boosts(list(all_doc_indices)) | |
relationship_boosts = self._calculate_relationship_boosts(list(all_doc_indices)) | |
# Scale graph enhancements proportionally to base score range | |
graph_weight = self.graph_config.get("graph_weight", 0.1) | |
max_possible_enhancement = 0.25 # Max entity (0.15) + relationship (0.1) boost | |
# Scale enhancement to be proportional to base score range | |
# This ensures graph enhancement doesn't dominate but still provides meaningful boost | |
enhancement_scale = min(base_range * 0.5, max_possible_enhancement) # Max 50% of base range | |
actual_scale_factor = enhancement_scale / max_possible_enhancement if max_possible_enhancement > 0 else 0 | |
logger.debug(f"Graph enhancement scaling: weight={graph_weight}, scale={enhancement_scale:.6f}, factor={actual_scale_factor:.3f}") | |
# Apply scaled enhancements | |
enhanced_scores = {} | |
enhancements_applied = 0 | |
for doc_idx in base_scores: | |
base_score = base_scores[doc_idx] | |
entity_boost = entity_boosts.get(doc_idx, 0.0) | |
relationship_boost = relationship_boosts.get(doc_idx, 0.0) | |
# Scale the enhancement to be proportional to base score range | |
raw_enhancement = (entity_boost + relationship_boost) * graph_weight | |
scaled_enhancement = raw_enhancement * actual_scale_factor | |
final_score = base_score + scaled_enhancement | |
# Ensure scores don't exceed 1.0 to maintain compatibility | |
final_score = min(final_score, 1.0) | |
enhanced_scores[doc_idx] = final_score | |
if scaled_enhancement > 0: | |
enhancements_applied += 1 | |
# Track statistics | |
if entity_boost > 0: | |
self.stats["entity_boosts_applied"] += 1 | |
if relationship_boost > 0: | |
self.stats["relationship_boosts_applied"] += 1 | |
logger.debug(f"Applied enhancements to {enhancements_applied} documents") | |
# Sort and return | |
enhanced_results = sorted(enhanced_scores.items(), key=lambda x: x[1], reverse=True) | |
# Final score range analysis | |
if enhanced_results: | |
final_min = min(score for _, score in enhanced_results) | |
final_max = max(score for _, score in enhanced_results) | |
final_range = final_max - final_min | |
discrimination_improvement = final_range / base_range if base_range > 0 else 1.0 | |
logger.debug(f"Final score range: {final_min:.6f} - {final_max:.6f} (spread: {final_range:.6f})") | |
logger.debug(f"Discrimination improvement: {discrimination_improvement:.2f}x") | |
return enhanced_results | |
except Exception as e: | |
logger.error(f"Graph enhancement failed: {e}") | |
return base_results | |
def _calculate_entity_boosts(self, doc_indices: List[int]) -> Dict[int, float]: | |
""" | |
Calculate entity-based scoring boosts for documents using real entity extraction. | |
Uses spaCy NLP to extract entities from query and documents, then calculates | |
overlap-based boost scores. Falls back to keyword matching if spaCy unavailable. | |
Args: | |
doc_indices: List of document indices to analyze | |
Returns: | |
Dictionary mapping doc_index to entity boost score | |
""" | |
entity_boosts = {} | |
try: | |
entity_boost_value = self.graph_config.get("entity_boost", 0.15) | |
# Extract query entities once per query | |
query_cache_key = f"query_entities:{getattr(self, 'current_query', '')}" | |
if query_cache_key in self.query_cache: | |
query_entities = self.query_cache[query_cache_key] | |
else: | |
query_entities = self._extract_entities(getattr(self, 'current_query', '')) | |
self.query_cache[query_cache_key] = query_entities | |
# Skip if no query entities found | |
if not query_entities: | |
return {doc_idx: 0.0 for doc_idx in doc_indices} | |
for doc_idx in doc_indices: | |
# Check cache first | |
cache_key = f"entity:{doc_idx}:{hash(frozenset(query_entities))}" | |
if cache_key in self.entity_cache: | |
entity_boosts[doc_idx] = self.entity_cache[cache_key] | |
continue | |
# Get document content | |
if doc_idx < len(self.documents): | |
doc_content = self.documents[doc_idx].content | |
# Extract document entities | |
doc_entities = self._extract_entities(doc_content) | |
# Calculate entity overlap score | |
if query_entities and doc_entities: | |
overlap = len(query_entities & doc_entities) | |
overlap_ratio = overlap / len(query_entities) | |
boost = overlap_ratio * entity_boost_value | |
else: | |
boost = 0.0 | |
else: | |
boost = 0.0 | |
# Cache the result | |
self.entity_cache[cache_key] = boost | |
entity_boosts[doc_idx] = boost | |
return entity_boosts | |
except Exception as e: | |
logger.warning(f"Entity boost calculation failed: {e}") | |
return {doc_idx: 0.0 for doc_idx in doc_indices} | |
def _extract_entities(self, text: str) -> set: | |
""" | |
Extract entities from text using spaCy or fallback to keyword matching. | |
Args: | |
text: Text to extract entities from | |
Returns: | |
Set of entity strings (normalized to lowercase) | |
""" | |
if not text: | |
return set() | |
entities = set() | |
try: | |
if self.nlp and NLP_MODEL_AVAILABLE: | |
# Use spaCy for real entity extraction | |
doc = self.nlp(text) | |
for ent in doc.ents: | |
# Focus on relevant entity types for technical documents | |
if ent.label_ in ['ORG', 'PRODUCT', 'PERSON', 'GPE', 'MONEY', 'CARDINAL']: | |
entities.add(ent.text.lower().strip()) | |
# Also extract technical terms (capitalized words, acronyms, etc.) | |
for token in doc: | |
# Technical terms: all caps (>=2 chars), camelCase, or specific patterns | |
if (token.text.isupper() and len(token.text) >= 2) or \ | |
(token.text[0].isupper() and any(c.isupper() for c in token.text[1:])) or \ | |
any(tech_pattern in token.text.lower() for tech_pattern in | |
['risc', 'cisc', 'cpu', 'gpu', 'arm', 'x86', 'isa', 'api']): | |
entities.add(token.text.lower().strip()) | |
else: | |
# Fallback: extract technical keywords and patterns | |
import re | |
# Technical acronyms and terms | |
tech_patterns = [ | |
r'\b[A-Z]{2,}\b', # All caps 2+ chars (RISC, CISC, ARM, x86) | |
r'\b[A-Z][a-z]*[A-Z][A-Za-z]*\b', # CamelCase | |
r'\bRV\d+[A-Z]*\b', # RISC-V variants (RV32I, RV64I) | |
r'\b[Aa]rm[vV]\d+\b', # ARM versions | |
r'\b[Xx]86\b', # x86 variants | |
] | |
for pattern in tech_patterns: | |
matches = re.findall(pattern, text) | |
entities.update(match.lower().strip() for match in matches) | |
# Common technical terms | |
tech_terms = ['risc', 'cisc', 'arm', 'intel', 'amd', 'qualcomm', 'apple', | |
'samsung', 'berkeley', 'processor', 'cpu', 'gpu', 'architecture', | |
'instruction', 'set', 'pipelining', 'cache', 'memory'] | |
words = text.lower().split() | |
entities.update(term for term in tech_terms if term in words) | |
except Exception as e: | |
logger.warning(f"Entity extraction failed: {e}") | |
return entities | |
def _calculate_relationship_boosts(self, doc_indices: List[int]) -> Dict[int, float]: | |
""" | |
Calculate relationship-based scoring boosts using semantic similarity analysis. | |
Uses document embeddings to calculate centrality scores in the semantic | |
similarity graph, boosting documents that are central to the result set. | |
Args: | |
doc_indices: List of document indices to analyze | |
Returns: | |
Dictionary mapping doc_index to relationship boost score | |
""" | |
relationship_boosts = {} | |
try: | |
relationship_boost_value = self.graph_config.get("relationship_boost", 0.1) | |
similarity_threshold = self.graph_config.get("similarity_threshold", 0.7) | |
# Need at least 2 documents for relationship analysis | |
if len(doc_indices) < 2: | |
return {doc_idx: 0.0 for doc_idx in doc_indices} | |
# Get document embeddings for similarity calculation | |
doc_embeddings = [] | |
valid_indices = [] | |
for doc_idx in doc_indices: | |
if doc_idx < len(self.documents) and hasattr(self.documents[doc_idx], 'embedding'): | |
doc_embeddings.append(self.documents[doc_idx].embedding) | |
valid_indices.append(doc_idx) | |
# Skip if we don't have enough embeddings | |
if len(doc_embeddings) < 2: | |
return {doc_idx: 0.0 for doc_idx in doc_indices} | |
# Calculate similarity matrix | |
embeddings_array = np.array(doc_embeddings) | |
if embeddings_array.ndim == 1: | |
embeddings_array = embeddings_array.reshape(1, -1) | |
# Normalize embeddings for cosine similarity | |
norms = np.linalg.norm(embeddings_array, axis=1, keepdims=True) | |
norms[norms == 0] = 1 # Avoid division by zero | |
normalized_embeddings = embeddings_array / norms | |
# Calculate cosine similarity matrix | |
similarity_matrix = np.dot(normalized_embeddings, normalized_embeddings.T) | |
# Calculate centrality scores (sum of similarities above threshold) | |
centrality_scores = [] | |
for i in range(len(similarity_matrix)): | |
# Count strong connections (similarity above threshold) | |
strong_connections = np.sum(similarity_matrix[i] > similarity_threshold) | |
# Weight by average similarity to other documents | |
avg_similarity = np.mean(similarity_matrix[i]) | |
centrality_score = (strong_connections * 0.6) + (avg_similarity * 0.4) | |
centrality_scores.append(centrality_score) | |
# Normalize centrality scores | |
if centrality_scores: | |
max_centrality = max(centrality_scores) | |
if max_centrality > 0: | |
centrality_scores = [score / max_centrality for score in centrality_scores] | |
# Apply relationship boosts | |
for i, doc_idx in enumerate(valid_indices): | |
# Check cache first | |
cache_key = f"relationship:{doc_idx}:{len(valid_indices)}" | |
if cache_key in self.relationship_cache: | |
relationship_boosts[doc_idx] = self.relationship_cache[cache_key] | |
continue | |
centrality_score = centrality_scores[i] if i < len(centrality_scores) else 0.0 | |
boost = centrality_score * relationship_boost_value | |
# Cache the result | |
self.relationship_cache[cache_key] = boost | |
relationship_boosts[doc_idx] = boost | |
# Fill in zero boosts for documents without embeddings | |
for doc_idx in doc_indices: | |
if doc_idx not in relationship_boosts: | |
relationship_boosts[doc_idx] = 0.0 | |
return relationship_boosts | |
except Exception as e: | |
logger.warning(f"Relationship boost calculation failed: {e}") | |
return {doc_idx: 0.0 for doc_idx in doc_indices} | |
def _update_stats(self, graph_latency_ms: float) -> None: | |
"""Update performance statistics.""" | |
self.stats["total_graph_latency_ms"] += graph_latency_ms | |
if self.stats["graph_enhanced_fusions"] > 0: | |
self.stats["avg_graph_latency_ms"] = ( | |
self.stats["total_graph_latency_ms"] / self.stats["graph_enhanced_fusions"] | |
) | |
def get_strategy_info(self) -> Dict[str, Any]: | |
""" | |
Get information about the graph-enhanced fusion strategy. | |
Returns: | |
Dictionary with strategy configuration and statistics | |
""" | |
base_info = self.base_fusion.get_strategy_info() | |
enhanced_info = { | |
"type": "graph_enhanced_rrf", | |
"base_strategy": base_info, | |
"graph_enabled": self.graph_config.get("enabled", True), | |
"graph_weight": self.graph_config.get("graph_weight", 0.1), | |
"entity_boost": self.graph_config.get("entity_boost", 0.15), | |
"relationship_boost": self.graph_config.get("relationship_boost", 0.1), | |
"statistics": self.stats.copy() | |
} | |
# Add performance metrics | |
if self.stats["total_fusions"] > 0: | |
enhanced_info["graph_enhancement_rate"] = ( | |
self.stats["graph_enhanced_fusions"] / self.stats["total_fusions"] | |
) | |
if self.stats["graph_enhanced_fusions"] > 0: | |
enhanced_info["avg_graph_latency_ms"] = self.stats["avg_graph_latency_ms"] | |
return enhanced_info | |
def enable_graph_enhancement(self) -> None: | |
"""Enable graph enhancement features.""" | |
self.graph_config["enabled"] = True | |
logger.info("Graph enhancement enabled") | |
def disable_graph_enhancement(self) -> None: | |
"""Disable graph enhancement features.""" | |
self.graph_config["enabled"] = False | |
logger.info("Graph enhancement disabled") | |
def clear_caches(self) -> None: | |
"""Clear entity and relationship caches.""" | |
self.entity_cache.clear() | |
self.relationship_cache.clear() | |
logger.info("Graph enhancement caches cleared") | |
def get_performance_stats(self) -> Dict[str, Any]: | |
""" | |
Get detailed performance statistics. | |
Returns: | |
Dictionary with performance metrics | |
""" | |
return { | |
**self.stats, | |
"cache_sizes": { | |
"entity_cache": len(self.entity_cache), | |
"relationship_cache": len(self.relationship_cache) | |
}, | |
"base_fusion_stats": self.base_fusion.get_strategy_info() | |
} |