Spaces:
Paused
Paused
""" | |
Dense Passage Retrieval (DPR) query module. | |
Uses bi-encoder for retrieval and cross-encoder for re-ranking. | |
""" | |
import pickle | |
import logging | |
from typing import List, Tuple, Optional | |
import numpy as np | |
import faiss | |
from sentence_transformers import SentenceTransformer, CrossEncoder | |
from openai import OpenAI | |
from config import * | |
logger = logging.getLogger(__name__) | |
class DPRRetriever: | |
"""Dense Passage Retrieval with cross-encoder re-ranking.""" | |
def __init__(self): | |
self.client = OpenAI(api_key=OPENAI_API_KEY) | |
self.bi_encoder = None | |
self.cross_encoder = None | |
self.index = None | |
self.metadata = None | |
self._load_models() | |
self._load_index() | |
def _load_models(self): | |
"""Load bi-encoder and cross-encoder models.""" | |
try: | |
logger.info("Loading DPR models...") | |
self.bi_encoder = SentenceTransformer(SENTENCE_TRANSFORMER_MODEL) | |
self.cross_encoder = CrossEncoder(CROSS_ENCODER_MODEL) | |
if DEVICE == "cuda": | |
self.bi_encoder = self.bi_encoder.to(DEVICE) | |
self.cross_encoder = self.cross_encoder.to(DEVICE) | |
logger.info("β DPR models loaded successfully") | |
except Exception as e: | |
logger.error(f"Error loading DPR models: {e}") | |
raise | |
def _load_index(self): | |
"""Load FAISS index and metadata.""" | |
try: | |
if DPR_FAISS_INDEX.exists() and DPR_METADATA.exists(): | |
logger.info("Loading DPR index and metadata...") | |
# Load FAISS index | |
self.index = faiss.read_index(str(DPR_FAISS_INDEX)) | |
# Load metadata | |
with open(DPR_METADATA, 'rb') as f: | |
data = pickle.load(f) | |
self.metadata = data | |
logger.info(f"β Loaded DPR index with {len(self.metadata)} chunks") | |
else: | |
logger.warning("DPR index not found. Run preprocess.py first.") | |
except Exception as e: | |
logger.error(f"Error loading DPR index: {e}") | |
raise | |
def retrieve_candidates(self, question: str, top_k: int = DEFAULT_TOP_K) -> List[Tuple[str, float, dict]]: | |
"""Retrieve candidate passages using bi-encoder.""" | |
if self.index is None or self.metadata is None: | |
raise ValueError("DPR index not loaded. Run preprocess.py first.") | |
try: | |
# Encode question with bi-encoder | |
question_embedding = self.bi_encoder.encode([question], convert_to_numpy=True) | |
# Normalize for cosine similarity | |
faiss.normalize_L2(question_embedding) | |
# Search FAISS index | |
# Retrieve more candidates for re-ranking | |
retrieve_k = min(top_k * RERANK_MULTIPLIER, len(self.metadata)) | |
scores, indices = self.index.search(question_embedding, retrieve_k) | |
# Prepare candidates | |
candidates = [] | |
for score, idx in zip(scores[0], indices[0]): | |
if idx < len(self.metadata): | |
chunk_data = self.metadata[idx] | |
candidates.append(( | |
chunk_data['text'], | |
float(score), | |
chunk_data['metadata'] | |
)) | |
logger.info(f"Retrieved {len(candidates)} candidates for re-ranking") | |
return candidates | |
except Exception as e: | |
logger.error(f"Error in candidate retrieval: {e}") | |
raise | |
def rerank_candidates(self, question: str, candidates: List[Tuple[str, float, dict]], | |
top_k: int = DEFAULT_TOP_K) -> List[Tuple[str, float, dict]]: | |
"""Re-rank candidates using cross-encoder.""" | |
if not candidates: | |
return [] | |
try: | |
# Prepare pairs for cross-encoder | |
pairs = [(question, candidate[0]) for candidate in candidates] | |
# Get cross-encoder scores | |
cross_scores = self.cross_encoder.predict(pairs) | |
# Combine with candidate data and re-sort | |
reranked = [] | |
for i, (text, bi_score, metadata) in enumerate(candidates): | |
cross_score = float(cross_scores[i]) | |
# Filter by minimum relevance score | |
if cross_score >= MIN_RELEVANCE_SCORE: | |
reranked.append((text, cross_score, metadata)) | |
# Sort by cross-encoder score (descending) | |
reranked.sort(key=lambda x: x[1], reverse=True) | |
# Return top-k | |
final_results = reranked[:top_k] | |
logger.info(f"Re-ranked to {len(final_results)} final results") | |
return final_results | |
except Exception as e: | |
logger.error(f"Error in re-ranking: {e}") | |
# Fall back to bi-encoder results | |
return candidates[:top_k] | |
def generate_answer(self, question: str, context_chunks: List[Tuple[str, float, dict]]) -> str: | |
"""Generate answer using GPT with retrieved context.""" | |
if not context_chunks: | |
return "I couldn't find relevant information to answer your question." | |
try: | |
# Prepare context | |
context_parts = [] | |
for i, (text, score, metadata) in enumerate(context_chunks, 1): | |
source = metadata.get('source', 'Unknown') | |
context_parts.append(f"[Context {i}] Source: {source}\n{text}") | |
context = "\n\n".join(context_parts) | |
# Create system message | |
system_message = ( | |
"You are a helpful assistant specialized in occupational safety and health. " | |
"Answer questions based only on the provided context. " | |
"If the context doesn't contain enough information, say so clearly. " | |
"Always cite the source when referencing information." | |
) | |
# Create user message | |
user_message = f"Context:\n{context}\n\nQuestion: {question}" | |
# Generate response | |
# For GPT-5, temperature must be default (1.0) | |
response = self.client.chat.completions.create( | |
model=OPENAI_CHAT_MODEL, | |
messages=[ | |
{"role": "system", "content": system_message}, | |
{"role": "user", "content": user_message} | |
], | |
max_completion_tokens=DEFAULT_MAX_TOKENS | |
) | |
return response.choices[0].message.content.strip() | |
except Exception as e: | |
logger.error(f"Error generating answer: {e}") | |
return "I apologize, but I encountered an error while generating the answer." | |
# Global retriever instance | |
_retriever = None | |
def get_retriever() -> DPRRetriever: | |
"""Get or create global DPR retriever instance.""" | |
global _retriever | |
if _retriever is None: | |
_retriever = DPRRetriever() | |
return _retriever | |
def query(question: str, image_path: Optional[str] = None, top_k: int = DEFAULT_TOP_K) -> Tuple[str, List[dict]]: | |
""" | |
Main DPR query function with unified signature. | |
Args: | |
question: User question | |
image_path: Optional image path (not used in DPR but kept for consistency) | |
top_k: Number of top results to retrieve | |
Returns: | |
Tuple of (answer, citations) | |
""" | |
try: | |
retriever = get_retriever() | |
# Step 1: Retrieve candidates with bi-encoder | |
candidates = retriever.retrieve_candidates(question, top_k) | |
if not candidates: | |
return "I couldn't find any relevant information for your question.", [] | |
# Step 2: Re-rank with cross-encoder | |
reranked_candidates = retriever.rerank_candidates(question, candidates, top_k) | |
# Step 3: Generate answer | |
answer = retriever.generate_answer(question, reranked_candidates) | |
# Step 4: Prepare citations | |
citations = [] | |
for i, (text, score, metadata) in enumerate(reranked_candidates, 1): | |
citations.append({ | |
'rank': i, | |
'text': text, | |
'score': float(score), | |
'source': metadata.get('source', 'Unknown'), | |
'type': metadata.get('type', 'unknown'), | |
'method': 'dpr' | |
}) | |
logger.info(f"DPR query completed. Retrieved {len(citations)} citations.") | |
return answer, citations | |
except Exception as e: | |
logger.error(f"Error in DPR query: {e}") | |
error_message = "I apologize, but I encountered an error while processing your question with DPR." | |
return error_message, [] | |
def query_with_details(question: str, image_path: Optional[str] = None, | |
top_k: int = DEFAULT_TOP_K) -> Tuple[str, List[dict], List[Tuple]]: | |
""" | |
DPR query function that returns detailed chunk information (for compatibility). | |
Returns: | |
Tuple of (answer, citations, chunks) | |
""" | |
answer, citations = query(question, image_path, top_k) | |
# Convert citations to chunk format for backward compatibility | |
chunks = [] | |
for citation in citations: | |
chunks.append(( | |
f"Rank {citation['rank']} (Score: {citation['score']:.3f})", | |
citation['score'], | |
citation['text'], | |
citation['source'] | |
)) | |
return answer, citations, chunks | |
if __name__ == "__main__": | |
# Test the DPR system | |
test_question = "What are the general requirements for machine guarding?" | |
print("Testing DPR retrieval system...") | |
print(f"Question: {test_question}") | |
print("-" * 50) | |
try: | |
answer, citations = query(test_question) | |
print("Answer:") | |
print(answer) | |
print("\nCitations:") | |
for citation in citations: | |
print(f"- {citation['source']} (Score: {citation['score']:.3f})") | |
except Exception as e: | |
print(f"Error during testing: {e}") |