File size: 11,753 Bytes
ef821d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
"""

BM25 keyword search with cross-encoder re-ranking and hybrid search support.

"""

import numpy as np
import faiss
from typing import Tuple, List, Optional
from openai import OpenAI
from sentence_transformers import CrossEncoder

import config
import utils

# Initialize models
client = OpenAI(api_key=config.OPENAI_API_KEY)
cross_encoder = CrossEncoder(config.CROSS_ENCODER_MODEL)

# Global variables for lazy loading
_bm25_index = None
_texts = None
_metadata = None
_semantic_index = None

def _load_bm25_index():
    """Lazy load BM25 index and metadata."""
    global _bm25_index, _texts, _metadata, _semantic_index
    
    if _bm25_index is None:
        # Initialize defaults
        _texts = []
        _metadata = []
        _semantic_index = None
        try:
            import pickle
            
            if config.BM25_INDEX.exists():
                with open(config.BM25_INDEX, 'rb') as f:
                    bm25_data = pickle.load(f)
                
                if isinstance(bm25_data, dict):
                    _bm25_index = bm25_data.get('index') or bm25_data.get('bm25')
                    chunks = bm25_data.get('texts', [])
                    if chunks:
                        _texts = [chunk.text for chunk in chunks if hasattr(chunk, 'text')]
                        _metadata = [chunk.metadata for chunk in chunks if hasattr(chunk, 'metadata')]
                    else:
                        _texts = []
                        _metadata = []
                    
                    # Load semantic embeddings if available for hybrid search
                    if 'embeddings' in bm25_data:
                        semantic_embeddings = bm25_data['embeddings']
                        # Build FAISS index
                        import faiss
                        dimension = semantic_embeddings.shape[1]
                        _semantic_index = faiss.IndexFlatIP(dimension)
                        faiss.normalize_L2(semantic_embeddings)
                        _semantic_index.add(semantic_embeddings)
                else:
                    _bm25_index = bm25_data
                    _texts = []
                    _metadata = []
                
                print(f"Loaded BM25 index with {len(_texts)} documents")
            else:
                print("BM25 index not found. Run preprocess.py first.")
                
        except Exception as e:
            print(f"Error loading BM25 index: {e}")
            _bm25_index = None
            _texts = []
            _metadata = []


def query(question: str, image_path: Optional[str] = None, top_k: int = None) -> Tuple[str, List[dict]]:
    """

    Query using BM25 keyword search with re-ranking.

    

    Args:

        question: User's question

        image_path: Optional path to an image

        top_k: Number of relevant chunks to retrieve

    

    Returns:

        Tuple of (answer, citations)

    """
    if top_k is None:
        top_k = config.DEFAULT_TOP_K
    
    # Load index if not already loaded
    _load_bm25_index()
    
    if _bm25_index is None or len(_texts) == 0:
        return "BM25 index not loaded. Please run preprocess.py first.", []
    
    # Tokenize query for BM25
    tokenized_query = question.lower().split()
    
    # Get BM25 scores
    bm25_scores = _bm25_index.get_scores(tokenized_query)
    
    # Get top candidates (retrieve more for re-ranking)
    top_indices = np.argsort(bm25_scores)[::-1][:top_k * config.RERANK_MULTIPLIER]
    
    # Prepare candidates for re-ranking
    candidates = []
    for idx in top_indices:
        if idx < len(_texts) and bm25_scores[idx] > 0:
            candidates.append({
                'text': _texts[idx],
                'bm25_score': bm25_scores[idx],
                'metadata': _metadata[idx],
                'idx': idx
            })
    
    # Re-rank with cross-encoder
    if candidates:
        pairs = [[question, cand['text']] for cand in candidates]
        cross_scores = cross_encoder.predict(pairs)
        
        # Add cross-encoder scores and sort
        for i, score in enumerate(cross_scores):
            candidates[i]['cross_score'] = score
        
        candidates = sorted(candidates, key=lambda x: x['cross_score'], reverse=True)[:top_k]
    
    # Collect citations
    citations = []
    sources_seen = set()
    
    for chunk in candidates:
        chunk_meta = chunk['metadata']
        
        if chunk_meta['source'] not in sources_seen:
            citation = {
                'source': chunk_meta['source'],
                'type': chunk_meta['type'],
                'bm25_score': round(chunk['bm25_score'], 3),
                'rerank_score': round(chunk['cross_score'], 3)
            }
            
            if chunk_meta['type'] == 'pdf':
                citation['path'] = chunk_meta['path']
            else:
                citation['url'] = chunk_meta.get('url', '')
            
            citations.append(citation)
            sources_seen.add(chunk_meta['source'])
    
    # Handle image if provided
    image_context = ""
    if image_path:
        try:
            classification = utils.classify_image(image_path)
            # classification is a string, not a dict
            image_context = f"\n\n[Image Analysis: The image appears to show a {classification}.]"
        except Exception as e:
            print(f"Error processing image: {e}")
    
    # Build context from retrieved chunks
    context = "\n\n---\n\n".join([chunk['text'] for chunk in candidates])
    
    if not context:
        return "No relevant documents found for your query.", []
    
    # Generate answer
    prompt = f"""Answer the following question using the retrieved documents:



Retrieved Documents:

{context}{image_context}



Question: {question}



Instructions:

1. Provide a comprehensive answer based on the retrieved documents

2. Mention specific details from the sources

3. If the documents don't fully answer the question, indicate what information is missing"""
    
    # For GPT-5, temperature must be default (1.0)
    response = client.chat.completions.create(
        model=config.OPENAI_CHAT_MODEL,
        messages=[
            {"role": "system", "content": "You are a technical expert on manufacturing safety and regulations. Provide accurate, detailed answers based on the retrieved documents."},
            {"role": "user", "content": prompt}
        ],
        max_completion_tokens=config.DEFAULT_MAX_TOKENS
    )
    
    answer = response.choices[0].message.content
    
    return answer, citations


