sight_chat / query_bm25.py
fmegahed's picture
version 2.0.0
ef821d9 verified
"""
BM25 keyword search with cross-encoder re-ranking and hybrid search support.
"""
import numpy as np
import faiss
from typing import Tuple, List, Optional
from openai import OpenAI
from sentence_transformers import CrossEncoder
import config
import utils
# Initialize models
client = OpenAI(api_key=config.OPENAI_API_KEY)
cross_encoder = CrossEncoder(config.CROSS_ENCODER_MODEL)
# Global variables for lazy loading
_bm25_index = None
_texts = None
_metadata = None
_semantic_index = None
def _load_bm25_index():
"""Lazy load BM25 index and metadata."""
global _bm25_index, _texts, _metadata, _semantic_index
if _bm25_index is None:
# Initialize defaults
_texts = []
_metadata = []
_semantic_index = None
try:
import pickle
if config.BM25_INDEX.exists():
with open(config.BM25_INDEX, 'rb') as f:
bm25_data = pickle.load(f)
if isinstance(bm25_data, dict):
_bm25_index = bm25_data.get('index') or bm25_data.get('bm25')
chunks = bm25_data.get('texts', [])
if chunks:
_texts = [chunk.text for chunk in chunks if hasattr(chunk, 'text')]
_metadata = [chunk.metadata for chunk in chunks if hasattr(chunk, 'metadata')]
else:
_texts = []
_metadata = []
# Load semantic embeddings if available for hybrid search
if 'embeddings' in bm25_data:
semantic_embeddings = bm25_data['embeddings']
# Build FAISS index
import faiss
dimension = semantic_embeddings.shape[1]
_semantic_index = faiss.IndexFlatIP(dimension)
faiss.normalize_L2(semantic_embeddings)
_semantic_index.add(semantic_embeddings)
else:
_bm25_index = bm25_data
_texts = []
_metadata = []
print(f"Loaded BM25 index with {len(_texts)} documents")
else:
print("BM25 index not found. Run preprocess.py first.")
except Exception as e:
print(f"Error loading BM25 index: {e}")
_bm25_index = None
_texts = []
_metadata = []
def query(question: str, image_path: Optional[str] = None, top_k: int = None) -> Tuple[str, List[dict]]:
"""
Query using BM25 keyword search with re-ranking.
Args:
question: User's question
image_path: Optional path to an image
top_k: Number of relevant chunks to retrieve
Returns:
Tuple of (answer, citations)
"""
if top_k is None:
top_k = config.DEFAULT_TOP_K
# Load index if not already loaded
_load_bm25_index()
if _bm25_index is None or len(_texts) == 0:
return "BM25 index not loaded. Please run preprocess.py first.", []
# Tokenize query for BM25
tokenized_query = question.lower().split()
# Get BM25 scores
bm25_scores = _bm25_index.get_scores(tokenized_query)
# Get top candidates (retrieve more for re-ranking)
top_indices = np.argsort(bm25_scores)[::-1][:top_k * config.RERANK_MULTIPLIER]
# Prepare candidates for re-ranking
candidates = []
for idx in top_indices:
if idx < len(_texts) and bm25_scores[idx] > 0:
candidates.append({
'text': _texts[idx],
'bm25_score': bm25_scores[idx],
'metadata': _metadata[idx],
'idx': idx
})
# Re-rank with cross-encoder
if candidates:
pairs = [[question, cand['text']] for cand in candidates]
cross_scores = cross_encoder.predict(pairs)
# Add cross-encoder scores and sort
for i, score in enumerate(cross_scores):
candidates[i]['cross_score'] = score
candidates = sorted(candidates, key=lambda x: x['cross_score'], reverse=True)[:top_k]
# Collect citations
citations = []
sources_seen = set()
for chunk in candidates:
chunk_meta = chunk['metadata']
if chunk_meta['source'] not in sources_seen:
citation = {
'source': chunk_meta['source'],
'type': chunk_meta['type'],
'bm25_score': round(chunk['bm25_score'], 3),
'rerank_score': round(chunk['cross_score'], 3)
}
if chunk_meta['type'] == 'pdf':
citation['path'] = chunk_meta['path']
else:
citation['url'] = chunk_meta.get('url', '')
citations.append(citation)
sources_seen.add(chunk_meta['source'])
# Handle image if provided
image_context = ""
if image_path:
try:
classification = utils.classify_image(image_path)
# classification is a string, not a dict
image_context = f"\n\n[Image Analysis: The image appears to show a {classification}.]"
except Exception as e:
print(f"Error processing image: {e}")
# Build context from retrieved chunks
context = "\n\n---\n\n".join([chunk['text'] for chunk in candidates])
if not context:
return "No relevant documents found for your query.", []
# Generate answer
prompt = f"""Answer the following question using the retrieved documents:
Retrieved Documents:
{context}{image_context}
Question: {question}
Instructions:
1. Provide a comprehensive answer based on the retrieved documents
2. Mention specific details from the sources
3. If the documents don't fully answer the question, indicate what information is missing"""
# For GPT-5, temperature must be default (1.0)
response = client.chat.completions.create(
model=config.OPENAI_CHAT_MODEL,
messages=[
{"role": "system", "content": "You are a technical expert on manufacturing safety and regulations. Provide accurate, detailed answers based on the retrieved documents."},
{"role": "user", "content": prompt}
],
max_completion_tokens=config.DEFAULT_MAX_TOKENS
)
answer = response.choices[0].message.content
return answer, citations
def query_hybrid(question: str, top_k: int = None, alpha: float = None) -> Tuple[str, List[dict]]:
"""
Hybrid search combining BM25 and semantic search.
Args:
question: User's question
top_k: Number of relevant chunks to retrieve
alpha: Weight for BM25 scores (1-alpha for semantic)
Returns:
Tuple of (answer, citations)
"""
if top_k is None:
top_k = config.DEFAULT_TOP_K
if alpha is None:
alpha = config.DEFAULT_HYBRID_ALPHA
# Load index if not already loaded
_load_bm25_index()
if _bm25_index is None or _semantic_index is None:
return "Hybrid search requires both BM25 and semantic indices. Please run preprocess.py with semantic embeddings.", []
# Get BM25 scores
tokenized_query = question.lower().split()
bm25_scores = _bm25_index.get_scores(tokenized_query)
# Normalize BM25 scores
if bm25_scores.max() > 0:
bm25_scores = bm25_scores / bm25_scores.max()
# Get semantic scores using FAISS
embedding_generator = utils.EmbeddingGenerator()
query_embedding = embedding_generator.embed_text_openai([question]).astype(np.float32)
faiss.normalize_L2(query_embedding)
# Search semantic index for all documents
k_search = min(len(_texts), top_k * config.RERANK_MULTIPLIER)
distances, indices = _semantic_index.search(query_embedding.reshape(1, -1), k_search)
# Create semantic scores array
semantic_scores = np.zeros(len(_texts))
for idx, dist in zip(indices[0], distances[0]):
if idx < len(_texts):
semantic_scores[idx] = dist
# Combine scores
hybrid_scores = alpha * bm25_scores + (1 - alpha) * semantic_scores
# Get top candidates
top_indices = np.argsort(hybrid_scores)[::-1][:top_k * config.RERANK_MULTIPLIER]
# Prepare candidates
candidates = []
for idx in top_indices:
if idx < len(_texts) and hybrid_scores[idx] > 0:
candidates.append({
'text': _texts[idx],
'hybrid_score': hybrid_scores[idx],
'bm25_score': bm25_scores[idx],
'semantic_score': semantic_scores[idx],
'metadata': _metadata[idx],
'idx': idx
})
# Re-rank with cross-encoder
if candidates:
pairs = [[question, cand['text']] for cand in candidates]
cross_scores = cross_encoder.predict(pairs)
for i, score in enumerate(cross_scores):
candidates[i]['cross_score'] = score
# Final ranking using cross-encoder scores
candidates = sorted(candidates, key=lambda x: x['cross_score'], reverse=True)[:top_k]
# Collect citations
citations = []
sources_seen = set()
for chunk in candidates:
chunk_meta = chunk['metadata']
if chunk_meta['source'] not in sources_seen:
citation = {
'source': chunk_meta['source'],
'type': chunk_meta['type'],
'hybrid_score': round(chunk['hybrid_score'], 3),
'rerank_score': round(chunk.get('cross_score', 0), 3)
}
if chunk_meta['type'] == 'pdf':
citation['path'] = chunk_meta['path']
else:
citation['url'] = chunk_meta.get('url', '')
citations.append(citation)
sources_seen.add(chunk_meta['source'])
# Build context
context = "\n\n---\n\n".join([chunk['text'] for chunk in candidates])
if not context:
return "No relevant documents found for your query.", []
# Generate answer
prompt = f"""Using the following retrieved passages, answer the question:
{context}
Question: {question}
Provide a clear, detailed answer based on the information in the passages."""
# For GPT-5, temperature must be default (1.0)
response = client.chat.completions.create(
model=config.OPENAI_CHAT_MODEL,
messages=[
{"role": "system", "content": "You are a safety expert. Answer questions accurately using the provided passages."},
{"role": "user", "content": prompt}
],
max_completion_tokens=config.DEFAULT_MAX_TOKENS
)
answer = response.choices[0].message.content
return answer, citations
if __name__ == "__main__":
# Test BM25 query
test_questions = [
"lockout tagout procedures",
"machine guard requirements OSHA",
"robot safety collaborative workspace"
]
for q in test_questions:
print(f"\nQuestion: {q}")
answer, citations = query(q)
print(f"Answer: {answer[:200]}...")
print(f"Citations: {citations}")
print("-" * 50)