sight_chat / query_dpr.py
fmegahed's picture
version 2.0.0
ef821d9 verified
"""
Dense Passage Retrieval (DPR) query module.
Uses bi-encoder for retrieval and cross-encoder for re-ranking.
"""
import pickle
import logging
from typing import List, Tuple, Optional
import numpy as np
import faiss
from sentence_transformers import SentenceTransformer, CrossEncoder
from openai import OpenAI
from config import *
logger = logging.getLogger(__name__)
class DPRRetriever:
"""Dense Passage Retrieval with cross-encoder re-ranking."""
def __init__(self):
self.client = OpenAI(api_key=OPENAI_API_KEY)
self.bi_encoder = None
self.cross_encoder = None
self.index = None
self.metadata = None
self._load_models()
self._load_index()
def _load_models(self):
"""Load bi-encoder and cross-encoder models."""
try:
logger.info("Loading DPR models...")
self.bi_encoder = SentenceTransformer(SENTENCE_TRANSFORMER_MODEL)
self.cross_encoder = CrossEncoder(CROSS_ENCODER_MODEL)
if DEVICE == "cuda":
self.bi_encoder = self.bi_encoder.to(DEVICE)
self.cross_encoder = self.cross_encoder.to(DEVICE)
logger.info("βœ“ DPR models loaded successfully")
except Exception as e:
logger.error(f"Error loading DPR models: {e}")
raise
def _load_index(self):
"""Load FAISS index and metadata."""
try:
if DPR_FAISS_INDEX.exists() and DPR_METADATA.exists():
logger.info("Loading DPR index and metadata...")
# Load FAISS index
self.index = faiss.read_index(str(DPR_FAISS_INDEX))
# Load metadata
with open(DPR_METADATA, 'rb') as f:
data = pickle.load(f)
self.metadata = data
logger.info(f"βœ“ Loaded DPR index with {len(self.metadata)} chunks")
else:
logger.warning("DPR index not found. Run preprocess.py first.")
except Exception as e:
logger.error(f"Error loading DPR index: {e}")
raise
def retrieve_candidates(self, question: str, top_k: int = DEFAULT_TOP_K) -> List[Tuple[str, float, dict]]:
"""Retrieve candidate passages using bi-encoder."""
if self.index is None or self.metadata is None:
raise ValueError("DPR index not loaded. Run preprocess.py first.")
try:
# Encode question with bi-encoder
question_embedding = self.bi_encoder.encode([question], convert_to_numpy=True)
# Normalize for cosine similarity
faiss.normalize_L2(question_embedding)
# Search FAISS index
# Retrieve more candidates for re-ranking
retrieve_k = min(top_k * RERANK_MULTIPLIER, len(self.metadata))
scores, indices = self.index.search(question_embedding, retrieve_k)
# Prepare candidates
candidates = []
for score, idx in zip(scores[0], indices[0]):
if idx < len(self.metadata):
chunk_data = self.metadata[idx]
candidates.append((
chunk_data['text'],
float(score),
chunk_data['metadata']
))
logger.info(f"Retrieved {len(candidates)} candidates for re-ranking")
return candidates
except Exception as e:
logger.error(f"Error in candidate retrieval: {e}")
raise
def rerank_candidates(self, question: str, candidates: List[Tuple[str, float, dict]],
top_k: int = DEFAULT_TOP_K) -> List[Tuple[str, float, dict]]:
"""Re-rank candidates using cross-encoder."""
if not candidates:
return []
try:
# Prepare pairs for cross-encoder
pairs = [(question, candidate[0]) for candidate in candidates]
# Get cross-encoder scores
cross_scores = self.cross_encoder.predict(pairs)
# Combine with candidate data and re-sort
reranked = []
for i, (text, bi_score, metadata) in enumerate(candidates):
cross_score = float(cross_scores[i])
# Filter by minimum relevance score
if cross_score >= MIN_RELEVANCE_SCORE:
reranked.append((text, cross_score, metadata))
# Sort by cross-encoder score (descending)
reranked.sort(key=lambda x: x[1], reverse=True)
# Return top-k
final_results = reranked[:top_k]
logger.info(f"Re-ranked to {len(final_results)} final results")
return final_results
except Exception as e:
logger.error(f"Error in re-ranking: {e}")
# Fall back to bi-encoder results
return candidates[:top_k]
def generate_answer(self, question: str, context_chunks: List[Tuple[str, float, dict]]) -> str:
"""Generate answer using GPT with retrieved context."""
if not context_chunks:
return "I couldn't find relevant information to answer your question."
try:
# Prepare context
context_parts = []
for i, (text, score, metadata) in enumerate(context_chunks, 1):
source = metadata.get('source', 'Unknown')
context_parts.append(f"[Context {i}] Source: {source}\n{text}")
context = "\n\n".join(context_parts)
# Create system message
system_message = (
"You are a helpful assistant specialized in occupational safety and health. "
"Answer questions based only on the provided context. "
"If the context doesn't contain enough information, say so clearly. "
"Always cite the source when referencing information."
)
# Create user message
user_message = f"Context:\n{context}\n\nQuestion: {question}"
# Generate response
# For GPT-5, temperature must be default (1.0)
response = self.client.chat.completions.create(
model=OPENAI_CHAT_MODEL,
messages=[
{"role": "system", "content": system_message},
{"role": "user", "content": user_message}
],
max_completion_tokens=DEFAULT_MAX_TOKENS
)
return response.choices[0].message.content.strip()
except Exception as e:
logger.error(f"Error generating answer: {e}")
return "I apologize, but I encountered an error while generating the answer."
# Global retriever instance
_retriever = None
def get_retriever() -> DPRRetriever:
"""Get or create global DPR retriever instance."""
global _retriever
if _retriever is None:
_retriever = DPRRetriever()
return _retriever
def query(question: str, image_path: Optional[str] = None, top_k: int = DEFAULT_TOP_K) -> Tuple[str, List[dict]]:
"""
Main DPR query function with unified signature.
Args:
question: User question
image_path: Optional image path (not used in DPR but kept for consistency)
top_k: Number of top results to retrieve
Returns:
Tuple of (answer, citations)
"""
try:
retriever = get_retriever()
# Step 1: Retrieve candidates with bi-encoder
candidates = retriever.retrieve_candidates(question, top_k)
if not candidates:
return "I couldn't find any relevant information for your question.", []
# Step 2: Re-rank with cross-encoder
reranked_candidates = retriever.rerank_candidates(question, candidates, top_k)
# Step 3: Generate answer
answer = retriever.generate_answer(question, reranked_candidates)
# Step 4: Prepare citations
citations = []
for i, (text, score, metadata) in enumerate(reranked_candidates, 1):
citations.append({
'rank': i,
'text': text,
'score': float(score),
'source': metadata.get('source', 'Unknown'),
'type': metadata.get('type', 'unknown'),
'method': 'dpr'
})
logger.info(f"DPR query completed. Retrieved {len(citations)} citations.")
return answer, citations
except Exception as e:
logger.error(f"Error in DPR query: {e}")
error_message = "I apologize, but I encountered an error while processing your question with DPR."
return error_message, []
def query_with_details(question: str, image_path: Optional[str] = None,
top_k: int = DEFAULT_TOP_K) -> Tuple[str, List[dict], List[Tuple]]:
"""
DPR query function that returns detailed chunk information (for compatibility).
Returns:
Tuple of (answer, citations, chunks)
"""
answer, citations = query(question, image_path, top_k)
# Convert citations to chunk format for backward compatibility
chunks = []
for citation in citations:
chunks.append((
f"Rank {citation['rank']} (Score: {citation['score']:.3f})",
citation['score'],
citation['text'],
citation['source']
))
return answer, citations, chunks
if __name__ == "__main__":
# Test the DPR system
test_question = "What are the general requirements for machine guarding?"
print("Testing DPR retrieval system...")
print(f"Question: {test_question}")
print("-" * 50)
try:
answer, citations = query(test_question)
print("Answer:")
print(answer)
print("\nCitations:")
for citation in citations:
print(f"- {citation['source']} (Score: {citation['score']:.3f})")
except Exception as e:
print(f"Error during testing: {e}")