Spaces:
Running
Running
""" | |
BM25 Sparse Retriever implementation for Modular Retriever Architecture. | |
This module provides a direct implementation of BM25 sparse retrieval | |
extracted from the existing sparse retrieval system for improved modularity. | |
""" | |
import logging | |
import re | |
import time | |
from typing import List, Dict, Any, Tuple, Optional | |
import numpy as np | |
from rank_bm25 import BM25Okapi | |
from src.core.interfaces import Document | |
from .base import SparseRetriever | |
logger = logging.getLogger(__name__) | |
class BM25Retriever(SparseRetriever): | |
""" | |
BM25-based sparse retrieval implementation. | |
This is a direct implementation that handles BM25 keyword search | |
without external adapters. It provides efficient sparse retrieval | |
for technical documentation with optimized tokenization. | |
Features: | |
- Technical term preservation (handles RISC-V, ARM Cortex-M, etc.) | |
- Configurable BM25 parameters (k1, b) | |
- Normalized scoring for fusion compatibility | |
- Efficient preprocessing and indexing | |
- Performance monitoring | |
Example: | |
config = { | |
"k1": 1.2, | |
"b": 0.75, | |
"lowercase": True, | |
"preserve_technical_terms": True | |
} | |
retriever = BM25Retriever(config) | |
retriever.index_documents(documents) | |
results = retriever.search("RISC-V processor", k=5) | |
""" | |
def __init__(self, config: Dict[str, Any]): | |
""" | |
Initialize BM25 sparse retriever. | |
Args: | |
config: Configuration dictionary with: | |
- k1: Term frequency saturation parameter (default: 1.2) | |
- b: Document length normalization factor (default: 0.75) | |
- lowercase: Whether to lowercase text (default: True) | |
- preserve_technical_terms: Whether to preserve technical terms (default: True) | |
- filter_stop_words: Whether to filter common stop words (default: True) | |
- stop_word_sets: List of predefined stopword sets to use (default: ["english_common"]) | |
- custom_stop_words: Additional stop words to filter (default: empty list) | |
- min_word_length: Minimum word length to preserve (default: 2) | |
- debug_stop_words: Enable debug logging for stopword filtering (default: False) | |
- min_score: Minimum normalized score threshold for results (default: 0.0) | |
""" | |
self.config = config | |
self.k1 = config.get("k1", 1.2) | |
self.b = config.get("b", 0.75) | |
self.lowercase = config.get("lowercase", True) | |
self.preserve_technical_terms = config.get("preserve_technical_terms", True) | |
self.filter_stop_words = config.get("filter_stop_words", True) | |
self.stop_word_sets = config.get("stop_word_sets", ["english_common"]) | |
self.custom_stop_words = set(config.get("custom_stop_words", [])) | |
self.min_word_length = config.get("min_word_length", 2) | |
self.debug_stop_words = config.get("debug_stop_words", False) | |
self.min_score = config.get("min_score", 0.0) | |
# Initialize stopword sets | |
self._initialize_stopword_sets() | |
# Combine all stopword sets | |
self.stop_words = set() | |
if self.filter_stop_words: | |
for set_name in self.stop_word_sets: | |
if set_name in self.available_stop_word_sets: | |
self.stop_words.update(self.available_stop_word_sets[set_name]) | |
else: | |
logger.warning(f"Unknown stopword set: {set_name}") | |
# Add custom stop words | |
self.stop_words.update(self.custom_stop_words) | |
else: | |
# Only use custom stop words if filtering is disabled | |
self.stop_words = self.custom_stop_words.copy() | |
# Validation | |
if self.k1 <= 0: | |
raise ValueError("k1 must be positive") | |
if not 0 <= self.b <= 1: | |
raise ValueError("b must be between 0 and 1") | |
# BM25 components | |
self.bm25: Optional[BM25Okapi] = None | |
self.documents: List[Document] = [] | |
self.tokenized_corpus: List[List[str]] = [] | |
self.chunk_mapping: List[int] = [] | |
# Deferred indexing control | |
self._index_dirty = False # Track if index needs rebuilding | |
self._deferred_mode = False # Enable deferred indexing mode | |
# Compile regex patterns for technical term preservation | |
if self.preserve_technical_terms: | |
self._tech_pattern = re.compile(r'[a-zA-Z0-9][\w\-_.]*[a-zA-Z0-9]|[a-zA-Z0-9]') | |
self._punctuation_pattern = re.compile(r'[^\w\s\-_.]') | |
else: | |
self._tech_pattern = re.compile(r'\b\w+\b') | |
self._punctuation_pattern = re.compile(r'[^\w\s]') | |
logger.info(f"BM25Retriever initialized with k1={self.k1}, b={self.b}, stop_word_sets={self.stop_word_sets}, stop_words={len(self.stop_words)}") | |
def _initialize_stopword_sets(self) -> None: | |
""" | |
Initialize predefined stopword sets for different filtering strategies. | |
""" | |
# Standard English stop words (articles, prepositions, common verbs) | |
english_common = { | |
'a', 'an', 'and', 'are', 'as', 'at', 'be', 'by', 'for', 'from', 'has', 'he', 'in', 'is', 'it', | |
'its', 'of', 'on', 'that', 'the', 'to', 'was', 'were', 'will', 'with', 'this', 'but', | |
'they', 'have', 'had', 'what', 'said', 'each', 'which', 'she', 'do', 'how', 'their', 'if', | |
'up', 'out', 'many', 'then', 'them', 'these', 'so', 'some', 'her', 'would', 'make', 'like', | |
'into', 'him', 'time', 'two', 'more', 'go', 'no', 'way', 'could', 'my', 'than', 'first', | |
'been', 'call', 'who', 'sit', 'now', 'find', 'down', 'day', 'did', 'get', 'come', | |
'made', 'may', 'part', 'much', 'too', 'any', 'after', 'back', 'other', 'see', | |
'want', 'just', 'also', 'when', 'here', 'all', 'well', 'can', 'should', 'must', 'might', | |
'shall', 'about', 'before', 'through', 'over', 'under', 'above', 'below', 'between', 'among' | |
} | |
# NOTE: Removed interrogative_words and irrelevant_entities sets | |
# These contained discriminative terms that should be preserved for proper BM25 behavior | |
# BM25 is designed for lexical matching only, not semantic analysis | |
# Extended set for comprehensive filtering | |
english_extended = english_common | { | |
'very', 'quite', 'really', 'actually', 'basically', 'essentially', 'generally', | |
'specifically', 'particularly', 'especially', 'exactly', 'precisely', 'approximately', | |
'roughly', 'mostly', 'mainly', 'primarily', 'largely', 'completely', 'totally', | |
'absolutely', 'definitely', 'certainly', 'probably', 'possibly', 'perhaps', | |
'maybe', 'sometimes', 'often', 'usually', 'always', 'never', 'rarely', 'seldom', | |
'frequently', 'occasionally', 'constantly', 'continuously', 'immediately', 'suddenly', | |
'quickly', 'slowly', 'carefully', 'easily', 'simply', 'clearly', 'obviously' | |
} | |
# Minimal set for technical domains (preserves more terms) | |
technical_minimal = { | |
'a', 'an', 'the', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'of', 'with', | |
'by', 'from', 'up', 'out', 'down', 'off', 'over', 'under', 'again', 'further', | |
'then', 'once', 'here', 'there', 'when', 'where', 'why', 'how', 'all', 'any', | |
'both', 'each', 'few', 'more', 'most', 'other', 'some', 'such', 'no', 'nor', | |
'not', 'only', 'own', 'same', 'so', 'than', 'too', 'very' | |
} | |
self.available_stop_word_sets = { | |
"english_common": english_common, | |
"english_extended": english_extended, | |
"technical_minimal": technical_minimal | |
} | |
def enable_deferred_indexing(self) -> None: | |
"""Enable deferred indexing mode to avoid rebuilding index on every batch""" | |
self._deferred_mode = True | |
logger.debug("Deferred indexing mode enabled") | |
def disable_deferred_indexing(self) -> None: | |
"""Disable deferred indexing mode and rebuild index if needed""" | |
self._deferred_mode = False | |
if self._index_dirty: | |
self._rebuild_index() | |
logger.debug("Deferred indexing mode disabled") | |
def force_rebuild_index(self) -> None: | |
"""Force rebuild the BM25 index with all accumulated documents""" | |
if self.tokenized_corpus: | |
self._rebuild_index() | |
else: | |
logger.warning("No documents to rebuild index") | |
def _rebuild_index(self) -> None: | |
"""Internal method to rebuild the BM25 index""" | |
if not self.tokenized_corpus: | |
logger.warning("No tokenized corpus available for index rebuild") | |
return | |
start_time = time.time() | |
self.bm25 = BM25Okapi(self.tokenized_corpus, k1=self.k1, b=self.b) | |
self._index_dirty = False | |
elapsed = time.time() - start_time | |
total_tokens = sum(len(tokens) for tokens in self.tokenized_corpus) | |
valid_doc_count = len([tokens for tokens in self.tokenized_corpus if tokens]) | |
logger.info(f"Rebuilt BM25 index with {valid_doc_count} documents in {elapsed:.3f}s") | |
def index_documents(self, documents: List[Document]) -> None: | |
""" | |
Index documents for BM25 sparse retrieval. | |
Args: | |
documents: List of documents to index | |
""" | |
if not documents: | |
raise ValueError("Cannot index empty document list") | |
start_time = time.time() | |
# Store documents (extend existing instead of replacing) | |
if not hasattr(self, 'documents') or self.documents is None: | |
self.documents = [] | |
if not hasattr(self, 'tokenized_corpus') or self.tokenized_corpus is None: | |
self.tokenized_corpus = [] | |
if not hasattr(self, 'chunk_mapping') or self.chunk_mapping is None: | |
self.chunk_mapping = [] | |
# Keep track of starting index for new documents | |
start_idx = len(self.documents) | |
# Add new documents | |
self.documents.extend(documents) | |
# Extract and preprocess texts for new documents only | |
texts = [doc.content for doc in documents] | |
new_tokenized = [self._preprocess_text(text) for text in texts] | |
# Filter out empty tokenized texts and track mapping for new documents | |
for i, tokens in enumerate(new_tokenized): | |
if tokens: # Only include non-empty tokenized texts | |
self.tokenized_corpus.append(tokens) | |
self.chunk_mapping.append(start_idx + i) | |
if not self.tokenized_corpus: | |
raise ValueError("No valid text content found in documents") | |
# Rebuild BM25 index unless in deferred mode | |
if self._deferred_mode: | |
# Mark index as dirty but don't rebuild yet | |
self._index_dirty = True | |
logger.debug(f"Added {len(documents)} documents to corpus (deferred mode - index not rebuilt)") | |
else: | |
# Rebuild index immediately (original behavior) | |
self._rebuild_index() | |
elapsed = time.time() - start_time | |
total_tokens = sum(len(tokens) for tokens in self.tokenized_corpus) | |
tokens_per_sec = total_tokens / elapsed if elapsed > 0 else 0 | |
valid_doc_count = len([tokens for tokens in self.tokenized_corpus if tokens]) | |
logger.info(f"Indexed {len(documents)} new documents ({valid_doc_count} total valid) in {elapsed:.3f}s") | |
logger.debug(f"Processing rate: {tokens_per_sec:.1f} tokens/second") | |
def search(self, query: str, k: int = 5) -> List[Tuple[int, float]]: | |
""" | |
Search for documents using BM25 sparse retrieval. | |
Args: | |
query: Search query string | |
k: Number of results to return | |
Returns: | |
List of (document_index, score) tuples sorted by relevance | |
""" | |
# Ensure index is built before searching | |
if self.bm25 is None or self._index_dirty: | |
if self._index_dirty: | |
logger.debug("Rebuilding BM25 index before search (was dirty)") | |
self._rebuild_index() | |
else: | |
raise ValueError("Must call index_documents() before searching") | |
if not query or not query.strip(): | |
return [] | |
if k <= 0: | |
raise ValueError("k must be positive") | |
# Preprocess query using same method as documents | |
query_tokens = self._preprocess_text(query) | |
if not query_tokens: | |
return [] | |
# Get BM25 scores for all documents | |
scores = self.bm25.get_scores(query_tokens) | |
if len(scores) == 0: | |
return [] | |
# Fix for rank_bm25 library bug: BM25 scores can be negative when they shouldn't be | |
# Ensure all scores are non-negative by shifting them if needed | |
min_raw_score = np.min(scores) | |
if min_raw_score < 0: | |
scores = scores - min_raw_score # Shift all scores to be non-negative | |
logger.debug(f"Shifted negative BM25 scores by {-min_raw_score:.6f}") | |
# Normalize scores to [0,1] range for fusion compatibility | |
max_score = np.max(scores) | |
min_score = np.min(scores) | |
if max_score > min_score: | |
# Standard min-max normalization to [0,1] | |
normalized_scores = (scores - min_score) / (max_score - min_score) | |
else: | |
# All scores are the same - check if any actual matches exist | |
if np.any(scores != 0): | |
# Scores are equal and non-zero (all docs equally relevant) | |
normalized_scores = np.ones_like(scores) | |
else: | |
# All scores are exactly zero (no matches) | |
normalized_scores = np.zeros_like(scores) | |
# Create results with original document indices | |
results = [ | |
(self.chunk_mapping[i], float(normalized_scores[i])) | |
for i in range(len(scores)) | |
] | |
# Filter out zero scores (no matches) and apply minimum score threshold | |
threshold = max(self.min_score, 0.001) # Always filter scores <= 0 | |
filtered_results = [(doc_idx, score) for doc_idx, score in results if score >= threshold] | |
if not filtered_results: | |
logger.debug(f"No BM25 results above score threshold {threshold}") | |
return [] | |
results = filtered_results | |
# Sort by score (descending) and return top_k | |
results.sort(key=lambda x: x[1], reverse=True) | |
return results[:k] | |
def get_document_count(self) -> int: | |
"""Get the number of indexed documents.""" | |
return len(self.documents) | |
def clear(self) -> None: | |
"""Clear all indexed documents.""" | |
self.documents.clear() | |
self.tokenized_corpus.clear() | |
self.chunk_mapping.clear() | |
self.bm25 = None | |
logger.info("BM25 index cleared") | |
def get_stats(self) -> Dict[str, Any]: | |
""" | |
Get statistics about the BM25 retriever. | |
Returns: | |
Dictionary with retriever statistics | |
""" | |
stats = { | |
"k1": self.k1, | |
"b": self.b, | |
"lowercase": self.lowercase, | |
"preserve_technical_terms": self.preserve_technical_terms, | |
"filter_stop_words": self.filter_stop_words, | |
"stop_word_sets": self.stop_word_sets, | |
"stop_words_count": len(self.stop_words) if self.stop_words else 0, | |
"min_word_length": self.min_word_length, | |
"debug_stop_words": self.debug_stop_words, | |
"min_score": self.min_score, | |
"total_documents": len(self.documents), | |
"valid_documents": len(self.chunk_mapping), | |
"is_indexed": self.bm25 is not None | |
} | |
if self.tokenized_corpus: | |
total_tokens = sum(len(tokens) for tokens in self.tokenized_corpus) | |
stats.update({ | |
"total_tokens": total_tokens, | |
"avg_tokens_per_doc": total_tokens / len(self.tokenized_corpus) if self.tokenized_corpus else 0 | |
}) | |
return stats | |
def _preprocess_text(self, text: str) -> List[str]: | |
""" | |
Preprocess text with standard BM25 stopword filtering. | |
Args: | |
text: Raw text to tokenize | |
Returns: | |
List of preprocessed tokens | |
""" | |
if not text or not text.strip(): | |
return [] | |
original_text = text | |
# Convert to lowercase while preserving structure | |
if self.lowercase: | |
text = text.lower() | |
# Remove punctuation except hyphens, underscores, periods in technical terms | |
text = self._punctuation_pattern.sub(' ', text) | |
# Extract tokens using appropriate pattern | |
tokens = self._tech_pattern.findall(text) | |
# Filter out tokens shorter than minimum length | |
if self.min_word_length > 1: | |
length_filtered = [token for token in tokens if len(token) >= self.min_word_length] | |
else: | |
length_filtered = [token for token in tokens if len(token) > 0] | |
# Apply standard stopword filtering (linguistic noise words only) | |
if self.stop_words: | |
filtered_tokens = [] | |
stop_words_removed = [] | |
for token in length_filtered: | |
token_lower = token.lower() | |
# Simple standard stopword filtering - no semantic analysis | |
if token_lower in self.stop_words: | |
stop_words_removed.append(token) | |
else: | |
filtered_tokens.append(token) | |
# Debug logging if enabled | |
if self.debug_stop_words and stop_words_removed: | |
logger.info(f"[BM25_DEBUG] Text: \"{original_text[:50]}{'...' if len(original_text) > 50 else ''}\"") | |
logger.info(f"[BM25_DEBUG] Tokens before filtering: {length_filtered}") | |
logger.info(f"[BM25_DEBUG] Stop words removed: {stop_words_removed}") | |
logger.info(f"[BM25_DEBUG] Tokens after filtering: {filtered_tokens}") | |
if length_filtered: | |
filter_rate = len(stop_words_removed) / len(length_filtered) * 100 | |
logger.info(f"[BM25_DEBUG] Filtering impact: {filter_rate:.1f}% tokens removed") | |
logger.info(f"[BM25_DEBUG] ---") | |
return filtered_tokens | |
else: | |
return length_filtered | |
def get_query_tokens(self, query: str) -> List[str]: | |
""" | |
Get preprocessed tokens for a query (useful for debugging). | |
Args: | |
query: Query string | |
Returns: | |
List of preprocessed tokens | |
""" | |
return self._preprocess_text(query) | |
def get_document_tokens(self, doc_index: int) -> List[str]: | |
""" | |
Get preprocessed tokens for a document (useful for debugging). | |
Args: | |
doc_index: Document index | |
Returns: | |
List of preprocessed tokens | |
""" | |
if 0 <= doc_index < len(self.tokenized_corpus): | |
return self.tokenized_corpus[doc_index] | |
else: | |
raise IndexError(f"Document index {doc_index} out of range") | |
def get_bm25_scores(self, query: str) -> List[float]: | |
""" | |
Get raw BM25 scores for all documents (useful for debugging). | |
Args: | |
query: Query string | |
Returns: | |
List of BM25 scores (not normalized) | |
""" | |
if self.bm25 is None: | |
raise ValueError("Must call index_documents() before getting scores") | |
query_tokens = self._preprocess_text(query) | |
if not query_tokens: | |
return [] | |
scores = self.bm25.get_scores(query_tokens) | |
return scores.tolist() | |
def get_term_frequencies(self, query: str) -> Dict[str, int]: | |
""" | |
Get term frequencies for a query (useful for analysis). | |
Args: | |
query: Query string | |
Returns: | |
Dictionary mapping terms to frequencies | |
""" | |
query_tokens = self._preprocess_text(query) | |
term_freqs = {} | |
for token in query_tokens: | |
term_freqs[token] = term_freqs.get(token, 0) + 1 | |
return term_freqs |