Spaces:
Paused
Paused
File size: 11,753 Bytes
ef821d9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 |
"""
BM25 keyword search with cross-encoder re-ranking and hybrid search support.
"""
import numpy as np
import faiss
from typing import Tuple, List, Optional
from openai import OpenAI
from sentence_transformers import CrossEncoder
import config
import utils
# Initialize models
client = OpenAI(api_key=config.OPENAI_API_KEY)
cross_encoder = CrossEncoder(config.CROSS_ENCODER_MODEL)
# Global variables for lazy loading
_bm25_index = None
_texts = None
_metadata = None
_semantic_index = None
def _load_bm25_index():
"""Lazy load BM25 index and metadata."""
global _bm25_index, _texts, _metadata, _semantic_index
if _bm25_index is None:
# Initialize defaults
_texts = []
_metadata = []
_semantic_index = None
try:
import pickle
if config.BM25_INDEX.exists():
with open(config.BM25_INDEX, 'rb') as f:
bm25_data = pickle.load(f)
if isinstance(bm25_data, dict):
_bm25_index = bm25_data.get('index') or bm25_data.get('bm25')
chunks = bm25_data.get('texts', [])
if chunks:
_texts = [chunk.text for chunk in chunks if hasattr(chunk, 'text')]
_metadata = [chunk.metadata for chunk in chunks if hasattr(chunk, 'metadata')]
else:
_texts = []
_metadata = []
# Load semantic embeddings if available for hybrid search
if 'embeddings' in bm25_data:
semantic_embeddings = bm25_data['embeddings']
# Build FAISS index
import faiss
dimension = semantic_embeddings.shape[1]
_semantic_index = faiss.IndexFlatIP(dimension)
faiss.normalize_L2(semantic_embeddings)
_semantic_index.add(semantic_embeddings)
else:
_bm25_index = bm25_data
_texts = []
_metadata = []
print(f"Loaded BM25 index with {len(_texts)} documents")
else:
print("BM25 index not found. Run preprocess.py first.")
except Exception as e:
print(f"Error loading BM25 index: {e}")
_bm25_index = None
_texts = []
_metadata = []
def query(question: str, image_path: Optional[str] = None, top_k: int = None) -> Tuple[str, List[dict]]:
"""
Query using BM25 keyword search with re-ranking.
Args:
question: User's question
image_path: Optional path to an image
top_k: Number of relevant chunks to retrieve
Returns:
Tuple of (answer, citations)
"""
if top_k is None:
top_k = config.DEFAULT_TOP_K
# Load index if not already loaded
_load_bm25_index()
if _bm25_index is None or len(_texts) == 0:
return "BM25 index not loaded. Please run preprocess.py first.", []
# Tokenize query for BM25
tokenized_query = question.lower().split()
# Get BM25 scores
bm25_scores = _bm25_index.get_scores(tokenized_query)
# Get top candidates (retrieve more for re-ranking)
top_indices = np.argsort(bm25_scores)[::-1][:top_k * config.RERANK_MULTIPLIER]
# Prepare candidates for re-ranking
candidates = []
for idx in top_indices:
if idx < len(_texts) and bm25_scores[idx] > 0:
candidates.append({
'text': _texts[idx],
'bm25_score': bm25_scores[idx],
'metadata': _metadata[idx],
'idx': idx
})
# Re-rank with cross-encoder
if candidates:
pairs = [[question, cand['text']] for cand in candidates]
cross_scores = cross_encoder.predict(pairs)
# Add cross-encoder scores and sort
for i, score in enumerate(cross_scores):
candidates[i]['cross_score'] = score
candidates = sorted(candidates, key=lambda x: x['cross_score'], reverse=True)[:top_k]
# Collect citations
citations = []
sources_seen = set()
for chunk in candidates:
chunk_meta = chunk['metadata']
if chunk_meta['source'] not in sources_seen:
citation = {
'source': chunk_meta['source'],
'type': chunk_meta['type'],
'bm25_score': round(chunk['bm25_score'], 3),
'rerank_score': round(chunk['cross_score'], 3)
}
if chunk_meta['type'] == 'pdf':
citation['path'] = chunk_meta['path']
else:
citation['url'] = chunk_meta.get('url', '')
citations.append(citation)
sources_seen.add(chunk_meta['source'])
# Handle image if provided
image_context = ""
if image_path:
try:
classification = utils.classify_image(image_path)
# classification is a string, not a dict
image_context = f"\n\n[Image Analysis: The image appears to show a {classification}.]"
except Exception as e:
print(f"Error processing image: {e}")
# Build context from retrieved chunks
context = "\n\n---\n\n".join([chunk['text'] for chunk in candidates])
if not context:
return "No relevant documents found for your query.", []
# Generate answer
prompt = f"""Answer the following question using the retrieved documents:
Retrieved Documents:
{context}{image_context}
Question: {question}
Instructions:
1. Provide a comprehensive answer based on the retrieved documents
2. Mention specific details from the sources
3. If the documents don't fully answer the question, indicate what information is missing"""
# For GPT-5, temperature must be default (1.0)
response = client.chat.completions.create(
model=config.OPENAI_CHAT_MODEL,
messages=[
{"role": "system", "content": "You are a technical expert on manufacturing safety and regulations. Provide accurate, detailed answers based on the retrieved documents."},
{"role": "user", "content": prompt}
],
max_completion_tokens=config.DEFAULT_MAX_TOKENS
)
answer = response.choices[0].message.content
return answer, citations
def query_hybrid(question: str, top_k: int = None, alpha: float = None) -> Tuple[str, List[dict]]:
"""
Hybrid search combining BM25 and semantic search.
Args:
question: User's question
top_k: Number of relevant chunks to retrieve
alpha: Weight for BM25 scores (1-alpha for semantic)
Returns:
Tuple of (answer, citations)
"""
if top_k is None:
top_k = config.DEFAULT_TOP_K
if alpha is None:
alpha = config.DEFAULT_HYBRID_ALPHA
# Load index if not already loaded
_load_bm25_index()
if _bm25_index is None or _semantic_index is None:
return "Hybrid search requires both BM25 and semantic indices. Please run preprocess.py with semantic embeddings.", []
# Get BM25 scores
tokenized_query = question.lower().split()
bm25_scores = _bm25_index.get_scores(tokenized_query)
# Normalize BM25 scores
if bm25_scores.max() > 0:
bm25_scores = bm25_scores / bm25_scores.max()
# Get semantic scores using FAISS
embedding_generator = utils.EmbeddingGenerator()
query_embedding = embedding_generator.embed_text_openai([question]).astype(np.float32)
faiss.normalize_L2(query_embedding)
# Search semantic index for all documents
k_search = min(len(_texts), top_k * config.RERANK_MULTIPLIER)
distances, indices = _semantic_index.search(query_embedding.reshape(1, -1), k_search)
# Create semantic scores array
semantic_scores = np.zeros(len(_texts))
for idx, dist in zip(indices[0], distances[0]):
if idx < len(_texts):
semantic_scores[idx] = dist
# Combine scores
hybrid_scores = alpha * bm25_scores + (1 - alpha) * semantic_scores
# Get top candidates
top_indices = np.argsort(hybrid_scores)[::-1][:top_k * config.RERANK_MULTIPLIER]
# Prepare candidates
candidates = []
for idx in top_indices:
if idx < len(_texts) and hybrid_scores[idx] > 0:
candidates.append({
'text': _texts[idx],
'hybrid_score': hybrid_scores[idx],
'bm25_score': bm25_scores[idx],
'semantic_score': semantic_scores[idx],
'metadata': _metadata[idx],
'idx': idx
})
# Re-rank with cross-encoder
if candidates:
pairs = [[question, cand['text']] for cand in candidates]
cross_scores = cross_encoder.predict(pairs)
for i, score in enumerate(cross_scores):
candidates[i]['cross_score'] = score
# Final ranking using cross-encoder scores
candidates = sorted(candidates, key=lambda x: x['cross_score'], reverse=True)[:top_k]
# Collect citations
citations = []
sources_seen = set()
for chunk in candidates:
chunk_meta = chunk['metadata']
if chunk_meta['source'] not in sources_seen:
citation = {
'source': chunk_meta['source'],
'type': chunk_meta['type'],
'hybrid_score': round(chunk['hybrid_score'], 3),
'rerank_score': round(chunk.get('cross_score', 0), 3)
}
if chunk_meta['type'] == 'pdf':
citation['path'] = chunk_meta['path']
else:
citation['url'] = chunk_meta.get('url', '')
citations.append(citation)
sources_seen.add(chunk_meta['source'])
# Build context
context = "\n\n---\n\n".join([chunk['text'] for chunk in candidates])
if not context:
return "No relevant documents found for your query.", []
# Generate answer
prompt = f"""Using the following retrieved passages, answer the question:
{context}
Question: {question}
Provide a clear, detailed answer based on the information in the passages."""
# For GPT-5, temperature must be default (1.0)
response = client.chat.completions.create(
model=config.OPENAI_CHAT_MODEL,
messages=[
{"role": "system", "content": "You are a safety expert. Answer questions accurately using the provided passages."},
{"role": "user", "content": prompt}
],
max_completion_tokens=config.DEFAULT_MAX_TOKENS
)
answer = response.choices[0].message.content
return answer, citations
if __name__ == "__main__":
# Test BM25 query
test_questions = [
"lockout tagout procedures",
"machine guard requirements OSHA",
"robot safety collaborative workspace"
]
for q in test_questions:
print(f"\nQuestion: {q}")
answer, citations = query(q)
print(f"Answer: {answer[:200]}...")
print(f"Citations: {citations}")
print("-" * 50) |