Spaces:
Sleeping
Sleeping
""" | |
Document graph builder for Epic 2 Week 2. | |
This module provides graph construction capabilities for technical documents, | |
using NetworkX to build knowledge graphs that capture relationships between | |
concepts, protocols, architectures, and extensions in RISC-V documentation. | |
""" | |
import logging | |
import time | |
import hashlib | |
from typing import List, Dict, Any, Optional, Set, Tuple, Union | |
from dataclasses import dataclass, field | |
from collections import defaultdict | |
try: | |
import networkx as nx | |
import numpy as np | |
except ImportError: | |
nx = None | |
np = None | |
from src.core.interfaces import Document | |
from .config.graph_config import GraphBuilderConfig | |
from .entity_extraction import Entity, EntityExtractor | |
logger = logging.getLogger(__name__) | |
class GraphNode: | |
"""Represents a node in the document graph.""" | |
node_id: str | |
node_type: str # concept, protocol, architecture, extension | |
text: str | |
documents: Set[str] = field(default_factory=set) | |
frequency: int = 0 | |
confidence: float = 0.0 | |
metadata: Dict[str, Any] = field(default_factory=dict) | |
def __hash__(self) -> int: | |
return hash(self.node_id) | |
def __eq__(self, other) -> bool: | |
if not isinstance(other, GraphNode): | |
return False | |
return self.node_id == other.node_id | |
class GraphEdge: | |
"""Represents an edge in the document graph.""" | |
source: str | |
target: str | |
edge_type: str # implements, extends, requires, conflicts | |
weight: float | |
confidence: float | |
documents: Set[str] = field(default_factory=set) | |
metadata: Dict[str, Any] = field(default_factory=dict) | |
def __hash__(self) -> int: | |
return hash((self.source, self.target, self.edge_type)) | |
class DocumentGraphBuilderError(Exception): | |
"""Raised when graph construction operations fail.""" | |
pass | |
class DocumentGraphBuilder: | |
""" | |
Builds knowledge graphs from technical documents. | |
This class constructs NetworkX graphs that capture semantic relationships | |
between technical concepts in RISC-V documentation. It processes entities | |
extracted from documents and builds a graph structure that can be used | |
for graph-based retrieval and analysis. | |
Features: | |
- NetworkX-based graph construction | |
- Support for multiple node types (concept, protocol, architecture, extension) | |
- Multiple relationship types (implements, extends, requires, conflicts) | |
- Incremental graph updates | |
- Memory optimization with graph pruning | |
- Performance monitoring and statistics | |
""" | |
def __init__(self, config: GraphBuilderConfig, entity_extractor: EntityExtractor): | |
""" | |
Initialize the document graph builder. | |
Args: | |
config: Graph builder configuration | |
entity_extractor: Entity extractor for processing documents | |
""" | |
if nx is None: | |
raise DocumentGraphBuilderError("NetworkX is not installed. Install with: pip install networkx") | |
self.config = config | |
self.entity_extractor = entity_extractor | |
# Initialize graph | |
self.graph = nx.DiGraph() # Directed graph for relationships | |
# Node and edge tracking | |
self.nodes: Dict[str, GraphNode] = {} | |
self.edges: Dict[Tuple[str, str, str], GraphEdge] = {} | |
# Document tracking | |
self.document_entities: Dict[str, List[Entity]] = {} | |
self.processed_documents: Set[str] = set() | |
# Statistics | |
self.stats = { | |
"documents_processed": 0, | |
"total_nodes": 0, | |
"total_edges": 0, | |
"construction_time": 0.0, | |
"last_update_time": 0.0, | |
"memory_usage_mb": 0.0, | |
"pruning_operations": 0 | |
} | |
logger.info(f"DocumentGraphBuilder initialized with {self.config.implementation} backend") | |
def build_graph(self, documents: List[Document]) -> nx.DiGraph: | |
""" | |
Build knowledge graph from documents. | |
Args: | |
documents: List of documents to process | |
Returns: | |
NetworkX directed graph | |
""" | |
if not documents: | |
logger.warning("No documents provided for graph construction") | |
return self.graph | |
start_time = time.time() | |
try: | |
logger.info(f"Building graph from {len(documents)} documents") | |
# Extract entities from all documents | |
document_entities = self.entity_extractor.extract_entities(documents) | |
# Build nodes from entities | |
self._build_nodes(document_entities) | |
# Build edges from co-occurrence and semantic relationships | |
self._build_edges(documents, document_entities) | |
# Prune graph if enabled | |
if self.config.enable_pruning: | |
self._prune_graph() | |
# Update statistics | |
construction_time = time.time() - start_time | |
self._update_statistics(documents, construction_time) | |
logger.info( | |
f"Graph construction completed in {construction_time:.3f}s " | |
f"({len(self.nodes)} nodes, {len(self.edges)} edges)" | |
) | |
return self.graph | |
except Exception as e: | |
logger.error(f"Graph construction failed: {str(e)}") | |
raise DocumentGraphBuilderError(f"Failed to build graph: {str(e)}") from e | |
def update_graph(self, new_documents: List[Document]) -> nx.DiGraph: | |
""" | |
Incrementally update graph with new documents. | |
Args: | |
new_documents: List of new documents to add | |
Returns: | |
Updated NetworkX directed graph | |
""" | |
if not new_documents: | |
return self.graph | |
# Filter out already processed documents | |
unprocessed_docs = [ | |
doc for doc in new_documents | |
if doc.metadata.get("id", "unknown") not in self.processed_documents | |
] | |
if not unprocessed_docs: | |
logger.info("All documents already processed") | |
return self.graph | |
start_time = time.time() | |
try: | |
logger.info(f"Updating graph with {len(unprocessed_docs)} new documents") | |
# Extract entities from new documents | |
new_entities = self.entity_extractor.extract_entities(unprocessed_docs) | |
# Update nodes | |
self._update_nodes(new_entities) | |
# Update edges | |
self._update_edges(unprocessed_docs, new_entities) | |
# Prune if necessary | |
if self.config.enable_pruning and len(self.nodes) > self.config.max_graph_size: | |
self._prune_graph() | |
# Update statistics | |
update_time = time.time() - start_time | |
self.stats["last_update_time"] = update_time | |
self.stats["documents_processed"] += len(unprocessed_docs) | |
logger.info(f"Graph updated in {update_time:.3f}s") | |
return self.graph | |
except Exception as e: | |
logger.error(f"Graph update failed: {str(e)}") | |
raise DocumentGraphBuilderError(f"Failed to update graph: {str(e)}") from e | |
def _build_nodes(self, document_entities: Dict[str, List[Entity]]) -> None: | |
"""Build graph nodes from extracted entities.""" | |
entity_groups = defaultdict(list) | |
# Group entities by normalized text and type | |
for doc_id, entities in document_entities.items(): | |
for entity in entities: | |
key = self._normalize_entity_text(entity.text, entity.label) | |
entity_groups[key].append((entity, doc_id)) | |
# Create nodes from entity groups | |
for normalized_key, entity_instances in entity_groups.items(): | |
node_id = self._generate_node_id(normalized_key) | |
# Aggregate information from all instances | |
entity_types = [e[0].label for e in entity_instances] | |
most_common_type = max(set(entity_types), key=entity_types.count) | |
# Get representative text (longest variant) | |
texts = [e[0].text for e in entity_instances] | |
representative_text = max(texts, key=len) | |
# Calculate aggregate confidence | |
confidences = [e[0].confidence for e in entity_instances] | |
avg_confidence = sum(confidences) / len(confidences) | |
# Get all documents containing this entity | |
documents = set(e[1] for e in entity_instances) | |
# Create node | |
node = GraphNode( | |
node_id=node_id, | |
node_type=self._map_entity_to_node_type(most_common_type), | |
text=representative_text, | |
documents=documents, | |
frequency=len(entity_instances), | |
confidence=avg_confidence, | |
metadata={ | |
"entity_variants": list(set(texts)), | |
"source_types": list(set(entity_types)) | |
} | |
) | |
self.nodes[node_id] = node | |
# Add node to NetworkX graph | |
self.graph.add_node( | |
node_id, | |
node_type=node.node_type, | |
text=node.text, | |
frequency=node.frequency, | |
confidence=node.confidence, | |
documents=list(node.documents), | |
metadata=node.metadata | |
) | |
logger.info(f"Created {len(self.nodes)} nodes from entities") | |
def _build_edges(self, documents: List[Document], document_entities: Dict[str, List[Entity]]) -> None: | |
"""Build graph edges from entity co-occurrence and relationships.""" | |
# Store document entities for edge building | |
self.document_entities.update(document_entities) | |
# Build edges from co-occurrence within documents | |
for doc_id, entities in document_entities.items(): | |
if len(entities) < 2: | |
continue | |
# Create edges between entities in the same document | |
for i, entity1 in enumerate(entities): | |
for entity2 in entities[i+1:]: | |
self._create_co_occurrence_edge(entity1, entity2, doc_id) | |
# Build semantic relationship edges | |
self._build_semantic_edges(documents, document_entities) | |
logger.info(f"Created {len(self.edges)} edges from relationships") | |
def _create_co_occurrence_edge(self, entity1: Entity, entity2: Entity, doc_id: str) -> None: | |
"""Create co-occurrence edge between two entities.""" | |
node1_id = self._generate_node_id(self._normalize_entity_text(entity1.text, entity1.label)) | |
node2_id = self._generate_node_id(self._normalize_entity_text(entity2.text, entity2.label)) | |
if node1_id == node2_id: | |
return # Skip self-loops | |
# Determine edge type based on entity types and context | |
edge_type = self._determine_edge_type(entity1, entity2) | |
# Calculate edge weight and confidence | |
weight = self._calculate_edge_weight(entity1, entity2) | |
confidence = min(entity1.confidence, entity2.confidence) * 0.8 # Co-occurrence has lower confidence | |
# Create or update edge | |
edge_key = (node1_id, node2_id, edge_type) | |
if edge_key in self.edges: | |
# Update existing edge | |
edge = self.edges[edge_key] | |
edge.weight = max(edge.weight, weight) # Keep highest weight | |
edge.confidence = max(edge.confidence, confidence) | |
edge.documents.add(doc_id) | |
else: | |
# Create new edge | |
edge = GraphEdge( | |
source=node1_id, | |
target=node2_id, | |
edge_type=edge_type, | |
weight=weight, | |
confidence=confidence, | |
documents={doc_id}, | |
metadata={"creation_type": "co_occurrence"} | |
) | |
self.edges[edge_key] = edge | |
# Add edge to NetworkX graph | |
self.graph.add_edge( | |
node1_id, | |
node2_id, | |
edge_type=edge_type, | |
weight=weight, | |
confidence=confidence, | |
documents=list(edge.documents), | |
metadata=edge.metadata | |
) | |
def _build_semantic_edges(self, documents: List[Document], document_entities: Dict[str, List[Entity]]) -> None: | |
"""Build semantic relationship edges using text analysis.""" | |
# This is a simplified implementation - in practice, you might use | |
# more sophisticated NLP techniques for relationship extraction | |
relationship_patterns = { | |
"implements": [ | |
r"(\w+)\s+implements?\s+(\w+)", | |
r"(\w+)\s+implementation\s+of\s+(\w+)", | |
], | |
"extends": [ | |
r"(\w+)\s+extends?\s+(\w+)", | |
r"(\w+)\s+extension\s+of\s+(\w+)", | |
r"(\w+)\s+based\s+on\s+(\w+)", | |
], | |
"requires": [ | |
r"(\w+)\s+requires?\s+(\w+)", | |
r"(\w+)\s+depends?\s+on\s+(\w+)", | |
r"(\w+)\s+needs?\s+(\w+)", | |
], | |
"conflicts": [ | |
r"(\w+)\s+conflicts?\s+with\s+(\w+)", | |
r"(\w+)\s+incompatible\s+with\s+(\w+)", | |
] | |
} | |
import re | |
for doc_id, entities in document_entities.items(): | |
# Find document by metadata id | |
document = None | |
for doc in documents: | |
if doc.metadata.get("id", "unknown") == doc_id: | |
document = doc | |
break | |
if document is None: | |
continue | |
content = document.content.lower() | |
# Find explicit relationships in text | |
for relationship_type, patterns in relationship_patterns.items(): | |
for pattern in patterns: | |
matches = re.findall(pattern, content, re.IGNORECASE) | |
for match in matches: | |
if len(match) == 2: | |
source_text, target_text = match | |
self._create_semantic_edge( | |
source_text, target_text, relationship_type, | |
entities, doc_id | |
) | |
def _create_semantic_edge(self, source_text: str, target_text: str, | |
relationship_type: str, entities: List[Entity], doc_id: str) -> None: | |
"""Create semantic relationship edge between entities.""" | |
# Find matching entities | |
source_entity = self._find_matching_entity(source_text, entities) | |
target_entity = self._find_matching_entity(target_text, entities) | |
if not source_entity or not target_entity: | |
return | |
source_id = self._generate_node_id(self._normalize_entity_text(source_entity.text, source_entity.label)) | |
target_id = self._generate_node_id(self._normalize_entity_text(target_entity.text, target_entity.label)) | |
if source_id == target_id: | |
return | |
# High confidence for explicit relationships | |
confidence = 0.9 | |
weight = 1.0 | |
edge_key = (source_id, target_id, relationship_type) | |
if edge_key in self.edges: | |
edge = self.edges[edge_key] | |
edge.confidence = max(edge.confidence, confidence) | |
edge.documents.add(doc_id) | |
else: | |
edge = GraphEdge( | |
source=source_id, | |
target=target_id, | |
edge_type=relationship_type, | |
weight=weight, | |
confidence=confidence, | |
documents={doc_id}, | |
metadata={"creation_type": "semantic_pattern"} | |
) | |
self.edges[edge_key] = edge | |
# Add edge to NetworkX graph | |
self.graph.add_edge( | |
source_id, | |
target_id, | |
edge_type=relationship_type, | |
weight=weight, | |
confidence=confidence, | |
documents=list(edge.documents), | |
metadata=edge.metadata | |
) | |
def _find_matching_entity(self, text: str, entities: List[Entity]) -> Optional[Entity]: | |
"""Find entity that matches the given text.""" | |
text_lower = text.lower().strip() | |
# Exact match first | |
for entity in entities: | |
if entity.text.lower().strip() == text_lower: | |
return entity | |
# Partial match | |
for entity in entities: | |
if text_lower in entity.text.lower() or entity.text.lower() in text_lower: | |
return entity | |
return None | |
def _update_nodes(self, new_entities: Dict[str, List[Entity]]) -> None: | |
"""Update graph nodes with new entities.""" | |
self._build_nodes(new_entities) | |
def _update_edges(self, new_documents: List[Document], new_entities: Dict[str, List[Entity]]) -> None: | |
"""Update graph edges with new documents and entities.""" | |
self._build_edges(new_documents, new_entities) | |
# Mark documents as processed | |
for document in new_documents: | |
self.processed_documents.add(document.metadata.get("id", "unknown")) | |
def _prune_graph(self) -> None: | |
"""Prune graph to keep it within size limits.""" | |
if len(self.nodes) <= self.config.max_graph_size: | |
return | |
start_time = time.time() | |
# Calculate node importance scores | |
node_scores = {} | |
for node_id, node in self.nodes.items(): | |
# Score based on frequency, confidence, and connectivity | |
degree = self.graph.degree(node_id) if self.graph.has_node(node_id) else 0 | |
score = ( | |
node.frequency * 0.4 + | |
node.confidence * 0.3 + | |
degree * 0.3 | |
) | |
node_scores[node_id] = score | |
# Keep top nodes | |
nodes_to_keep = sorted(node_scores.items(), key=lambda x: x[1], reverse=True) | |
nodes_to_keep = nodes_to_keep[:self.config.max_graph_size] | |
keep_ids = set(node_id for node_id, _ in nodes_to_keep) | |
# Remove low-importance nodes | |
nodes_to_remove = set(self.nodes.keys()) - keep_ids | |
for node_id in nodes_to_remove: | |
# Remove from our tracking | |
if node_id in self.nodes: | |
del self.nodes[node_id] | |
# Remove from NetworkX graph | |
if self.graph.has_node(node_id): | |
self.graph.remove_node(node_id) | |
# Clean up edges | |
edges_to_remove = [] | |
for edge_key, edge in self.edges.items(): | |
if edge.source not in keep_ids or edge.target not in keep_ids: | |
edges_to_remove.append(edge_key) | |
for edge_key in edges_to_remove: | |
del self.edges[edge_key] | |
pruning_time = time.time() - start_time | |
self.stats["pruning_operations"] += 1 | |
logger.info( | |
f"Pruned graph in {pruning_time:.3f}s " | |
f"(removed {len(nodes_to_remove)} nodes, {len(edges_to_remove)} edges)" | |
) | |
def _normalize_entity_text(self, text: str, entity_type: str) -> str: | |
"""Normalize entity text for consistent node creation.""" | |
# Basic normalization | |
normalized = text.lower().strip() | |
# Remove common prefixes/suffixes for technical terms | |
if entity_type in ["TECH", "PROTOCOL"]: | |
normalized = normalized.replace("the ", "").replace(" extension", "") | |
return normalized | |
def _generate_node_id(self, normalized_text: str) -> str: | |
"""Generate unique node ID from normalized text.""" | |
return hashlib.md5(normalized_text.encode()).hexdigest()[:12] | |
def _map_entity_to_node_type(self, entity_type: str) -> str: | |
"""Map entity types to node types.""" | |
mapping = { | |
"TECH": "concept", | |
"PROTOCOL": "protocol", | |
"ARCH": "architecture", | |
"EXTENSION": "extension" | |
} | |
return mapping.get(entity_type, "concept") | |
def _determine_edge_type(self, entity1: Entity, entity2: Entity) -> str: | |
"""Determine edge type between two entities.""" | |
# Simple heuristics based on entity types | |
if entity1.label == "EXTENSION" or entity2.label == "EXTENSION": | |
return "extends" | |
elif entity1.label == "PROTOCOL" and entity2.label == "ARCH": | |
return "implements" | |
elif entity1.label == "ARCH" and entity2.label == "TECH": | |
return "requires" | |
else: | |
return "relates_to" # Default relationship | |
def _calculate_edge_weight(self, entity1: Entity, entity2: Entity) -> float: | |
"""Calculate edge weight between two entities.""" | |
# Weight based on entity confidence and proximity | |
base_weight = (entity1.confidence + entity2.confidence) / 2 | |
# Boost weight if entities are close in text | |
if hasattr(entity1, 'start_pos') and hasattr(entity2, 'start_pos'): | |
distance = abs(entity1.start_pos - entity2.start_pos) | |
if distance < 100: # Close entities | |
proximity_bonus = 0.2 | |
else: | |
proximity_bonus = 0.0 | |
else: | |
proximity_bonus = 0.0 | |
return min(base_weight + proximity_bonus, 1.0) | |
def _update_statistics(self, documents: List[Document], construction_time: float) -> None: | |
"""Update graph construction statistics.""" | |
self.stats["documents_processed"] += len(documents) | |
self.stats["total_nodes"] = len(self.nodes) | |
self.stats["total_edges"] = len(self.edges) | |
self.stats["construction_time"] += construction_time | |
# Estimate memory usage (rough approximation) | |
node_memory = len(self.nodes) * 200 # Bytes per node | |
edge_memory = len(self.edges) * 150 # Bytes per edge | |
self.stats["memory_usage_mb"] = (node_memory + edge_memory) / (1024 * 1024) | |
# Update processed documents | |
for document in documents: | |
self.processed_documents.add(document.metadata.get("id", "unknown")) | |
def get_graph_statistics(self) -> Dict[str, Any]: | |
""" | |
Get comprehensive graph statistics. | |
Returns: | |
Dictionary with graph statistics | |
""" | |
stats = self.stats.copy() | |
# Add NetworkX graph metrics | |
if self.graph: | |
stats["networkx_nodes"] = self.graph.number_of_nodes() | |
stats["networkx_edges"] = self.graph.number_of_edges() | |
if self.graph.number_of_nodes() > 0: | |
stats["avg_degree"] = sum(dict(self.graph.degree()).values()) / self.graph.number_of_nodes() | |
stats["density"] = nx.density(self.graph) | |
# Calculate connected components | |
if self.graph.is_directed(): | |
stats["strongly_connected_components"] = nx.number_strongly_connected_components(self.graph) | |
stats["weakly_connected_components"] = nx.number_weakly_connected_components(self.graph) | |
else: | |
stats["connected_components"] = nx.number_connected_components(self.graph) | |
else: | |
stats["avg_degree"] = 0.0 | |
stats["density"] = 0.0 | |
stats["connected_components"] = 0 | |
# Node type distribution | |
node_type_counts = defaultdict(int) | |
for node in self.nodes.values(): | |
node_type_counts[node.node_type] += 1 | |
stats["node_type_distribution"] = dict(node_type_counts) | |
# Edge type distribution | |
edge_type_counts = defaultdict(int) | |
for edge in self.edges.values(): | |
edge_type_counts[edge.edge_type] += 1 | |
stats["edge_type_distribution"] = dict(edge_type_counts) | |
return stats | |
def get_graph(self) -> nx.DiGraph: | |
"""Get the constructed NetworkX graph.""" | |
return self.graph | |
def get_subgraph(self, node_ids: List[str], radius: int = 1) -> nx.DiGraph: | |
""" | |
Get subgraph around specified nodes. | |
Args: | |
node_ids: List of central node IDs | |
radius: Distance from central nodes to include | |
Returns: | |
NetworkX subgraph | |
""" | |
if not node_ids or not self.graph: | |
return nx.DiGraph() | |
# Find all nodes within radius | |
subgraph_nodes = set(node_ids) | |
for _ in range(radius): | |
new_nodes = set() | |
for node_id in subgraph_nodes: | |
if self.graph.has_node(node_id): | |
# Add neighbors | |
new_nodes.update(self.graph.neighbors(node_id)) | |
new_nodes.update(self.graph.predecessors(node_id)) | |
subgraph_nodes.update(new_nodes) | |
# Create subgraph | |
return self.graph.subgraph(subgraph_nodes).copy() | |
def reset_graph(self) -> None: | |
"""Reset the graph and all tracking data.""" | |
self.graph.clear() | |
self.nodes.clear() | |
self.edges.clear() | |
self.document_entities.clear() | |
self.processed_documents.clear() | |
# Reset statistics except configuration-dependent ones | |
self.stats = { | |
"documents_processed": 0, | |
"total_nodes": 0, | |
"total_edges": 0, | |
"construction_time": 0.0, | |
"last_update_time": 0.0, | |
"memory_usage_mb": 0.0, | |
"pruning_operations": 0 | |
} | |
logger.info("Graph reset completed") |