enhanced-rag-demo / src /components /retrievers /graph /document_graph_builder.py
Arthur Passuello
initial commit
5e1a30c
"""
Document graph builder for Epic 2 Week 2.
This module provides graph construction capabilities for technical documents,
using NetworkX to build knowledge graphs that capture relationships between
concepts, protocols, architectures, and extensions in RISC-V documentation.
"""
import logging
import time
import hashlib
from typing import List, Dict, Any, Optional, Set, Tuple, Union
from dataclasses import dataclass, field
from collections import defaultdict
try:
import networkx as nx
import numpy as np
except ImportError:
nx = None
np = None
from src.core.interfaces import Document
from .config.graph_config import GraphBuilderConfig
from .entity_extraction import Entity, EntityExtractor
logger = logging.getLogger(__name__)
@dataclass
class GraphNode:
"""Represents a node in the document graph."""
node_id: str
node_type: str # concept, protocol, architecture, extension
text: str
documents: Set[str] = field(default_factory=set)
frequency: int = 0
confidence: float = 0.0
metadata: Dict[str, Any] = field(default_factory=dict)
def __hash__(self) -> int:
return hash(self.node_id)
def __eq__(self, other) -> bool:
if not isinstance(other, GraphNode):
return False
return self.node_id == other.node_id
@dataclass
class GraphEdge:
"""Represents an edge in the document graph."""
source: str
target: str
edge_type: str # implements, extends, requires, conflicts
weight: float
confidence: float
documents: Set[str] = field(default_factory=set)
metadata: Dict[str, Any] = field(default_factory=dict)
def __hash__(self) -> int:
return hash((self.source, self.target, self.edge_type))
class DocumentGraphBuilderError(Exception):
"""Raised when graph construction operations fail."""
pass
class DocumentGraphBuilder:
"""
Builds knowledge graphs from technical documents.
This class constructs NetworkX graphs that capture semantic relationships
between technical concepts in RISC-V documentation. It processes entities
extracted from documents and builds a graph structure that can be used
for graph-based retrieval and analysis.
Features:
- NetworkX-based graph construction
- Support for multiple node types (concept, protocol, architecture, extension)
- Multiple relationship types (implements, extends, requires, conflicts)
- Incremental graph updates
- Memory optimization with graph pruning
- Performance monitoring and statistics
"""
def __init__(self, config: GraphBuilderConfig, entity_extractor: EntityExtractor):
"""
Initialize the document graph builder.
Args:
config: Graph builder configuration
entity_extractor: Entity extractor for processing documents
"""
if nx is None:
raise DocumentGraphBuilderError("NetworkX is not installed. Install with: pip install networkx")
self.config = config
self.entity_extractor = entity_extractor
# Initialize graph
self.graph = nx.DiGraph() # Directed graph for relationships
# Node and edge tracking
self.nodes: Dict[str, GraphNode] = {}
self.edges: Dict[Tuple[str, str, str], GraphEdge] = {}
# Document tracking
self.document_entities: Dict[str, List[Entity]] = {}
self.processed_documents: Set[str] = set()
# Statistics
self.stats = {
"documents_processed": 0,
"total_nodes": 0,
"total_edges": 0,
"construction_time": 0.0,
"last_update_time": 0.0,
"memory_usage_mb": 0.0,
"pruning_operations": 0
}
logger.info(f"DocumentGraphBuilder initialized with {self.config.implementation} backend")
def build_graph(self, documents: List[Document]) -> nx.DiGraph:
"""
Build knowledge graph from documents.
Args:
documents: List of documents to process
Returns:
NetworkX directed graph
"""
if not documents:
logger.warning("No documents provided for graph construction")
return self.graph
start_time = time.time()
try:
logger.info(f"Building graph from {len(documents)} documents")
# Extract entities from all documents
document_entities = self.entity_extractor.extract_entities(documents)
# Build nodes from entities
self._build_nodes(document_entities)
# Build edges from co-occurrence and semantic relationships
self._build_edges(documents, document_entities)
# Prune graph if enabled
if self.config.enable_pruning:
self._prune_graph()
# Update statistics
construction_time = time.time() - start_time
self._update_statistics(documents, construction_time)
logger.info(
f"Graph construction completed in {construction_time:.3f}s "
f"({len(self.nodes)} nodes, {len(self.edges)} edges)"
)
return self.graph
except Exception as e:
logger.error(f"Graph construction failed: {str(e)}")
raise DocumentGraphBuilderError(f"Failed to build graph: {str(e)}") from e
def update_graph(self, new_documents: List[Document]) -> nx.DiGraph:
"""
Incrementally update graph with new documents.
Args:
new_documents: List of new documents to add
Returns:
Updated NetworkX directed graph
"""
if not new_documents:
return self.graph
# Filter out already processed documents
unprocessed_docs = [
doc for doc in new_documents
if doc.metadata.get("id", "unknown") not in self.processed_documents
]
if not unprocessed_docs:
logger.info("All documents already processed")
return self.graph
start_time = time.time()
try:
logger.info(f"Updating graph with {len(unprocessed_docs)} new documents")
# Extract entities from new documents
new_entities = self.entity_extractor.extract_entities(unprocessed_docs)
# Update nodes
self._update_nodes(new_entities)
# Update edges
self._update_edges(unprocessed_docs, new_entities)
# Prune if necessary
if self.config.enable_pruning and len(self.nodes) > self.config.max_graph_size:
self._prune_graph()
# Update statistics
update_time = time.time() - start_time
self.stats["last_update_time"] = update_time
self.stats["documents_processed"] += len(unprocessed_docs)
logger.info(f"Graph updated in {update_time:.3f}s")
return self.graph
except Exception as e:
logger.error(f"Graph update failed: {str(e)}")
raise DocumentGraphBuilderError(f"Failed to update graph: {str(e)}") from e
def _build_nodes(self, document_entities: Dict[str, List[Entity]]) -> None:
"""Build graph nodes from extracted entities."""
entity_groups = defaultdict(list)
# Group entities by normalized text and type
for doc_id, entities in document_entities.items():
for entity in entities:
key = self._normalize_entity_text(entity.text, entity.label)
entity_groups[key].append((entity, doc_id))
# Create nodes from entity groups
for normalized_key, entity_instances in entity_groups.items():
node_id = self._generate_node_id(normalized_key)
# Aggregate information from all instances
entity_types = [e[0].label for e in entity_instances]
most_common_type = max(set(entity_types), key=entity_types.count)
# Get representative text (longest variant)
texts = [e[0].text for e in entity_instances]
representative_text = max(texts, key=len)
# Calculate aggregate confidence
confidences = [e[0].confidence for e in entity_instances]
avg_confidence = sum(confidences) / len(confidences)
# Get all documents containing this entity
documents = set(e[1] for e in entity_instances)
# Create node
node = GraphNode(
node_id=node_id,
node_type=self._map_entity_to_node_type(most_common_type),
text=representative_text,
documents=documents,
frequency=len(entity_instances),
confidence=avg_confidence,
metadata={
"entity_variants": list(set(texts)),
"source_types": list(set(entity_types))
}
)
self.nodes[node_id] = node
# Add node to NetworkX graph
self.graph.add_node(
node_id,
node_type=node.node_type,
text=node.text,
frequency=node.frequency,
confidence=node.confidence,
documents=list(node.documents),
metadata=node.metadata
)
logger.info(f"Created {len(self.nodes)} nodes from entities")
def _build_edges(self, documents: List[Document], document_entities: Dict[str, List[Entity]]) -> None:
"""Build graph edges from entity co-occurrence and relationships."""
# Store document entities for edge building
self.document_entities.update(document_entities)
# Build edges from co-occurrence within documents
for doc_id, entities in document_entities.items():
if len(entities) < 2:
continue
# Create edges between entities in the same document
for i, entity1 in enumerate(entities):
for entity2 in entities[i+1:]:
self._create_co_occurrence_edge(entity1, entity2, doc_id)
# Build semantic relationship edges
self._build_semantic_edges(documents, document_entities)
logger.info(f"Created {len(self.edges)} edges from relationships")
def _create_co_occurrence_edge(self, entity1: Entity, entity2: Entity, doc_id: str) -> None:
"""Create co-occurrence edge between two entities."""
node1_id = self._generate_node_id(self._normalize_entity_text(entity1.text, entity1.label))
node2_id = self._generate_node_id(self._normalize_entity_text(entity2.text, entity2.label))
if node1_id == node2_id:
return # Skip self-loops
# Determine edge type based on entity types and context
edge_type = self._determine_edge_type(entity1, entity2)
# Calculate edge weight and confidence
weight = self._calculate_edge_weight(entity1, entity2)
confidence = min(entity1.confidence, entity2.confidence) * 0.8 # Co-occurrence has lower confidence
# Create or update edge
edge_key = (node1_id, node2_id, edge_type)
if edge_key in self.edges:
# Update existing edge
edge = self.edges[edge_key]
edge.weight = max(edge.weight, weight) # Keep highest weight
edge.confidence = max(edge.confidence, confidence)
edge.documents.add(doc_id)
else:
# Create new edge
edge = GraphEdge(
source=node1_id,
target=node2_id,
edge_type=edge_type,
weight=weight,
confidence=confidence,
documents={doc_id},
metadata={"creation_type": "co_occurrence"}
)
self.edges[edge_key] = edge
# Add edge to NetworkX graph
self.graph.add_edge(
node1_id,
node2_id,
edge_type=edge_type,
weight=weight,
confidence=confidence,
documents=list(edge.documents),
metadata=edge.metadata
)
def _build_semantic_edges(self, documents: List[Document], document_entities: Dict[str, List[Entity]]) -> None:
"""Build semantic relationship edges using text analysis."""
# This is a simplified implementation - in practice, you might use
# more sophisticated NLP techniques for relationship extraction
relationship_patterns = {
"implements": [
r"(\w+)\s+implements?\s+(\w+)",
r"(\w+)\s+implementation\s+of\s+(\w+)",
],
"extends": [
r"(\w+)\s+extends?\s+(\w+)",
r"(\w+)\s+extension\s+of\s+(\w+)",
r"(\w+)\s+based\s+on\s+(\w+)",
],
"requires": [
r"(\w+)\s+requires?\s+(\w+)",
r"(\w+)\s+depends?\s+on\s+(\w+)",
r"(\w+)\s+needs?\s+(\w+)",
],
"conflicts": [
r"(\w+)\s+conflicts?\s+with\s+(\w+)",
r"(\w+)\s+incompatible\s+with\s+(\w+)",
]
}
import re
for doc_id, entities in document_entities.items():
# Find document by metadata id
document = None
for doc in documents:
if doc.metadata.get("id", "unknown") == doc_id:
document = doc
break
if document is None:
continue
content = document.content.lower()
# Find explicit relationships in text
for relationship_type, patterns in relationship_patterns.items():
for pattern in patterns:
matches = re.findall(pattern, content, re.IGNORECASE)
for match in matches:
if len(match) == 2:
source_text, target_text = match
self._create_semantic_edge(
source_text, target_text, relationship_type,
entities, doc_id
)
def _create_semantic_edge(self, source_text: str, target_text: str,
relationship_type: str, entities: List[Entity], doc_id: str) -> None:
"""Create semantic relationship edge between entities."""
# Find matching entities
source_entity = self._find_matching_entity(source_text, entities)
target_entity = self._find_matching_entity(target_text, entities)
if not source_entity or not target_entity:
return
source_id = self._generate_node_id(self._normalize_entity_text(source_entity.text, source_entity.label))
target_id = self._generate_node_id(self._normalize_entity_text(target_entity.text, target_entity.label))
if source_id == target_id:
return
# High confidence for explicit relationships
confidence = 0.9
weight = 1.0
edge_key = (source_id, target_id, relationship_type)
if edge_key in self.edges:
edge = self.edges[edge_key]
edge.confidence = max(edge.confidence, confidence)
edge.documents.add(doc_id)
else:
edge = GraphEdge(
source=source_id,
target=target_id,
edge_type=relationship_type,
weight=weight,
confidence=confidence,
documents={doc_id},
metadata={"creation_type": "semantic_pattern"}
)
self.edges[edge_key] = edge
# Add edge to NetworkX graph
self.graph.add_edge(
source_id,
target_id,
edge_type=relationship_type,
weight=weight,
confidence=confidence,
documents=list(edge.documents),
metadata=edge.metadata
)
def _find_matching_entity(self, text: str, entities: List[Entity]) -> Optional[Entity]:
"""Find entity that matches the given text."""
text_lower = text.lower().strip()
# Exact match first
for entity in entities:
if entity.text.lower().strip() == text_lower:
return entity
# Partial match
for entity in entities:
if text_lower in entity.text.lower() or entity.text.lower() in text_lower:
return entity
return None
def _update_nodes(self, new_entities: Dict[str, List[Entity]]) -> None:
"""Update graph nodes with new entities."""
self._build_nodes(new_entities)
def _update_edges(self, new_documents: List[Document], new_entities: Dict[str, List[Entity]]) -> None:
"""Update graph edges with new documents and entities."""
self._build_edges(new_documents, new_entities)
# Mark documents as processed
for document in new_documents:
self.processed_documents.add(document.metadata.get("id", "unknown"))
def _prune_graph(self) -> None:
"""Prune graph to keep it within size limits."""
if len(self.nodes) <= self.config.max_graph_size:
return
start_time = time.time()
# Calculate node importance scores
node_scores = {}
for node_id, node in self.nodes.items():
# Score based on frequency, confidence, and connectivity
degree = self.graph.degree(node_id) if self.graph.has_node(node_id) else 0
score = (
node.frequency * 0.4 +
node.confidence * 0.3 +
degree * 0.3
)
node_scores[node_id] = score
# Keep top nodes
nodes_to_keep = sorted(node_scores.items(), key=lambda x: x[1], reverse=True)
nodes_to_keep = nodes_to_keep[:self.config.max_graph_size]
keep_ids = set(node_id for node_id, _ in nodes_to_keep)
# Remove low-importance nodes
nodes_to_remove = set(self.nodes.keys()) - keep_ids
for node_id in nodes_to_remove:
# Remove from our tracking
if node_id in self.nodes:
del self.nodes[node_id]
# Remove from NetworkX graph
if self.graph.has_node(node_id):
self.graph.remove_node(node_id)
# Clean up edges
edges_to_remove = []
for edge_key, edge in self.edges.items():
if edge.source not in keep_ids or edge.target not in keep_ids:
edges_to_remove.append(edge_key)
for edge_key in edges_to_remove:
del self.edges[edge_key]
pruning_time = time.time() - start_time
self.stats["pruning_operations"] += 1
logger.info(
f"Pruned graph in {pruning_time:.3f}s "
f"(removed {len(nodes_to_remove)} nodes, {len(edges_to_remove)} edges)"
)
def _normalize_entity_text(self, text: str, entity_type: str) -> str:
"""Normalize entity text for consistent node creation."""
# Basic normalization
normalized = text.lower().strip()
# Remove common prefixes/suffixes for technical terms
if entity_type in ["TECH", "PROTOCOL"]:
normalized = normalized.replace("the ", "").replace(" extension", "")
return normalized
def _generate_node_id(self, normalized_text: str) -> str:
"""Generate unique node ID from normalized text."""
return hashlib.md5(normalized_text.encode()).hexdigest()[:12]
def _map_entity_to_node_type(self, entity_type: str) -> str:
"""Map entity types to node types."""
mapping = {
"TECH": "concept",
"PROTOCOL": "protocol",
"ARCH": "architecture",
"EXTENSION": "extension"
}
return mapping.get(entity_type, "concept")
def _determine_edge_type(self, entity1: Entity, entity2: Entity) -> str:
"""Determine edge type between two entities."""
# Simple heuristics based on entity types
if entity1.label == "EXTENSION" or entity2.label == "EXTENSION":
return "extends"
elif entity1.label == "PROTOCOL" and entity2.label == "ARCH":
return "implements"
elif entity1.label == "ARCH" and entity2.label == "TECH":
return "requires"
else:
return "relates_to" # Default relationship
def _calculate_edge_weight(self, entity1: Entity, entity2: Entity) -> float:
"""Calculate edge weight between two entities."""
# Weight based on entity confidence and proximity
base_weight = (entity1.confidence + entity2.confidence) / 2
# Boost weight if entities are close in text
if hasattr(entity1, 'start_pos') and hasattr(entity2, 'start_pos'):
distance = abs(entity1.start_pos - entity2.start_pos)
if distance < 100: # Close entities
proximity_bonus = 0.2
else:
proximity_bonus = 0.0
else:
proximity_bonus = 0.0
return min(base_weight + proximity_bonus, 1.0)
def _update_statistics(self, documents: List[Document], construction_time: float) -> None:
"""Update graph construction statistics."""
self.stats["documents_processed"] += len(documents)
self.stats["total_nodes"] = len(self.nodes)
self.stats["total_edges"] = len(self.edges)
self.stats["construction_time"] += construction_time
# Estimate memory usage (rough approximation)
node_memory = len(self.nodes) * 200 # Bytes per node
edge_memory = len(self.edges) * 150 # Bytes per edge
self.stats["memory_usage_mb"] = (node_memory + edge_memory) / (1024 * 1024)
# Update processed documents
for document in documents:
self.processed_documents.add(document.metadata.get("id", "unknown"))
def get_graph_statistics(self) -> Dict[str, Any]:
"""
Get comprehensive graph statistics.
Returns:
Dictionary with graph statistics
"""
stats = self.stats.copy()
# Add NetworkX graph metrics
if self.graph:
stats["networkx_nodes"] = self.graph.number_of_nodes()
stats["networkx_edges"] = self.graph.number_of_edges()
if self.graph.number_of_nodes() > 0:
stats["avg_degree"] = sum(dict(self.graph.degree()).values()) / self.graph.number_of_nodes()
stats["density"] = nx.density(self.graph)
# Calculate connected components
if self.graph.is_directed():
stats["strongly_connected_components"] = nx.number_strongly_connected_components(self.graph)
stats["weakly_connected_components"] = nx.number_weakly_connected_components(self.graph)
else:
stats["connected_components"] = nx.number_connected_components(self.graph)
else:
stats["avg_degree"] = 0.0
stats["density"] = 0.0
stats["connected_components"] = 0
# Node type distribution
node_type_counts = defaultdict(int)
for node in self.nodes.values():
node_type_counts[node.node_type] += 1
stats["node_type_distribution"] = dict(node_type_counts)
# Edge type distribution
edge_type_counts = defaultdict(int)
for edge in self.edges.values():
edge_type_counts[edge.edge_type] += 1
stats["edge_type_distribution"] = dict(edge_type_counts)
return stats
def get_graph(self) -> nx.DiGraph:
"""Get the constructed NetworkX graph."""
return self.graph
def get_subgraph(self, node_ids: List[str], radius: int = 1) -> nx.DiGraph:
"""
Get subgraph around specified nodes.
Args:
node_ids: List of central node IDs
radius: Distance from central nodes to include
Returns:
NetworkX subgraph
"""
if not node_ids or not self.graph:
return nx.DiGraph()
# Find all nodes within radius
subgraph_nodes = set(node_ids)
for _ in range(radius):
new_nodes = set()
for node_id in subgraph_nodes:
if self.graph.has_node(node_id):
# Add neighbors
new_nodes.update(self.graph.neighbors(node_id))
new_nodes.update(self.graph.predecessors(node_id))
subgraph_nodes.update(new_nodes)
# Create subgraph
return self.graph.subgraph(subgraph_nodes).copy()
def reset_graph(self) -> None:
"""Reset the graph and all tracking data."""
self.graph.clear()
self.nodes.clear()
self.edges.clear()
self.document_entities.clear()
self.processed_documents.clear()
# Reset statistics except configuration-dependent ones
self.stats = {
"documents_processed": 0,
"total_nodes": 0,
"total_edges": 0,
"construction_time": 0.0,
"last_update_time": 0.0,
"memory_usage_mb": 0.0,
"pruning_operations": 0
}
logger.info("Graph reset completed")