Spaces:
Paused
Paused
File size: 10,734 Bytes
ef821d9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 |
"""
Dense Passage Retrieval (DPR) query module.
Uses bi-encoder for retrieval and cross-encoder for re-ranking.
"""
import pickle
import logging
from typing import List, Tuple, Optional
import numpy as np
import faiss
from sentence_transformers import SentenceTransformer, CrossEncoder
from openai import OpenAI
from config import *
logger = logging.getLogger(__name__)
class DPRRetriever:
"""Dense Passage Retrieval with cross-encoder re-ranking."""
def __init__(self):
self.client = OpenAI(api_key=OPENAI_API_KEY)
self.bi_encoder = None
self.cross_encoder = None
self.index = None
self.metadata = None
self._load_models()
self._load_index()
def _load_models(self):
"""Load bi-encoder and cross-encoder models."""
try:
logger.info("Loading DPR models...")
self.bi_encoder = SentenceTransformer(SENTENCE_TRANSFORMER_MODEL)
self.cross_encoder = CrossEncoder(CROSS_ENCODER_MODEL)
if DEVICE == "cuda":
self.bi_encoder = self.bi_encoder.to(DEVICE)
self.cross_encoder = self.cross_encoder.to(DEVICE)
logger.info("β DPR models loaded successfully")
except Exception as e:
logger.error(f"Error loading DPR models: {e}")
raise
def _load_index(self):
"""Load FAISS index and metadata."""
try:
if DPR_FAISS_INDEX.exists() and DPR_METADATA.exists():
logger.info("Loading DPR index and metadata...")
# Load FAISS index
self.index = faiss.read_index(str(DPR_FAISS_INDEX))
# Load metadata
with open(DPR_METADATA, 'rb') as f:
data = pickle.load(f)
self.metadata = data
logger.info(f"β Loaded DPR index with {len(self.metadata)} chunks")
else:
logger.warning("DPR index not found. Run preprocess.py first.")
except Exception as e:
logger.error(f"Error loading DPR index: {e}")
raise
def retrieve_candidates(self, question: str, top_k: int = DEFAULT_TOP_K) -> List[Tuple[str, float, dict]]:
"""Retrieve candidate passages using bi-encoder."""
if self.index is None or self.metadata is None:
raise ValueError("DPR index not loaded. Run preprocess.py first.")
try:
# Encode question with bi-encoder
question_embedding = self.bi_encoder.encode([question], convert_to_numpy=True)
# Normalize for cosine similarity
faiss.normalize_L2(question_embedding)
# Search FAISS index
# Retrieve more candidates for re-ranking
retrieve_k = min(top_k * RERANK_MULTIPLIER, len(self.metadata))
scores, indices = self.index.search(question_embedding, retrieve_k)
# Prepare candidates
candidates = []
for score, idx in zip(scores[0], indices[0]):
if idx < len(self.metadata):
chunk_data = self.metadata[idx]
candidates.append((
chunk_data['text'],
float(score),
chunk_data['metadata']
))
logger.info(f"Retrieved {len(candidates)} candidates for re-ranking")
return candidates
except Exception as e:
logger.error(f"Error in candidate retrieval: {e}")
raise
def rerank_candidates(self, question: str, candidates: List[Tuple[str, float, dict]],
top_k: int = DEFAULT_TOP_K) -> List[Tuple[str, float, dict]]:
"""Re-rank candidates using cross-encoder."""
if not candidates:
return []
try:
# Prepare pairs for cross-encoder
pairs = [(question, candidate[0]) for candidate in candidates]
# Get cross-encoder scores
cross_scores = self.cross_encoder.predict(pairs)
# Combine with candidate data and re-sort
reranked = []
for i, (text, bi_score, metadata) in enumerate(candidates):
cross_score = float(cross_scores[i])
# Filter by minimum relevance score
if cross_score >= MIN_RELEVANCE_SCORE:
reranked.append((text, cross_score, metadata))
# Sort by cross-encoder score (descending)
reranked.sort(key=lambda x: x[1], reverse=True)
# Return top-k
final_results = reranked[:top_k]
logger.info(f"Re-ranked to {len(final_results)} final results")
return final_results
except Exception as e:
logger.error(f"Error in re-ranking: {e}")
# Fall back to bi-encoder results
return candidates[:top_k]
def generate_answer(self, question: str, context_chunks: List[Tuple[str, float, dict]]) -> str:
"""Generate answer using GPT with retrieved context."""
if not context_chunks:
return "I couldn't find relevant information to answer your question."
try:
# Prepare context
context_parts = []
for i, (text, score, metadata) in enumerate(context_chunks, 1):
source = metadata.get('source', 'Unknown')
context_parts.append(f"[Context {i}] Source: {source}\n{text}")
context = "\n\n".join(context_parts)
# Create system message
system_message = (
"You are a helpful assistant specialized in occupational safety and health. "
"Answer questions based only on the provided context. "
"If the context doesn't contain enough information, say so clearly. "
"Always cite the source when referencing information."
)
# Create user message
user_message = f"Context:\n{context}\n\nQuestion: {question}"
# Generate response
# For GPT-5, temperature must be default (1.0)
response = self.client.chat.completions.create(
model=OPENAI_CHAT_MODEL,
messages=[
{"role": "system", "content": system_message},
{"role": "user", "content": user_message}
],
max_completion_tokens=DEFAULT_MAX_TOKENS
)
return response.choices[0].message.content.strip()
except Exception as e:
logger.error(f"Error generating answer: {e}")
return "I apologize, but I encountered an error while generating the answer."
# Global retriever instance
_retriever = None
def get_retriever() -> DPRRetriever:
"""Get or create global DPR retriever instance."""
global _retriever
if _retriever is None:
_retriever = DPRRetriever()
return _retriever
def query(question: str, image_path: Optional[str] = None, top_k: int = DEFAULT_TOP_K) -> Tuple[str, List[dict]]:
"""
Main DPR query function with unified signature.
Args:
question: User question
image_path: Optional image path (not used in DPR but kept for consistency)
top_k: Number of top results to retrieve
Returns:
Tuple of (answer, citations)
"""
try:
retriever = get_retriever()
# Step 1: Retrieve candidates with bi-encoder
candidates = retriever.retrieve_candidates(question, top_k)
if not candidates:
return "I couldn't find any relevant information for your question.", []
# Step 2: Re-rank with cross-encoder
reranked_candidates = retriever.rerank_candidates(question, candidates, top_k)
# Step 3: Generate answer
answer = retriever.generate_answer(question, reranked_candidates)
# Step 4: Prepare citations
citations = []
for i, (text, score, metadata) in enumerate(reranked_candidates, 1):
citations.append({
'rank': i,
'text': text,
'score': float(score),
'source': metadata.get('source', 'Unknown'),
'type': metadata.get('type', 'unknown'),
'method': 'dpr'
})
logger.info(f"DPR query completed. Retrieved {len(citations)} citations.")
return answer, citations
except Exception as e:
logger.error(f"Error in DPR query: {e}")
error_message = "I apologize, but I encountered an error while processing your question with DPR."
return error_message, []
def query_with_details(question: str, image_path: Optional[str] = None,
top_k: int = DEFAULT_TOP_K) -> Tuple[str, List[dict], List[Tuple]]:
"""
DPR query function that returns detailed chunk information (for compatibility).
Returns:
Tuple of (answer, citations, chunks)
"""
answer, citations = query(question, image_path, top_k)
# Convert citations to chunk format for backward compatibility
chunks = []
for citation in citations:
chunks.append((
f"Rank {citation['rank']} (Score: {citation['score']:.3f})",
citation['score'],
citation['text'],
citation['source']
))
return answer, citations, chunks
if __name__ == "__main__":
# Test the DPR system
test_question = "What are the general requirements for machine guarding?"
print("Testing DPR retrieval system...")
print(f"Question: {test_question}")
print("-" * 50)
try:
answer, citations = query(test_question)
print("Answer:")
print(answer)
print("\nCitations:")
for citation in citations:
print(f"- {citation['source']} (Score: {citation['score']:.3f})")
except Exception as e:
print(f"Error during testing: {e}") |