""" Graph-based RAG using NetworkX. Updated to match the common query signature used by other methods. """ import numpy as np import logging from typing import Tuple, List, Optional from openai import OpenAI import networkx as nx from sklearn.metrics.pairwise import cosine_similarity from config import * from utils import classify_image logger = logging.getLogger(__name__) # Initialize OpenAI client client = OpenAI(api_key=OPENAI_API_KEY) # Global variables for lazy loading _graph = None _enodes = None _embeddings = None def _load_graph(): """Lazy load graph database.""" global _graph, _enodes, _embeddings if _graph is None: try: if GRAPH_FILE.exists(): logger.info("Loading graph database...") _graph = nx.read_gml(str(GRAPH_FILE)) _enodes = list(_graph.nodes) # Convert embeddings from lists back to numpy arrays embeddings_list = [] for n in _enodes: embedding = _graph.nodes[n]['embedding'] if isinstance(embedding, list): embeddings_list.append(np.array(embedding)) else: embeddings_list.append(embedding) _embeddings = np.array(embeddings_list) logger.info(f"✓ Loaded graph with {len(_enodes)} nodes") else: logger.warning("Graph database not found. Run preprocess.py first.") _graph = nx.Graph() _enodes = [] _embeddings = np.array([]) except Exception as e: logger.error(f"Error loading graph: {e}") _graph = nx.Graph() _enodes = [] _embeddings = np.array([]) def query(question: str, image_path: Optional[str] = None, top_k: int = DEFAULT_TOP_K) -> Tuple[str, List[dict]]: """ Query using graph-based retrieval. Args: question: User's question image_path: Optional path to an image (for multimodal queries) top_k: Number of relevant chunks to retrieve Returns: Tuple of (answer, citations) """ # Load graph if not already loaded _load_graph() if len(_enodes) == 0: return "Graph database is empty. Please run preprocess.py first.", [] # Embed question using OpenAI emb_resp = client.embeddings.create( model=OPENAI_EMBEDDING_MODEL, input=question ) q_vec = np.array(emb_resp.data[0].embedding) # Compute cosine similarities sims = cosine_similarity([q_vec], _embeddings)[0] idxs = sims.argsort()[::-1][:top_k] # Collect chunk-level info chunks = [] citations = [] sources_seen = set() for rank, i in enumerate(idxs, start=1): node = _enodes[i] node_data = _graph.nodes[node] text = node_data['text'] # Extract header from text header = text.split('\n', 1)[0].lstrip('#').strip() score = sims[i] # Extract citation format - get source from metadata or node_data metadata = node_data.get('metadata', {}) source = metadata.get('source') or node_data.get('source') if not source: continue if 'url' in metadata: # HTML source citation_ref = metadata['url'] cite_type = 'html' elif 'path' in metadata: # PDF source citation_ref = metadata['path'] cite_type = 'pdf' elif 'url' in node_data: # Legacy format citation_ref = node_data['url'] cite_type = 'html' elif 'path' in node_data: # Legacy format citation_ref = node_data['path'] cite_type = 'pdf' else: citation_ref = source cite_type = 'unknown' chunks.append({ 'header': header, 'score': score, 'text': text, 'citation': citation_ref }) # Add unique citation if source not in sources_seen: citation_entry = { 'source': source, 'type': cite_type, 'relevance_score': round(float(score), 3) } if cite_type == 'html': citation_entry['url'] = citation_ref elif cite_type == 'pdf': citation_entry['path'] = citation_ref citations.append(citation_entry) sources_seen.add(source) # Handle image if provided image_context = "" if image_path: try: # Classify the image classification = classify_image(image_path) image_context = f"\n\n[Image Context: The provided image appears to be a {classification}.]" # Optionally, find related nodes in graph based on image classification # This would require storing image-related metadata in the graph except Exception as e: print(f"Error processing image: {e}") # Assemble context for prompt context = "\n\n---\n\n".join([c['text'] for c in chunks]) prompt = f"""Use the following context to answer the question: {context}{image_context} Question: {question} Please provide a comprehensive answer based on the context provided. Cite specific sources when providing information.""" # For GPT-5, temperature must be default (1.0) chat_resp = client.chat.completions.create( model=OPENAI_CHAT_MODEL, messages=[ {"role": "system", "content": "You are a helpful assistant for manufacturing equipment safety. Always provide accurate information based on the given context."}, {"role": "user", "content": prompt} ], max_completion_tokens=DEFAULT_MAX_TOKENS ) answer = chat_resp.choices[0].message.content return answer, citations def query_with_graph_traversal(question: str, top_k: int = 5, max_hops: int = 2) -> Tuple[str, List[dict]]: """ Enhanced graph query that can traverse edges to find related information. Args: question: User's question top_k: Number of initial nodes to retrieve max_hops: Maximum graph traversal depth Returns: Tuple of (answer, citations) """ # Load graph if not already loaded _load_graph() if len(_enodes) == 0: return "Graph database is empty. Please run preprocess.py first.", [] # Get initial nodes using standard query initial_answer, initial_citations = query(question, top_k=top_k) # For a more sophisticated implementation, you would: # 1. Add edges between related nodes during preprocessing # 2. Traverse from initial nodes to find related content # 3. Score the related nodes based on path distance and relevance # For now, return the standard query results return initial_answer, initial_citations def query_subgraph(question: str, source_filter: str = None, top_k: int = 5) -> Tuple[str, List[dict]]: """ Query a specific subgraph filtered by source. Args: question: User's question source_filter: Filter nodes by source (e.g., specific PDF name) top_k: Number of relevant chunks to retrieve Returns: Tuple of (answer, citations) """ # Load graph if not already loaded _load_graph() # Filter nodes if source specified if source_filter: filtered_nodes = [] for n in _enodes: node_data = _graph.nodes[n] metadata = node_data.get('metadata', {}) source = metadata.get('source') or node_data.get('source', '') source_from_meta = metadata.get('source', '') # Check both direct source and metadata source if (source_filter.lower() in source.lower() or source_filter.lower() in source_from_meta.lower()): filtered_nodes.append(n) if not filtered_nodes: return f"No nodes found for source: {source_filter}", [] else: filtered_nodes = _enodes # Get embeddings for filtered nodes filtered_embeddings = np.array([_graph.nodes[n]['embedding'] for n in filtered_nodes]) # Embed question emb_resp = client.embeddings.create( model=OPENAI_EMBEDDING_MODEL, input=question ) q_vec = np.array(emb_resp.data[0].embedding) # Compute similarities sims = cosine_similarity([q_vec], filtered_embeddings)[0] idxs = sims.argsort()[::-1][:top_k] # Collect results chunks = [] citations = [] sources_seen = set() for i in idxs: if i < len(filtered_nodes): node = filtered_nodes[i] node_data = _graph.nodes[node] chunks.append(node_data['text']) # Skip if source information missing metadata = node_data.get('metadata', {}) source = metadata.get('source') or node_data.get('source') if not source: continue if source not in sources_seen: citation = { 'source': source, 'type': 'pdf' if ('path' in metadata or 'path' in node_data) else 'html', 'relevance_score': round(float(sims[i]), 3) } # Check metadata first, then node_data for legacy support if 'url' in metadata: citation['url'] = metadata['url'] elif 'path' in metadata: citation['path'] = metadata['path'] elif 'url' in node_data: citation['url'] = node_data['url'] elif 'path' in node_data: citation['path'] = node_data['path'] citations.append(citation) sources_seen.add(source) # Build context and generate answer context = "\n\n---\n\n".join(chunks) prompt = f"""Answer the following question using the provided context: Context from {source_filter if source_filter else 'all sources'}: {context} Question: {question} Provide a detailed answer based on the context.""" # For GPT-5, temperature must be default (1.0) response = client.chat.completions.create( model=OPENAI_CHAT_MODEL, messages=[ {"role": "system", "content": "You are an expert on manufacturing safety. Answer based on the provided context."}, {"role": "user", "content": prompt} ], max_completion_tokens=DEFAULT_MAX_TOKENS ) answer = response.choices[0].message.content return answer, citations # Maintain backward compatibility with original function signature def query_graph(question: str, top_k: int = 5) -> Tuple[str, List[str], List[tuple]]: """ Original query_graph function signature for backward compatibility. Args: question: User's question top_k: Number of relevant chunks to retrieve Returns: Tuple of (answer, sources, chunks) """ # Call the new query function answer, citations = query(question, top_k=top_k) # Convert citations to old format sources = [c['source'] for c in citations] # Get chunks in old format (header, score, text, citation) _load_graph() if len(_enodes) == 0: return answer, sources, [] # Regenerate chunks for backward compatibility emb_resp = client.embeddings.create( model=OPENAI_EMBEDDING_MODEL, input=question ) q_vec = np.array(emb_resp.data[0].embedding) sims = cosine_similarity([q_vec], _embeddings)[0] idxs = sims.argsort()[::-1][:top_k] chunks = [] for i in idxs: node = _enodes[i] node_data = _graph.nodes[node] text = node_data['text'] header = text.split('\n', 1)[0].lstrip('#').strip() score = sims[i] # Skip if source information missing metadata = node_data.get('metadata', {}) source = metadata.get('source') or node_data.get('source') if not source: continue if 'url' in metadata: citation = metadata['url'] elif 'path' in metadata: citation = metadata['path'] elif 'url' in node_data: citation = node_data['url'] elif 'path' in node_data: citation = node_data['path'] else: citation = source chunks.append((header, score, text, citation)) return answer, sources, chunks if __name__ == "__main__": # Test the updated graph query test_questions = [ "What are general machine guarding requirements?", "How do I perform lockout/tagout procedures?", "What safety measures are needed for robotic systems?" ] for q in test_questions: print(f"\nQuestion: {q}") answer, citations = query(q) print(f"Answer: {answer[:200]}...") print(f"Citations: {[c['source'] for c in citations]}") print("-" * 50)