Spaces:
Running
Running
""" | |
Hybrid search engine combining vector similarity and BM25 keyword search. | |
""" | |
import re | |
import time | |
from typing import Any, Dict, List, Optional, Tuple, Set | |
import numpy as np | |
from rank_bm25 import BM25Okapi | |
from collections import defaultdict, Counter | |
import string | |
from .error_handler import SearchError | |
from .vector_store import VectorStore | |
from .document_processor import DocumentChunk | |
class SearchResult: | |
"""Represents a search result with scoring details.""" | |
def __init__( | |
self, | |
chunk_id: str, | |
content: str, | |
metadata: Dict[str, Any], | |
vector_score: float = 0.0, | |
bm25_score: float = 0.0, | |
final_score: float = 0.0, | |
rank: int = 0 | |
): | |
self.chunk_id = chunk_id | |
self.content = content | |
self.metadata = metadata | |
self.vector_score = vector_score | |
self.bm25_score = bm25_score | |
self.final_score = final_score | |
self.rank = rank | |
def to_dict(self) -> Dict[str, Any]: | |
"""Convert to dictionary representation.""" | |
return { | |
"chunk_id": self.chunk_id, | |
"content": self.content, | |
"metadata": self.metadata, | |
"scores": { | |
"vector_score": self.vector_score, | |
"bm25_score": self.bm25_score, | |
"final_score": self.final_score | |
}, | |
"rank": self.rank | |
} | |
class HybridSearchEngine: | |
"""Hybrid search engine combining vector similarity and BM25 keyword search.""" | |
def __init__(self, config: Dict[str, Any], vector_store: VectorStore): | |
self.config = config | |
self.search_config = config.get("search", {}) | |
self.vector_store = vector_store | |
# Search parameters | |
self.default_k = self.search_config.get("default_k", 10) | |
self.max_k = self.search_config.get("max_k", 20) | |
self.vector_weight = self.search_config.get("vector_weight", 0.7) | |
self.bm25_weight = self.search_config.get("bm25_weight", 0.3) | |
# BM25 setup | |
self.bm25_index: Optional[BM25Okapi] = None | |
self.bm25_corpus: List[List[str]] = [] | |
self.chunk_id_to_index: Dict[str, int] = {} | |
self.index_to_chunk_id: Dict[int, str] = {} | |
self._bm25_built = False | |
# Query processing | |
self.stopwords = self._load_stopwords() | |
# Statistics | |
self.stats = { | |
"searches_performed": 0, | |
"total_search_time": 0, | |
"vector_searches": 0, | |
"bm25_searches": 0, | |
"hybrid_searches": 0, | |
"avg_results_returned": 0, | |
"bm25_index_size": 0 | |
} | |
def _load_stopwords(self) -> Set[str]: | |
"""Load common English stopwords.""" | |
# Basic English stopwords - could be enhanced with NLTK | |
return { | |
'a', 'an', 'and', 'are', 'as', 'at', 'be', 'by', 'for', 'from', | |
'has', 'he', 'in', 'is', 'it', 'its', 'of', 'on', 'that', 'the', | |
'to', 'was', 'will', 'with', 'had', 'have', 'this', 'these', 'they', | |
'been', 'their', 'said', 'each', 'which', 'she', 'do', 'how', 'her', | |
'my', 'me', 'we', 'us', 'our', 'you', 'your', 'him', 'his', 'all' | |
} | |
def build_bm25_index(self, chunks: List[DocumentChunk]) -> None: | |
"""Build BM25 index from document chunks.""" | |
if not chunks: | |
self.bm25_index = None | |
self.bm25_corpus = [] | |
self.chunk_id_to_index = {} | |
self.index_to_chunk_id = {} | |
self._bm25_built = False | |
return | |
try: | |
print(f"Building BM25 index for {len(chunks)} chunks...") | |
start_time = time.time() | |
# Reset index data | |
self.bm25_corpus = [] | |
self.chunk_id_to_index = {} | |
self.index_to_chunk_id = {} | |
# Process chunks | |
for i, chunk in enumerate(chunks): | |
# Tokenize content | |
tokens = self._tokenize_text(chunk.content) | |
# Validate tokens | |
if not tokens: | |
print(f"Warning: Empty tokens for chunk {chunk.chunk_id}, using fallback") | |
tokens = ["content"] | |
# Store mappings | |
self.bm25_corpus.append(tokens) | |
self.chunk_id_to_index[chunk.chunk_id] = i | |
self.index_to_chunk_id[i] = chunk.chunk_id | |
# Validate corpus before building BM25 | |
if not self.bm25_corpus: | |
print("Warning: No valid content for BM25 index") | |
self.bm25_index = None | |
self._bm25_built = False | |
return | |
# Check if any document is empty | |
empty_docs = [i for i, doc in enumerate(self.bm25_corpus) if not doc] | |
if empty_docs: | |
print(f"Warning: Found {len(empty_docs)} empty documents, fixing...") | |
for idx in empty_docs: | |
self.bm25_corpus[idx] = ["content"] | |
# Build BM25 index | |
self.bm25_index = BM25Okapi(self.bm25_corpus) | |
self._bm25_built = True | |
build_time = time.time() - start_time | |
self.stats["bm25_index_size"] = len(self.bm25_corpus) | |
print(f"BM25 index built in {build_time:.2f}s with {len(self.bm25_corpus)} documents") | |
except Exception as e: | |
raise SearchError(f"Failed to build BM25 index: {str(e)}") from e | |
def _tokenize_text(self, text: str) -> List[str]: | |
"""Tokenize text for BM25 indexing.""" | |
if not text or not text.strip(): | |
return ["empty"] # Return a placeholder token for empty content | |
# Convert to lowercase | |
text = text.lower() | |
# Remove punctuation and split | |
text = re.sub(r'[^\w\s]', ' ', text) | |
tokens = text.split() | |
# Remove stopwords and very short tokens | |
tokens = [ | |
token for token in tokens | |
if len(token) > 2 and token not in self.stopwords | |
] | |
# Ensure we never return empty token list (causes division by zero in BM25) | |
if not tokens: | |
tokens = ["content"] # Fallback token for content with no valid tokens | |
return tokens | |
def search( | |
self, | |
query: str, | |
k: int = None, | |
search_mode: str = "hybrid", | |
metadata_filter: Optional[Dict[str, Any]] = None, | |
vector_weight: float = None, | |
bm25_weight: float = None | |
) -> List[SearchResult]: | |
""" | |
Perform search using specified mode. | |
Args: | |
query: Search query | |
k: Number of results to return | |
search_mode: "vector", "bm25", or "hybrid" | |
metadata_filter: Optional metadata filter | |
vector_weight: Weight for vector scores (hybrid mode) | |
bm25_weight: Weight for BM25 scores (hybrid mode) | |
Returns: | |
List of SearchResult objects | |
""" | |
start_time = time.time() | |
# Validate parameters | |
k = k if k is not None else self.default_k | |
k = min(k, self.max_k) | |
if not query or not query.strip(): | |
return [] | |
query = query.strip() | |
try: | |
if search_mode == "vector": | |
results = self._vector_search(query, k, metadata_filter) | |
self.stats["vector_searches"] += 1 | |
elif search_mode == "bm25": | |
results = self._bm25_search(query, k, metadata_filter) | |
self.stats["bm25_searches"] += 1 | |
elif search_mode == "hybrid": | |
results = self._hybrid_search( | |
query, k, metadata_filter, | |
vector_weight or self.vector_weight, | |
bm25_weight or self.bm25_weight | |
) | |
self.stats["hybrid_searches"] += 1 | |
else: | |
raise SearchError(f"Unknown search mode: {search_mode}") | |
# Update statistics | |
search_time = time.time() - start_time | |
self.stats["searches_performed"] += 1 | |
self.stats["total_search_time"] += search_time | |
self.stats["avg_results_returned"] = ( | |
(self.stats["avg_results_returned"] * (self.stats["searches_performed"] - 1) + len(results)) | |
/ self.stats["searches_performed"] | |
) | |
return results | |
except Exception as e: | |
if isinstance(e, SearchError): | |
raise | |
else: | |
raise SearchError(f"Search failed: {str(e)}") from e | |
def _vector_search( | |
self, | |
query: str, | |
k: int, | |
metadata_filter: Optional[Dict[str, Any]] | |
) -> List[SearchResult]: | |
"""Perform vector similarity search.""" | |
# Get embedding manager that was injected via set_embedding_manager | |
embedding_manager = getattr(self, '_embedding_manager', None) | |
if embedding_manager is None: | |
raise SearchError("Embedding manager not available for vector search") | |
# Generate query embedding | |
query_embeddings = embedding_manager.generate_embeddings([query], show_progress=False) | |
if query_embeddings.size == 0: | |
return [] | |
query_embedding = query_embeddings[0] | |
# Search vector store | |
vector_results = self.vector_store.search( | |
query_embedding, k=k*2, metadata_filter=metadata_filter | |
) | |
# Convert to SearchResult objects | |
results = [] | |
for i, (chunk_id, similarity, metadata) in enumerate(vector_results[:k]): | |
content = metadata.get("content", "") | |
# Debug: Log content info | |
if i < 3: # Only log first 3 results to avoid spam | |
content_preview = content[:100] + "..." if len(content) > 100 else content | |
print(f"Vector Result {i}: chunk_id={chunk_id}, content_length={len(content)}, preview='{content_preview}'") | |
result = SearchResult( | |
chunk_id=chunk_id, | |
content=content, | |
metadata=metadata, | |
vector_score=similarity, | |
bm25_score=0.0, | |
final_score=0.0, # Will be calculated after normalization | |
rank=i + 1 | |
) | |
results.append(result) | |
# Normalize scores and calculate final scores for vector-only mode | |
if results: | |
self._normalize_scores(results) | |
for result in results: | |
result.final_score = result.vector_score # For vector-only, final = vector | |
return results | |
def _bm25_search( | |
self, | |
query: str, | |
k: int, | |
metadata_filter: Optional[Dict[str, Any]] | |
) -> List[SearchResult]: | |
"""Perform BM25 keyword search.""" | |
if not self._bm25_built or self.bm25_index is None: | |
raise SearchError("BM25 index not built. Please add documents first.") | |
# Tokenize query | |
query_tokens = self._tokenize_text(query) | |
if not query_tokens: | |
return [] | |
# Get BM25 scores | |
scores = self.bm25_index.get_scores(query_tokens) | |
# Get top k indices | |
top_indices = np.argsort(scores)[::-1][:k*3] # Get more for filtering | |
# Convert to results and apply metadata filter | |
results = [] | |
for i, idx in enumerate(top_indices): | |
if len(results) >= k: | |
break | |
if idx >= len(self.index_to_chunk_id): | |
continue | |
chunk_id = self.index_to_chunk_id[idx] | |
score = float(scores[idx]) | |
if score <= 0: | |
break | |
# Get chunk data from vector store | |
chunk_data = self.vector_store.get_by_id(chunk_id) | |
if chunk_data is None: | |
continue | |
_, metadata = chunk_data | |
content = metadata.get("content", "") | |
# Apply metadata filter | |
if metadata_filter and not self._matches_filter(metadata, metadata_filter): | |
continue | |
result = SearchResult( | |
chunk_id=chunk_id, | |
content=content, | |
metadata=metadata, | |
vector_score=0.0, | |
bm25_score=score, | |
final_score=0.0, # Will be calculated after normalization | |
rank=len(results) + 1 | |
) | |
results.append(result) | |
# Normalize scores and calculate final scores for BM25-only mode | |
if results: | |
self._normalize_scores(results) | |
for result in results: | |
result.final_score = result.bm25_score # For BM25-only, final = bm25 | |
return results | |
def _hybrid_search( | |
self, | |
query: str, | |
k: int, | |
metadata_filter: Optional[Dict[str, Any]], | |
vector_weight: float, | |
bm25_weight: float | |
) -> List[SearchResult]: | |
"""Perform hybrid search combining vector and BM25 results.""" | |
# Get results from both methods | |
try: | |
vector_results = self._vector_search(query, k*2, metadata_filter) | |
except Exception as e: | |
print(f"Vector search failed: {e}") | |
vector_results = [] | |
try: | |
bm25_results = self._bm25_search(query, k*2, metadata_filter) | |
except Exception as e: | |
print(f"BM25 search failed: {e}") | |
bm25_results = [] | |
if not vector_results and not bm25_results: | |
return [] | |
# Combine results by chunk_id | |
combined_results: Dict[str, SearchResult] = {} | |
# Add vector results | |
for result in vector_results: | |
combined_results[result.chunk_id] = SearchResult( | |
chunk_id=result.chunk_id, | |
content=result.content, | |
metadata=result.metadata, | |
vector_score=result.vector_score, | |
bm25_score=0.0, | |
final_score=0.0, | |
rank=0 | |
) | |
# Add/merge BM25 results | |
for result in bm25_results: | |
if result.chunk_id in combined_results: | |
combined_results[result.chunk_id].bm25_score = result.bm25_score | |
else: | |
combined_results[result.chunk_id] = SearchResult( | |
chunk_id=result.chunk_id, | |
content=result.content, | |
metadata=result.metadata, | |
vector_score=0.0, | |
bm25_score=result.bm25_score, | |
final_score=0.0, | |
rank=0 | |
) | |
# Normalize scores | |
self._normalize_scores(list(combined_results.values())) | |
# Calculate final hybrid scores | |
for result in combined_results.values(): | |
result.final_score = ( | |
vector_weight * result.vector_score + | |
bm25_weight * result.bm25_score | |
) | |
# Sort by final score and return top k | |
sorted_results = sorted( | |
combined_results.values(), | |
key=lambda x: x.final_score, | |
reverse=True | |
) | |
# Update ranks | |
for i, result in enumerate(sorted_results): | |
result.rank = i + 1 | |
return sorted_results[:k] | |
def _normalize_scores(self, results: List[SearchResult]) -> None: | |
"""Normalize vector and BM25 scores to 0-1 range.""" | |
if not results: | |
return | |
# Normalize vector scores (handle negative scores like cosine similarity) | |
vector_scores = [r.vector_score for r in results] | |
if vector_scores: | |
min_vector = min(vector_scores) | |
max_vector = max(vector_scores) | |
if max_vector > min_vector: | |
for result in results: | |
result.vector_score = (result.vector_score - min_vector) / (max_vector - min_vector) | |
elif max_vector == min_vector and max_vector != 0: | |
# All scores are the same, normalize to 0.5 | |
for result in results: | |
result.vector_score = 0.5 | |
# Normalize BM25 scores (these should be positive) | |
bm25_scores = [r.bm25_score for r in results if r.bm25_score > 0] | |
if bm25_scores: | |
min_bm25 = min(bm25_scores) | |
max_bm25 = max(bm25_scores) | |
if max_bm25 > min_bm25: | |
for result in results: | |
if result.bm25_score > 0: | |
result.bm25_score = (result.bm25_score - min_bm25) / (max_bm25 - min_bm25) | |
def _matches_filter(self, metadata: Dict[str, Any], filter_dict: Dict[str, Any]) -> bool: | |
"""Check if metadata matches filter (same as vector_store implementation).""" | |
for key, value in filter_dict.items(): | |
if key not in metadata: | |
return False | |
metadata_value = metadata[key] | |
if isinstance(value, dict): | |
if "$gte" in value and metadata_value < value["$gte"]: | |
return False | |
if "$lte" in value and metadata_value > value["$lte"]: | |
return False | |
if "$in" in value and metadata_value not in value["$in"]: | |
return False | |
elif isinstance(value, list): | |
if metadata_value not in value: | |
return False | |
else: | |
if metadata_value != value: | |
return False | |
return True | |
def suggest_query_expansion(self, query: str, top_results: List[SearchResult]) -> List[str]: | |
"""Suggest query expansion terms based on top results.""" | |
if not top_results: | |
return [] | |
# Extract terms from top results | |
all_text = " ".join([result.content for result in top_results[:3]]) | |
tokens = self._tokenize_text(all_text) | |
# Count term frequency | |
term_counts = Counter(tokens) | |
# Filter out query terms and get most frequent | |
query_tokens = set(self._tokenize_text(query)) | |
suggestions = [] | |
for term, count in term_counts.most_common(10): | |
if term not in query_tokens and len(term) > 3: | |
suggestions.append(term) | |
return suggestions[:5] | |
def get_stats(self) -> Dict[str, Any]: | |
"""Get search engine statistics.""" | |
stats = self.stats.copy() | |
if stats["searches_performed"] > 0: | |
stats["avg_search_time"] = stats["total_search_time"] / stats["searches_performed"] | |
else: | |
stats["avg_search_time"] = 0 | |
stats["bm25_index_built"] = self._bm25_built | |
stats["vector_store_stats"] = self.vector_store.get_stats() | |
return stats | |
def set_embedding_manager(self, embedding_manager) -> None: | |
"""Set the embedding manager for vector search.""" | |
self._embedding_manager = embedding_manager | |
def optimize_index(self) -> Dict[str, Any]: | |
"""Optimize search indices.""" | |
optimization_results = {} | |
# Optimize vector store | |
if self.vector_store: | |
vector_opt = self.vector_store.optimize() | |
optimization_results["vector_store"] = vector_opt | |
# Could add BM25 optimization here | |
optimization_results["bm25_index"] = { | |
"corpus_size": len(self.bm25_corpus), | |
"index_built": self._bm25_built | |
} | |
return optimization_results |