|
""" |
|
Hybrid search engine combining vector similarity and BM25 keyword search. |
|
""" |
|
|
|
import re |
|
import time |
|
from typing import Any, Dict, List, Optional, Tuple, Set |
|
import numpy as np |
|
from rank_bm25 import BM25Okapi |
|
from collections import defaultdict, Counter |
|
import string |
|
|
|
from .error_handler import SearchError |
|
from .vector_store import VectorStore |
|
from .document_processor import DocumentChunk |
|
|
|
|
|
class SearchResult: |
|
"""Represents a search result with scoring details.""" |
|
|
|
def __init__( |
|
self, |
|
chunk_id: str, |
|
content: str, |
|
metadata: Dict[str, Any], |
|
vector_score: float = 0.0, |
|
bm25_score: float = 0.0, |
|
final_score: float = 0.0, |
|
rank: int = 0 |
|
): |
|
self.chunk_id = chunk_id |
|
self.content = content |
|
self.metadata = metadata |
|
self.vector_score = vector_score |
|
self.bm25_score = bm25_score |
|
self.final_score = final_score |
|
self.rank = rank |
|
|
|
def to_dict(self) -> Dict[str, Any]: |
|
"""Convert to dictionary representation.""" |
|
return { |
|
"chunk_id": self.chunk_id, |
|
"content": self.content, |
|
"metadata": self.metadata, |
|
"scores": { |
|
"vector_score": self.vector_score, |
|
"bm25_score": self.bm25_score, |
|
"final_score": self.final_score |
|
}, |
|
"rank": self.rank |
|
} |
|
|
|
|
|
class HybridSearchEngine: |
|
"""Hybrid search engine combining vector similarity and BM25 keyword search.""" |
|
|
|
def __init__(self, config: Dict[str, Any], vector_store: VectorStore): |
|
self.config = config |
|
self.search_config = config.get("search", {}) |
|
self.vector_store = vector_store |
|
|
|
|
|
self.default_k = self.search_config.get("default_k", 10) |
|
self.max_k = self.search_config.get("max_k", 20) |
|
self.vector_weight = self.search_config.get("vector_weight", 0.7) |
|
self.bm25_weight = self.search_config.get("bm25_weight", 0.3) |
|
|
|
|
|
self.bm25_index: Optional[BM25Okapi] = None |
|
self.bm25_corpus: List[List[str]] = [] |
|
self.chunk_id_to_index: Dict[str, int] = {} |
|
self.index_to_chunk_id: Dict[int, str] = {} |
|
self._bm25_built = False |
|
|
|
|
|
self.stopwords = self._load_stopwords() |
|
|
|
|
|
self.stats = { |
|
"searches_performed": 0, |
|
"total_search_time": 0, |
|
"vector_searches": 0, |
|
"bm25_searches": 0, |
|
"hybrid_searches": 0, |
|
"avg_results_returned": 0, |
|
"bm25_index_size": 0 |
|
} |
|
|
|
def _load_stopwords(self) -> Set[str]: |
|
"""Load common English stopwords.""" |
|
|
|
return { |
|
'a', 'an', 'and', 'are', 'as', 'at', 'be', 'by', 'for', 'from', |
|
'has', 'he', 'in', 'is', 'it', 'its', 'of', 'on', 'that', 'the', |
|
'to', 'was', 'will', 'with', 'had', 'have', 'this', 'these', 'they', |
|
'been', 'their', 'said', 'each', 'which', 'she', 'do', 'how', 'her', |
|
'my', 'me', 'we', 'us', 'our', 'you', 'your', 'him', 'his', 'all' |
|
} |
|
|
|
def build_bm25_index(self, chunks: List[DocumentChunk]) -> None: |
|
"""Build BM25 index from document chunks.""" |
|
if not chunks: |
|
self.bm25_index = None |
|
self.bm25_corpus = [] |
|
self.chunk_id_to_index = {} |
|
self.index_to_chunk_id = {} |
|
self._bm25_built = False |
|
return |
|
|
|
try: |
|
print(f"Building BM25 index for {len(chunks)} chunks...") |
|
start_time = time.time() |
|
|
|
|
|
self.bm25_corpus = [] |
|
self.chunk_id_to_index = {} |
|
self.index_to_chunk_id = {} |
|
|
|
|
|
for i, chunk in enumerate(chunks): |
|
|
|
tokens = self._tokenize_text(chunk.content) |
|
|
|
|
|
if not tokens: |
|
print(f"Warning: Empty tokens for chunk {chunk.chunk_id}, using fallback") |
|
tokens = ["content"] |
|
|
|
|
|
self.bm25_corpus.append(tokens) |
|
self.chunk_id_to_index[chunk.chunk_id] = i |
|
self.index_to_chunk_id[i] = chunk.chunk_id |
|
|
|
|
|
if not self.bm25_corpus: |
|
print("Warning: No valid content for BM25 index") |
|
self.bm25_index = None |
|
self._bm25_built = False |
|
return |
|
|
|
|
|
empty_docs = [i for i, doc in enumerate(self.bm25_corpus) if not doc] |
|
if empty_docs: |
|
print(f"Warning: Found {len(empty_docs)} empty documents, fixing...") |
|
for idx in empty_docs: |
|
self.bm25_corpus[idx] = ["content"] |
|
|
|
|
|
self.bm25_index = BM25Okapi(self.bm25_corpus) |
|
self._bm25_built = True |
|
|
|
build_time = time.time() - start_time |
|
self.stats["bm25_index_size"] = len(self.bm25_corpus) |
|
|
|
print(f"BM25 index built in {build_time:.2f}s with {len(self.bm25_corpus)} documents") |
|
|
|
except Exception as e: |
|
raise SearchError(f"Failed to build BM25 index: {str(e)}") from e |
|
|
|
def _tokenize_text(self, text: str) -> List[str]: |
|
"""Tokenize text for BM25 indexing.""" |
|
if not text or not text.strip(): |
|
return ["empty"] |
|
|
|
|
|
text = text.lower() |
|
|
|
|
|
text = re.sub(r'[^\w\s]', ' ', text) |
|
tokens = text.split() |
|
|
|
|
|
tokens = [ |
|
token for token in tokens |
|
if len(token) > 2 and token not in self.stopwords |
|
] |
|
|
|
|
|
if not tokens: |
|
tokens = ["content"] |
|
|
|
return tokens |
|
|
|
def search( |
|
self, |
|
query: str, |
|
k: int = None, |
|
search_mode: str = "hybrid", |
|
metadata_filter: Optional[Dict[str, Any]] = None, |
|
vector_weight: float = None, |
|
bm25_weight: float = None |
|
) -> List[SearchResult]: |
|
""" |
|
Perform search using specified mode. |
|
|
|
Args: |
|
query: Search query |
|
k: Number of results to return |
|
search_mode: "vector", "bm25", or "hybrid" |
|
metadata_filter: Optional metadata filter |
|
vector_weight: Weight for vector scores (hybrid mode) |
|
bm25_weight: Weight for BM25 scores (hybrid mode) |
|
|
|
Returns: |
|
List of SearchResult objects |
|
""" |
|
start_time = time.time() |
|
|
|
|
|
k = k if k is not None else self.default_k |
|
k = min(k, self.max_k) |
|
|
|
if not query or not query.strip(): |
|
return [] |
|
|
|
query = query.strip() |
|
|
|
try: |
|
if search_mode == "vector": |
|
results = self._vector_search(query, k, metadata_filter) |
|
self.stats["vector_searches"] += 1 |
|
elif search_mode == "bm25": |
|
results = self._bm25_search(query, k, metadata_filter) |
|
self.stats["bm25_searches"] += 1 |
|
elif search_mode == "hybrid": |
|
results = self._hybrid_search( |
|
query, k, metadata_filter, |
|
vector_weight or self.vector_weight, |
|
bm25_weight or self.bm25_weight |
|
) |
|
self.stats["hybrid_searches"] += 1 |
|
else: |
|
raise SearchError(f"Unknown search mode: {search_mode}") |
|
|
|
|
|
search_time = time.time() - start_time |
|
self.stats["searches_performed"] += 1 |
|
self.stats["total_search_time"] += search_time |
|
self.stats["avg_results_returned"] = ( |
|
(self.stats["avg_results_returned"] * (self.stats["searches_performed"] - 1) + len(results)) |
|
/ self.stats["searches_performed"] |
|
) |
|
|
|
return results |
|
|
|
except Exception as e: |
|
if isinstance(e, SearchError): |
|
raise |
|
else: |
|
raise SearchError(f"Search failed: {str(e)}") from e |
|
|
|
def _vector_search( |
|
self, |
|
query: str, |
|
k: int, |
|
metadata_filter: Optional[Dict[str, Any]] |
|
) -> List[SearchResult]: |
|
"""Perform vector similarity search.""" |
|
|
|
embedding_manager = getattr(self, '_embedding_manager', None) |
|
if embedding_manager is None: |
|
raise SearchError("Embedding manager not available for vector search") |
|
|
|
|
|
query_embeddings = embedding_manager.generate_embeddings([query], show_progress=False) |
|
if query_embeddings.size == 0: |
|
return [] |
|
|
|
query_embedding = query_embeddings[0] |
|
|
|
|
|
vector_results = self.vector_store.search( |
|
query_embedding, k=k*2, metadata_filter=metadata_filter |
|
) |
|
|
|
|
|
results = [] |
|
for i, (chunk_id, similarity, metadata) in enumerate(vector_results[:k]): |
|
content = metadata.get("content", "") |
|
|
|
|
|
if i < 3: |
|
content_preview = content[:100] + "..." if len(content) > 100 else content |
|
print(f"Vector Result {i}: chunk_id={chunk_id}, content_length={len(content)}, preview='{content_preview}'") |
|
|
|
result = SearchResult( |
|
chunk_id=chunk_id, |
|
content=content, |
|
metadata=metadata, |
|
vector_score=similarity, |
|
bm25_score=0.0, |
|
final_score=0.0, |
|
rank=i + 1 |
|
) |
|
results.append(result) |
|
|
|
|
|
if results: |
|
self._normalize_scores(results) |
|
for result in results: |
|
result.final_score = result.vector_score |
|
|
|
return results |
|
|
|
def _bm25_search( |
|
self, |
|
query: str, |
|
k: int, |
|
metadata_filter: Optional[Dict[str, Any]] |
|
) -> List[SearchResult]: |
|
"""Perform BM25 keyword search.""" |
|
if not self._bm25_built or self.bm25_index is None: |
|
raise SearchError("BM25 index not built. Please add documents first.") |
|
|
|
|
|
query_tokens = self._tokenize_text(query) |
|
if not query_tokens: |
|
return [] |
|
|
|
|
|
scores = self.bm25_index.get_scores(query_tokens) |
|
|
|
|
|
top_indices = np.argsort(scores)[::-1][:k*3] |
|
|
|
|
|
results = [] |
|
for i, idx in enumerate(top_indices): |
|
if len(results) >= k: |
|
break |
|
|
|
if idx >= len(self.index_to_chunk_id): |
|
continue |
|
|
|
chunk_id = self.index_to_chunk_id[idx] |
|
score = float(scores[idx]) |
|
|
|
if score <= 0: |
|
break |
|
|
|
|
|
chunk_data = self.vector_store.get_by_id(chunk_id) |
|
if chunk_data is None: |
|
continue |
|
|
|
_, metadata = chunk_data |
|
content = metadata.get("content", "") |
|
|
|
|
|
if metadata_filter and not self._matches_filter(metadata, metadata_filter): |
|
continue |
|
|
|
result = SearchResult( |
|
chunk_id=chunk_id, |
|
content=content, |
|
metadata=metadata, |
|
vector_score=0.0, |
|
bm25_score=score, |
|
final_score=0.0, |
|
rank=len(results) + 1 |
|
) |
|
results.append(result) |
|
|
|
|
|
if results: |
|
self._normalize_scores(results) |
|
for result in results: |
|
result.final_score = result.bm25_score |
|
|
|
return results |
|
|
|
def _hybrid_search( |
|
self, |
|
query: str, |
|
k: int, |
|
metadata_filter: Optional[Dict[str, Any]], |
|
vector_weight: float, |
|
bm25_weight: float |
|
) -> List[SearchResult]: |
|
"""Perform hybrid search combining vector and BM25 results.""" |
|
|
|
try: |
|
vector_results = self._vector_search(query, k*2, metadata_filter) |
|
except Exception as e: |
|
print(f"Vector search failed: {e}") |
|
vector_results = [] |
|
|
|
try: |
|
bm25_results = self._bm25_search(query, k*2, metadata_filter) |
|
except Exception as e: |
|
print(f"BM25 search failed: {e}") |
|
bm25_results = [] |
|
|
|
if not vector_results and not bm25_results: |
|
return [] |
|
|
|
|
|
combined_results: Dict[str, SearchResult] = {} |
|
|
|
|
|
for result in vector_results: |
|
combined_results[result.chunk_id] = SearchResult( |
|
chunk_id=result.chunk_id, |
|
content=result.content, |
|
metadata=result.metadata, |
|
vector_score=result.vector_score, |
|
bm25_score=0.0, |
|
final_score=0.0, |
|
rank=0 |
|
) |
|
|
|
|
|
for result in bm25_results: |
|
if result.chunk_id in combined_results: |
|
combined_results[result.chunk_id].bm25_score = result.bm25_score |
|
else: |
|
combined_results[result.chunk_id] = SearchResult( |
|
chunk_id=result.chunk_id, |
|
content=result.content, |
|
metadata=result.metadata, |
|
vector_score=0.0, |
|
bm25_score=result.bm25_score, |
|
final_score=0.0, |
|
rank=0 |
|
) |
|
|
|
|
|
self._normalize_scores(list(combined_results.values())) |
|
|
|
|
|
for result in combined_results.values(): |
|
result.final_score = ( |
|
vector_weight * result.vector_score + |
|
bm25_weight * result.bm25_score |
|
) |
|
|
|
|
|
sorted_results = sorted( |
|
combined_results.values(), |
|
key=lambda x: x.final_score, |
|
reverse=True |
|
) |
|
|
|
|
|
for i, result in enumerate(sorted_results): |
|
result.rank = i + 1 |
|
|
|
return sorted_results[:k] |
|
|
|
def _normalize_scores(self, results: List[SearchResult]) -> None: |
|
"""Normalize vector and BM25 scores to 0-1 range.""" |
|
if not results: |
|
return |
|
|
|
|
|
vector_scores = [r.vector_score for r in results] |
|
if vector_scores: |
|
min_vector = min(vector_scores) |
|
max_vector = max(vector_scores) |
|
if max_vector > min_vector: |
|
for result in results: |
|
result.vector_score = (result.vector_score - min_vector) / (max_vector - min_vector) |
|
elif max_vector == min_vector and max_vector != 0: |
|
|
|
for result in results: |
|
result.vector_score = 0.5 |
|
|
|
|
|
bm25_scores = [r.bm25_score for r in results if r.bm25_score > 0] |
|
if bm25_scores: |
|
min_bm25 = min(bm25_scores) |
|
max_bm25 = max(bm25_scores) |
|
if max_bm25 > min_bm25: |
|
for result in results: |
|
if result.bm25_score > 0: |
|
result.bm25_score = (result.bm25_score - min_bm25) / (max_bm25 - min_bm25) |
|
|
|
def _matches_filter(self, metadata: Dict[str, Any], filter_dict: Dict[str, Any]) -> bool: |
|
"""Check if metadata matches filter (same as vector_store implementation).""" |
|
for key, value in filter_dict.items(): |
|
if key not in metadata: |
|
return False |
|
|
|
metadata_value = metadata[key] |
|
|
|
if isinstance(value, dict): |
|
if "$gte" in value and metadata_value < value["$gte"]: |
|
return False |
|
if "$lte" in value and metadata_value > value["$lte"]: |
|
return False |
|
if "$in" in value and metadata_value not in value["$in"]: |
|
return False |
|
elif isinstance(value, list): |
|
if metadata_value not in value: |
|
return False |
|
else: |
|
if metadata_value != value: |
|
return False |
|
|
|
return True |
|
|
|
def suggest_query_expansion(self, query: str, top_results: List[SearchResult]) -> List[str]: |
|
"""Suggest query expansion terms based on top results.""" |
|
if not top_results: |
|
return [] |
|
|
|
|
|
all_text = " ".join([result.content for result in top_results[:3]]) |
|
tokens = self._tokenize_text(all_text) |
|
|
|
|
|
term_counts = Counter(tokens) |
|
|
|
|
|
query_tokens = set(self._tokenize_text(query)) |
|
suggestions = [] |
|
|
|
for term, count in term_counts.most_common(10): |
|
if term not in query_tokens and len(term) > 3: |
|
suggestions.append(term) |
|
|
|
return suggestions[:5] |
|
|
|
def get_stats(self) -> Dict[str, Any]: |
|
"""Get search engine statistics.""" |
|
stats = self.stats.copy() |
|
|
|
if stats["searches_performed"] > 0: |
|
stats["avg_search_time"] = stats["total_search_time"] / stats["searches_performed"] |
|
else: |
|
stats["avg_search_time"] = 0 |
|
|
|
stats["bm25_index_built"] = self._bm25_built |
|
stats["vector_store_stats"] = self.vector_store.get_stats() |
|
|
|
return stats |
|
|
|
def set_embedding_manager(self, embedding_manager) -> None: |
|
"""Set the embedding manager for vector search.""" |
|
self._embedding_manager = embedding_manager |
|
|
|
def optimize_index(self) -> Dict[str, Any]: |
|
"""Optimize search indices.""" |
|
optimization_results = {} |
|
|
|
|
|
if self.vector_store: |
|
vector_opt = self.vector_store.optimize() |
|
optimization_results["vector_store"] = vector_opt |
|
|
|
|
|
optimization_results["bm25_index"] = { |
|
"corpus_size": len(self.bm25_corpus), |
|
"index_built": self._bm25_built |
|
} |
|
|
|
return optimization_results |