def query_hybrid(question: str, top_k: int = None, alpha: float = None) -> Tuple[str, List[dict]]:
    """

    Hybrid search combining BM25 and semantic search.

    

    Args:

        question: User's question

        top_k: Number of relevant chunks to retrieve

        alpha: Weight for BM25 scores (1-alpha for semantic)

    

    Returns:

        Tuple of (answer, citations)

    """
    if top_k is None:
        top_k = config.DEFAULT_TOP_K
    if alpha is None:
        alpha = config.DEFAULT_HYBRID_ALPHA
    
    # Load index if not already loaded
    _load_bm25_index()
    
    if _bm25_index is None or _semantic_index is None:
        return "Hybrid search requires both BM25 and semantic indices. Please run preprocess.py with semantic embeddings.", []
    
    # Get BM25 scores
    tokenized_query = question.lower().split()
    bm25_scores = _bm25_index.get_scores(tokenized_query)
    
    # Normalize BM25 scores
    if bm25_scores.max() > 0:
        bm25_scores = bm25_scores / bm25_scores.max()
    
    # Get semantic scores using FAISS
    embedding_generator = utils.EmbeddingGenerator()
    query_embedding = embedding_generator.embed_text_openai([question]).astype(np.float32)
    faiss.normalize_L2(query_embedding)
    
    # Search semantic index for all documents
    k_search = min(len(_texts), top_k * config.RERANK_MULTIPLIER)
    distances, indices = _semantic_index.search(query_embedding.reshape(1, -1), k_search)
    
    # Create semantic scores array
    semantic_scores = np.zeros(len(_texts))
    for idx, dist in zip(indices[0], distances[0]):
        if idx < len(_texts):
            semantic_scores[idx] = dist
    
    # Combine scores
    hybrid_scores = alpha * bm25_scores + (1 - alpha) * semantic_scores
    
    # Get top candidates
    top_indices = np.argsort(hybrid_scores)[::-1][:top_k * config.RERANK_MULTIPLIER]
    
    # Prepare candidates
    candidates = []
    for idx in top_indices:
        if idx < len(_texts) and hybrid_scores[idx] > 0:
            candidates.append({
                'text': _texts[idx],
                'hybrid_score': hybrid_scores[idx],
                'bm25_score': bm25_scores[idx],
                'semantic_score': semantic_scores[idx],
                'metadata': _metadata[idx],
                'idx': idx
            })
    
    # Re-rank with cross-encoder
    if candidates:
        pairs = [[question, cand['text']] for cand in candidates]
        cross_scores = cross_encoder.predict(pairs)
        
        for i, score in enumerate(cross_scores):
            candidates[i]['cross_score'] = score
        
        # Final ranking using cross-encoder scores
        candidates = sorted(candidates, key=lambda x: x['cross_score'], reverse=True)[:top_k]
    
    # Collect citations
    citations = []
    sources_seen = set()
    
    for chunk in candidates:
        chunk_meta = chunk['metadata']
        
        if chunk_meta['source'] not in sources_seen:
            citation = {
                'source': chunk_meta['source'],
                'type': chunk_meta['type'],
                'hybrid_score': round(chunk['hybrid_score'], 3),
                'rerank_score': round(chunk.get('cross_score', 0), 3)
            }
            
            if chunk_meta['type'] == 'pdf':
                citation['path'] = chunk_meta['path']
            else:
                citation['url'] = chunk_meta.get('url', '')
            
            citations.append(citation)
            sources_seen.add(chunk_meta['source'])
    
    # Build context
    context = "\n\n---\n\n".join([chunk['text'] for chunk in candidates])
    
    if not context:
        return "No relevant documents found for your query.", []
    
    # Generate answer
    prompt = f"""Using the following retrieved passages, answer the question:



{context}



Question: {question}



Provide a clear, detailed answer based on the information in the passages."""
    
    # For GPT-5, temperature must be default (1.0)
    response = client.chat.completions.create(
        model=config.OPENAI_CHAT_MODEL,
        messages=[
            {"role": "system", "content": "You are a safety expert. Answer questions accurately using the provided passages."},
            {"role": "user", "content": prompt}
        ],
        max_completion_tokens=config.DEFAULT_MAX_TOKENS
    )
    
    answer = response.choices[0].message.content
    
    return answer, citations


if __name__ == "__main__":
    # Test BM25 query
    test_questions = [
        "lockout tagout procedures",
        "machine guard requirements OSHA",
        "robot safety collaborative workspace"
    ]
    
    for q in test_questions:
        print(f"\nQuestion: {q}")
        answer, citations = query(q)
        print(f"Answer: {answer[:200]}...")
        print(f"Citations: {citations}")
        print("-" * 50)