""" Context stuffing query module. Loads full documents and uses heuristics to select relevant content. """ import pickle import logging import re from typing import List, Tuple, Optional, Dict, Any from openai import OpenAI import tiktoken from config import * logger = logging.getLogger(__name__) class ContextStuffingRetriever: """Context stuffing with heuristic document selection.""" def __init__(self): self.client = OpenAI(api_key=OPENAI_API_KEY) self.encoding = tiktoken.get_encoding("cl100k_base") self.documents = None self._load_documents() def _load_documents(self): """Load full documents for context stuffing.""" try: if CONTEXT_DOCS.exists(): logger.info("Loading documents for context stuffing...") with open(CONTEXT_DOCS, 'rb') as f: data = pickle.load(f) if isinstance(data, list) and len(data) > 0: # Handle both old format (list of chunks) and new format (list of DocumentChunk objects) if hasattr(data[0], 'text'): # New format with DocumentChunk objects self.documents = [] for chunk in data: self.documents.append({ 'text': chunk.text, 'metadata': chunk.metadata, 'chunk_id': chunk.chunk_id }) else: # Old format with dict objects self.documents = data logger.info(f"✓ Loaded {len(self.documents)} documents for context stuffing") else: logger.warning("No documents found in context stuffing file") self.documents = [] else: logger.warning("Context stuffing documents not found. Run preprocess.py first.") self.documents = [] except Exception as e: logger.error(f"Error loading context stuffing documents: {e}") self.documents = [] def _calculate_keyword_score(self, text: str, question: str) -> float: """Calculate keyword overlap score between text and question.""" # Simple keyword matching heuristic question_words = set(re.findall(r'\w+', question.lower())) text_words = set(re.findall(r'\w+', text.lower())) if not question_words: return 0.0 overlap = len(question_words & text_words) return overlap / len(question_words) def _calculate_section_relevance(self, text: str, question: str) -> float: """Calculate section relevance using multiple heuristics.""" score = 0.0 # Keyword overlap score (weight: 0.5) keyword_score = self._calculate_keyword_score(text, question) score += 0.5 * keyword_score # Length penalty (prefer medium-length sections) text_length = len(text.split()) optimal_length = 200 # words length_score = min(1.0, text_length / optimal_length) if text_length < optimal_length else max(0.1, optimal_length / text_length) score += 0.2 * length_score # Header/title bonus (if text starts with common header patterns) if re.match(r'^#+\s|^\d+\.\s|^[A-Z\s]{3,20}:', text.strip()): score += 0.1 # Question type specific bonuses question_lower = question.lower() text_lower = text.lower() if any(word in question_lower for word in ['what', 'define', 'definition']): if any(phrase in text_lower for phrase in ['means', 'defined as', 'definition', 'refers to']): score += 0.2 if any(word in question_lower for word in ['how', 'procedure', 'steps']): if any(phrase in text_lower for phrase in ['step', 'procedure', 'process', 'method']): score += 0.2 if any(word in question_lower for word in ['requirement', 'shall', 'must']): if any(phrase in text_lower for phrase in ['shall', 'must', 'required', 'requirement']): score += 0.2 return min(1.0, score) # Cap at 1.0 def select_relevant_documents(self, question: str, max_tokens: int = None) -> List[Dict[str, Any]]: """Select most relevant documents using heuristics.""" if not self.documents: return [] if max_tokens is None: max_tokens = MAX_CONTEXT_TOKENS # Score all documents scored_docs = [] for doc in self.documents: text = doc.get('text', '') if text.strip(): relevance_score = self._calculate_section_relevance(text, question) doc_info = { 'text': text, 'metadata': doc.get('metadata', {}), 'score': relevance_score, 'token_count': len(self.encoding.encode(text)) } scored_docs.append(doc_info) # Sort by relevance score scored_docs.sort(key=lambda x: x['score'], reverse=True) # Select documents within token limit selected_docs = [] total_tokens = 0 for doc in scored_docs: if doc['score'] > 0.1: # Minimum relevance threshold if total_tokens + doc['token_count'] <= max_tokens: selected_docs.append(doc) total_tokens += doc['token_count'] else: # Try to include a truncated version remaining_tokens = max_tokens - total_tokens if remaining_tokens > 100: # Only if meaningful content can fit truncated_text = self._truncate_text(doc['text'], remaining_tokens) if truncated_text: doc['text'] = truncated_text doc['token_count'] = len(self.encoding.encode(truncated_text)) selected_docs.append(doc) break logger.info(f"Selected {len(selected_docs)} documents with {total_tokens} total tokens") return selected_docs def _truncate_text(self, text: str, max_tokens: int) -> str: """Truncate text to fit within token limit while preserving meaning.""" tokens = self.encoding.encode(text) if len(tokens) <= max_tokens: return text # Truncate and try to end at a sentence boundary truncated_tokens = tokens[:max_tokens] truncated_text = self.encoding.decode(truncated_tokens) # Try to end at a sentence boundary sentences = re.split(r'[.!?]+', truncated_text) if len(sentences) > 1: # Remove the last incomplete sentence truncated_text = '.'.join(sentences[:-1]) + '.' return truncated_text def generate_answer(self, question: str, context_docs: List[Dict[str, Any]]) -> str: """Generate answer using full context stuffing approach.""" if not context_docs: return "I couldn't find any relevant documents to answer your question." try: # Assemble context from selected documents context_parts = [] sources = [] for i, doc in enumerate(context_docs, 1): text = doc['text'] metadata = doc['metadata'] source = metadata.get('source', f'Document {i}') context_parts.append(f"=== {source} ===\n{text}") if source not in sources: sources.append(source) full_context = "\n\n".join(context_parts) # Create system message for context stuffing system_message = ( "You are an expert in occupational safety and health regulations. " "Answer the user's question using the provided regulatory documents and technical materials. " "Provide comprehensive, accurate answers that directly address the question. " "Reference specific sections or requirements when applicable. " "If the provided context doesn't fully answer the question, clearly state what information is missing." ) # Create user message user_message = f"""Based on the following regulatory and technical documents, please answer this question: QUESTION: {question} DOCUMENTS: {full_context} Please provide a thorough answer based on the information in these documents. If any important details are missing from the provided context, please indicate that as well.""" # For GPT-5, temperature must be default (1.0) response = self.client.chat.completions.create( model=OPENAI_CHAT_MODEL, messages=[ {"role": "system", "content": system_message}, {"role": "user", "content": user_message} ], max_completion_tokens=DEFAULT_MAX_TOKENS ) answer = response.choices[0].message.content.strip() # Add source information if len(sources) > 1: answer += f"\n\n*Sources consulted: {', '.join(sources)}*" elif sources: answer += f"\n\n*Source: {sources[0]}*" return answer except Exception as e: logger.error(f"Error generating context stuffing answer: {e}") return "I apologize, but I encountered an error while generating the answer using context stuffing." # Global retriever instance _retriever = None def get_retriever() -> ContextStuffingRetriever: """Get or create global context stuffing retriever instance.""" global _retriever if _retriever is None: _retriever = ContextStuffingRetriever() return _retriever def query(question: str, image_path: Optional[str] = None, top_k: int = DEFAULT_TOP_K) -> Tuple[str, List[Dict]]: """ Main context stuffing query function with unified signature. Args: question: User question image_path: Optional image path (not used in context stuffing but kept for consistency) top_k: Not used in context stuffing (uses heuristic selection instead) Returns: Tuple of (answer, citations) """ try: retriever = get_retriever() # Select relevant documents using heuristics relevant_docs = retriever.select_relevant_documents(question) if not relevant_docs: return "I couldn't find any relevant documents to answer your question.", [] # Generate comprehensive answer answer = retriever.generate_answer(question, relevant_docs) # Prepare citations citations = [] for i, doc in enumerate(relevant_docs, 1): metadata = doc['metadata'] citations.append({ 'rank': i, 'score': float(doc['score']), 'source': metadata.get('source', 'Unknown'), 'type': metadata.get('type', 'unknown'), 'method': 'context_stuffing', 'tokens_used': doc['token_count'] }) logger.info(f"Context stuffing query completed. Used {len(citations)} documents.") return answer, citations except Exception as e: logger.error(f"Error in context stuffing query: {e}") error_message = "I apologize, but I encountered an error while processing your question with context stuffing." return error_message, [] def query_with_details(question: str, image_path: Optional[str] = None, top_k: int = DEFAULT_TOP_K) -> Tuple[str, List[Dict], List[Tuple]]: """ Context stuffing query function that returns detailed chunk information (for compatibility). Returns: Tuple of (answer, citations, chunks) """ answer, citations = query(question, image_path, top_k) # Convert citations to chunk format for backward compatibility chunks = [] for citation in citations: chunks.append(( f"Document {citation['rank']} (Score: {citation['score']:.3f})", citation['score'], f"Context from {citation['source']} ({citation['tokens_used']} tokens)", citation['source'] )) return answer, citations, chunks if __name__ == "__main__": # Test the context stuffing system test_question = "What are the general requirements for machine guarding?" print("Testing context stuffing retrieval system...") print(f"Question: {test_question}") print("-" * 50) try: answer, citations = query(test_question) print("Answer:") print(answer) print(f"\nCitations ({len(citations)} documents used):") for citation in citations: print(f"- {citation['source']} (Relevance: {citation['score']:.3f}, Tokens: {citation['tokens_used']})") except Exception as e: print(f"Error during testing: {e}")