""" Dense Passage Retrieval (DPR) query module. Uses bi-encoder for retrieval and cross-encoder for re-ranking. """ import pickle import logging from typing import List, Tuple, Optional import numpy as np import faiss from sentence_transformers import SentenceTransformer, CrossEncoder from openai import OpenAI from config import * logger = logging.getLogger(__name__) class DPRRetriever: """Dense Passage Retrieval with cross-encoder re-ranking.""" def __init__(self): self.client = OpenAI(api_key=OPENAI_API_KEY) self.bi_encoder = None self.cross_encoder = None self.index = None self.metadata = None self._load_models() self._load_index() def _load_models(self): """Load bi-encoder and cross-encoder models.""" try: logger.info("Loading DPR models...") self.bi_encoder = SentenceTransformer(SENTENCE_TRANSFORMER_MODEL) self.cross_encoder = CrossEncoder(CROSS_ENCODER_MODEL) if DEVICE == "cuda": self.bi_encoder = self.bi_encoder.to(DEVICE) self.cross_encoder = self.cross_encoder.to(DEVICE) logger.info("✓ DPR models loaded successfully") except Exception as e: logger.error(f"Error loading DPR models: {e}") raise def _load_index(self): """Load FAISS index and metadata.""" try: if DPR_FAISS_INDEX.exists() and DPR_METADATA.exists(): logger.info("Loading DPR index and metadata...") # Load FAISS index self.index = faiss.read_index(str(DPR_FAISS_INDEX)) # Load metadata with open(DPR_METADATA, 'rb') as f: data = pickle.load(f) self.metadata = data logger.info(f"✓ Loaded DPR index with {len(self.metadata)} chunks") else: logger.warning("DPR index not found. Run preprocess.py first.") except Exception as e: logger.error(f"Error loading DPR index: {e}") raise def retrieve_candidates(self, question: str, top_k: int = DEFAULT_TOP_K) -> List[Tuple[str, float, dict]]: """Retrieve candidate passages using bi-encoder.""" if self.index is None or self.metadata is None: raise ValueError("DPR index not loaded. Run preprocess.py first.") try: # Encode question with bi-encoder question_embedding = self.bi_encoder.encode([question], convert_to_numpy=True) # Normalize for cosine similarity faiss.normalize_L2(question_embedding) # Search FAISS index # Retrieve more candidates for re-ranking retrieve_k = min(top_k * RERANK_MULTIPLIER, len(self.metadata)) scores, indices = self.index.search(question_embedding, retrieve_k) # Prepare candidates candidates = [] for score, idx in zip(scores[0], indices[0]): if idx < len(self.metadata): chunk_data = self.metadata[idx] candidates.append(( chunk_data['text'], float(score), chunk_data['metadata'] )) logger.info(f"Retrieved {len(candidates)} candidates for re-ranking") return candidates except Exception as e: logger.error(f"Error in candidate retrieval: {e}") raise def rerank_candidates(self, question: str, candidates: List[Tuple[str, float, dict]], top_k: int = DEFAULT_TOP_K) -> List[Tuple[str, float, dict]]: """Re-rank candidates using cross-encoder.""" if not candidates: return [] try: # Prepare pairs for cross-encoder pairs = [(question, candidate[0]) for candidate in candidates] # Get cross-encoder scores cross_scores = self.cross_encoder.predict(pairs) # Combine with candidate data and re-sort reranked = [] for i, (text, bi_score, metadata) in enumerate(candidates): cross_score = float(cross_scores[i]) # Filter by minimum relevance score if cross_score >= MIN_RELEVANCE_SCORE: reranked.append((text, cross_score, metadata)) # Sort by cross-encoder score (descending) reranked.sort(key=lambda x: x[1], reverse=True) # Return top-k final_results = reranked[:top_k] logger.info(f"Re-ranked to {len(final_results)} final results") return final_results except Exception as e: logger.error(f"Error in re-ranking: {e}") # Fall back to bi-encoder results return candidates[:top_k] def generate_answer(self, question: str, context_chunks: List[Tuple[str, float, dict]]) -> str: """Generate answer using GPT with retrieved context.""" if not context_chunks: return "I couldn't find relevant information to answer your question." try: # Prepare context context_parts = [] for i, (text, score, metadata) in enumerate(context_chunks, 1): source = metadata.get('source', 'Unknown') context_parts.append(f"[Context {i}] Source: {source}\n{text}") context = "\n\n".join(context_parts) # Create system message system_message = ( "You are a helpful assistant specialized in occupational safety and health. " "Answer questions based only on the provided context. " "If the context doesn't contain enough information, say so clearly. " "Always cite the source when referencing information." ) # Create user message user_message = f"Context:\n{context}\n\nQuestion: {question}" # Generate response # For GPT-5, temperature must be default (1.0) response = self.client.chat.completions.create( model=OPENAI_CHAT_MODEL, messages=[ {"role": "system", "content": system_message}, {"role": "user", "content": user_message} ], max_completion_tokens=DEFAULT_MAX_TOKENS ) return response.choices[0].message.content.strip() except Exception as e: logger.error(f"Error generating answer: {e}") return "I apologize, but I encountered an error while generating the answer." # Global retriever instance _retriever = None def get_retriever() -> DPRRetriever: """Get or create global DPR retriever instance.""" global _retriever if _retriever is None: _retriever = DPRRetriever() return _retriever def query(question: str, image_path: Optional[str] = None, top_k: int = DEFAULT_TOP_K) -> Tuple[str, List[dict]]: """ Main DPR query function with unified signature. Args: question: User question image_path: Optional image path (not used in DPR but kept for consistency) top_k: Number of top results to retrieve Returns: Tuple of (answer, citations) """ try: retriever = get_retriever() # Step 1: Retrieve candidates with bi-encoder candidates = retriever.retrieve_candidates(question, top_k) if not candidates: return "I couldn't find any relevant information for your question.", [] # Step 2: Re-rank with cross-encoder reranked_candidates = retriever.rerank_candidates(question, candidates, top_k) # Step 3: Generate answer answer = retriever.generate_answer(question, reranked_candidates) # Step 4: Prepare citations citations = [] for i, (text, score, metadata) in enumerate(reranked_candidates, 1): citations.append({ 'rank': i, 'text': text, 'score': float(score), 'source': metadata.get('source', 'Unknown'), 'type': metadata.get('type', 'unknown'), 'method': 'dpr' }) logger.info(f"DPR query completed. Retrieved {len(citations)} citations.") return answer, citations except Exception as e: logger.error(f"Error in DPR query: {e}") error_message = "I apologize, but I encountered an error while processing your question with DPR." return error_message, [] def query_with_details(question: str, image_path: Optional[str] = None, top_k: int = DEFAULT_TOP_K) -> Tuple[str, List[dict], List[Tuple]]: """ DPR query function that returns detailed chunk information (for compatibility). Returns: Tuple of (answer, citations, chunks) """ answer, citations = query(question, image_path, top_k) # Convert citations to chunk format for backward compatibility chunks = [] for citation in citations: chunks.append(( f"Rank {citation['rank']} (Score: {citation['score']:.3f})", citation['score'], citation['text'], citation['source'] )) return answer, citations, chunks if __name__ == "__main__": # Test the DPR system test_question = "What are the general requirements for machine guarding?" print("Testing DPR retrieval system...") print(f"Question: {test_question}") print("-" * 50) try: answer, citations = query(test_question) print("Answer:") print(answer) print("\nCitations:") for citation in citations: print(f"- {citation['source']} (Score: {citation['score']:.3f})") except Exception as e: print(f"Error during testing: {e}")