Spaces:
Paused
Paused
""" | |
BM25 keyword search with cross-encoder re-ranking and hybrid search support. | |
""" | |
import numpy as np | |
import faiss | |
from typing import Tuple, List, Optional | |
from openai import OpenAI | |
from sentence_transformers import CrossEncoder | |
import config | |
import utils | |
# Initialize models | |
client = OpenAI(api_key=config.OPENAI_API_KEY) | |
cross_encoder = CrossEncoder(config.CROSS_ENCODER_MODEL) | |
# Global variables for lazy loading | |
_bm25_index = None | |
_texts = None | |
_metadata = None | |
_semantic_index = None | |
def _load_bm25_index(): | |
"""Lazy load BM25 index and metadata.""" | |
global _bm25_index, _texts, _metadata, _semantic_index | |
if _bm25_index is None: | |
# Initialize defaults | |
_texts = [] | |
_metadata = [] | |
_semantic_index = None | |
try: | |
import pickle | |
if config.BM25_INDEX.exists(): | |
with open(config.BM25_INDEX, 'rb') as f: | |
bm25_data = pickle.load(f) | |
if isinstance(bm25_data, dict): | |
_bm25_index = bm25_data.get('index') or bm25_data.get('bm25') | |
chunks = bm25_data.get('texts', []) | |
if chunks: | |
_texts = [chunk.text for chunk in chunks if hasattr(chunk, 'text')] | |
_metadata = [chunk.metadata for chunk in chunks if hasattr(chunk, 'metadata')] | |
else: | |
_texts = [] | |
_metadata = [] | |
# Load semantic embeddings if available for hybrid search | |
if 'embeddings' in bm25_data: | |
semantic_embeddings = bm25_data['embeddings'] | |
# Build FAISS index | |
import faiss | |
dimension = semantic_embeddings.shape[1] | |
_semantic_index = faiss.IndexFlatIP(dimension) | |
faiss.normalize_L2(semantic_embeddings) | |
_semantic_index.add(semantic_embeddings) | |
else: | |
_bm25_index = bm25_data | |
_texts = [] | |
_metadata = [] | |
print(f"Loaded BM25 index with {len(_texts)} documents") | |
else: | |
print("BM25 index not found. Run preprocess.py first.") | |
except Exception as e: | |
print(f"Error loading BM25 index: {e}") | |
_bm25_index = None | |
_texts = [] | |
_metadata = [] | |
def query(question: str, image_path: Optional[str] = None, top_k: int = None) -> Tuple[str, List[dict]]: | |
""" | |
Query using BM25 keyword search with re-ranking. | |
Args: | |
question: User's question | |
image_path: Optional path to an image | |
top_k: Number of relevant chunks to retrieve | |
Returns: | |
Tuple of (answer, citations) | |
""" | |
if top_k is None: | |
top_k = config.DEFAULT_TOP_K | |
# Load index if not already loaded | |
_load_bm25_index() | |
if _bm25_index is None or len(_texts) == 0: | |
return "BM25 index not loaded. Please run preprocess.py first.", [] | |
# Tokenize query for BM25 | |
tokenized_query = question.lower().split() | |
# Get BM25 scores | |
bm25_scores = _bm25_index.get_scores(tokenized_query) | |
# Get top candidates (retrieve more for re-ranking) | |
top_indices = np.argsort(bm25_scores)[::-1][:top_k * config.RERANK_MULTIPLIER] | |
# Prepare candidates for re-ranking | |
candidates = [] | |
for idx in top_indices: | |
if idx < len(_texts) and bm25_scores[idx] > 0: | |
candidates.append({ | |
'text': _texts[idx], | |
'bm25_score': bm25_scores[idx], | |
'metadata': _metadata[idx], | |
'idx': idx | |
}) | |
# Re-rank with cross-encoder | |
if candidates: | |
pairs = [[question, cand['text']] for cand in candidates] | |
cross_scores = cross_encoder.predict(pairs) | |
# Add cross-encoder scores and sort | |
for i, score in enumerate(cross_scores): | |
candidates[i]['cross_score'] = score | |
candidates = sorted(candidates, key=lambda x: x['cross_score'], reverse=True)[:top_k] | |
# Collect citations | |
citations = [] | |
sources_seen = set() | |
for chunk in candidates: | |
chunk_meta = chunk['metadata'] | |
if chunk_meta['source'] not in sources_seen: | |
citation = { | |
'source': chunk_meta['source'], | |
'type': chunk_meta['type'], | |
'bm25_score': round(chunk['bm25_score'], 3), | |
'rerank_score': round(chunk['cross_score'], 3) | |
} | |
if chunk_meta['type'] == 'pdf': | |
citation['path'] = chunk_meta['path'] | |
else: | |
citation['url'] = chunk_meta.get('url', '') | |
citations.append(citation) | |
sources_seen.add(chunk_meta['source']) | |
# Handle image if provided | |
image_context = "" | |
if image_path: | |
try: | |
classification = utils.classify_image(image_path) | |
# classification is a string, not a dict | |
image_context = f"\n\n[Image Analysis: The image appears to show a {classification}.]" | |
except Exception as e: | |
print(f"Error processing image: {e}") | |
# Build context from retrieved chunks | |
context = "\n\n---\n\n".join([chunk['text'] for chunk in candidates]) | |
if not context: | |
return "No relevant documents found for your query.", [] | |
# Generate answer | |
prompt = f"""Answer the following question using the retrieved documents: | |
Retrieved Documents: | |
{context}{image_context} | |
Question: {question} | |
Instructions: | |
1. Provide a comprehensive answer based on the retrieved documents | |
2. Mention specific details from the sources | |
3. If the documents don't fully answer the question, indicate what information is missing""" | |
# For GPT-5, temperature must be default (1.0) | |
response = client.chat.completions.create( | |
model=config.OPENAI_CHAT_MODEL, | |
messages=[ | |
{"role": "system", "content": "You are a technical expert on manufacturing safety and regulations. Provide accurate, detailed answers based on the retrieved documents."}, | |
{"role": "user", "content": prompt} | |
], | |
max_completion_tokens=config.DEFAULT_MAX_TOKENS | |
) | |
answer = response.choices[0].message.content | |
return answer, citations | |
def query_hybrid(question: str, top_k: int = None, alpha: float = None) -> Tuple[str, List[dict]]: | |
""" | |
Hybrid search combining BM25 and semantic search. | |
Args: | |
question: User's question | |
top_k: Number of relevant chunks to retrieve | |
alpha: Weight for BM25 scores (1-alpha for semantic) | |
Returns: | |
Tuple of (answer, citations) | |
""" | |
if top_k is None: | |
top_k = config.DEFAULT_TOP_K | |
if alpha is None: | |
alpha = config.DEFAULT_HYBRID_ALPHA | |
# Load index if not already loaded | |
_load_bm25_index() | |
if _bm25_index is None or _semantic_index is None: | |
return "Hybrid search requires both BM25 and semantic indices. Please run preprocess.py with semantic embeddings.", [] | |
# Get BM25 scores | |
tokenized_query = question.lower().split() | |
bm25_scores = _bm25_index.get_scores(tokenized_query) | |
# Normalize BM25 scores | |
if bm25_scores.max() > 0: | |
bm25_scores = bm25_scores / bm25_scores.max() | |
# Get semantic scores using FAISS | |
embedding_generator = utils.EmbeddingGenerator() | |
query_embedding = embedding_generator.embed_text_openai([question]).astype(np.float32) | |
faiss.normalize_L2(query_embedding) | |
# Search semantic index for all documents | |
k_search = min(len(_texts), top_k * config.RERANK_MULTIPLIER) | |
distances, indices = _semantic_index.search(query_embedding.reshape(1, -1), k_search) | |
# Create semantic scores array | |
semantic_scores = np.zeros(len(_texts)) | |
for idx, dist in zip(indices[0], distances[0]): | |
if idx < len(_texts): | |
semantic_scores[idx] = dist | |
# Combine scores | |
hybrid_scores = alpha * bm25_scores + (1 - alpha) * semantic_scores | |
# Get top candidates | |
top_indices = np.argsort(hybrid_scores)[::-1][:top_k * config.RERANK_MULTIPLIER] | |
# Prepare candidates | |
candidates = [] | |
for idx in top_indices: | |
if idx < len(_texts) and hybrid_scores[idx] > 0: | |
candidates.append({ | |
'text': _texts[idx], | |
'hybrid_score': hybrid_scores[idx], | |
'bm25_score': bm25_scores[idx], | |
'semantic_score': semantic_scores[idx], | |
'metadata': _metadata[idx], | |
'idx': idx | |
}) | |
# Re-rank with cross-encoder | |
if candidates: | |
pairs = [[question, cand['text']] for cand in candidates] | |
cross_scores = cross_encoder.predict(pairs) | |
for i, score in enumerate(cross_scores): | |
candidates[i]['cross_score'] = score | |
# Final ranking using cross-encoder scores | |
candidates = sorted(candidates, key=lambda x: x['cross_score'], reverse=True)[:top_k] | |
# Collect citations | |
citations = [] | |
sources_seen = set() | |
for chunk in candidates: | |
chunk_meta = chunk['metadata'] | |
if chunk_meta['source'] not in sources_seen: | |
citation = { | |
'source': chunk_meta['source'], | |
'type': chunk_meta['type'], | |
'hybrid_score': round(chunk['hybrid_score'], 3), | |
'rerank_score': round(chunk.get('cross_score', 0), 3) | |
} | |
if chunk_meta['type'] == 'pdf': | |
citation['path'] = chunk_meta['path'] | |
else: | |
citation['url'] = chunk_meta.get('url', '') | |
citations.append(citation) | |
sources_seen.add(chunk_meta['source']) | |
# Build context | |
context = "\n\n---\n\n".join([chunk['text'] for chunk in candidates]) | |
if not context: | |
return "No relevant documents found for your query.", [] | |
# Generate answer | |
prompt = f"""Using the following retrieved passages, answer the question: | |
{context} | |
Question: {question} | |
Provide a clear, detailed answer based on the information in the passages.""" | |
# For GPT-5, temperature must be default (1.0) | |
response = client.chat.completions.create( | |
model=config.OPENAI_CHAT_MODEL, | |
messages=[ | |
{"role": "system", "content": "You are a safety expert. Answer questions accurately using the provided passages."}, | |
{"role": "user", "content": prompt} | |
], | |
max_completion_tokens=config.DEFAULT_MAX_TOKENS | |
) | |
answer = response.choices[0].message.content | |
return answer, citations | |
if __name__ == "__main__": | |
# Test BM25 query | |
test_questions = [ | |
"lockout tagout procedures", | |
"machine guard requirements OSHA", | |
"robot safety collaborative workspace" | |
] | |
for q in test_questions: | |
print(f"\nQuestion: {q}") | |
answer, citations = query(q) | |
print(f"Answer: {answer[:200]}...") | |
print(f"Citations: {citations}") | |
print("-" * 50) |