graph-rec / exp /gnn /model.py
erermeev-d
Refactored experiments code
c746c39
import torch
import dgl
class GNNLayer(torch.nn.Module):
def __init__(self, hidden_dim, aggregator_type, skip_connection, bidirectional):
super().__init__()
self._skip_connection = skip_connection
self._bidirectional = bidirectional
self._conv = dgl.nn.SAGEConv(hidden_dim, hidden_dim, aggregator_type)
self._activation = torch.nn.ReLU()
if bidirectional:
self._conv_rev = dgl.nn.SAGEConv(hidden_dim, hidden_dim, aggregator_type)
self._activation_rev = torch.nn.ReLU()
def forward(self, graph, x):
edge_weights = graph.edata["weights"]
y = self._activation(self._conv(graph, x, edge_weights))
if self._bidirectional:
reversed_graph = dgl.reverse(graph, copy_edata=True)
edge_weights = reversed_graph.edata["weights"]
y = y + self._activation_rev(self._conv_rev(reversed_graph, x, edge_weights))
if self._skip_connection:
return x + y
else:
return y
class GNNModel(torch.nn.Module):
def __init__(
self,
bipartite_graph,
text_embeddings,
num_layers,
hidden_dim,
aggregator_type,
skip_connection,
bidirectional,
num_traversals,
termination_prob,
num_random_walks,
num_neighbor,
):
super().__init__()
self._bipartite_graph = bipartite_graph
self._text_embeddings = text_embeddings
self._sampler = dgl.sampling.PinSAGESampler(
bipartite_graph, "Item", "User", num_traversals,
termination_prob, num_random_walks, num_neighbor)
self._text_encoder = torch.nn.Linear(text_embeddings.shape[-1], hidden_dim)
self._layers = torch.nn.ModuleList()
for _ in range(num_layers):
self._layers.append(GNNLayer(
hidden_dim, aggregator_type, skip_connection, bidirectional))
def _sample_subraph(self, frontier_ids):
num_layers = len(self._layers)
device = self._bipartite_graph.device
subgraph = dgl.graph(([], []), num_nodes=self._bipartite_graph.num_nodes("Item")).to(device)
prev_ids = set()
weights = []
for _ in range(num_layers):
frontier_ids = torch.tensor(frontier_ids, dtype=torch.int64).to(device)
new_sample = self._sampler(frontier_ids)
new_weights = new_sample.edata["weights"]
new_edges = new_sample.edges()
subgraph.add_edges(*new_edges)
weights.append(new_weights)
prev_ids |= set(frontier_ids.cpu().tolist())
frontier_ids = set(dgl.compact_graphs(subgraph).ndata[dgl.NID].cpu().tolist())
frontier_ids = list(frontier_ids - prev_ids)
subgraph.edata["weights"] = torch.cat(weights, dim=0).to(torch.float32)
return subgraph
def forward(self, ids):
### Sample subgraph
sampled_subgraph = self._sample_subraph(ids)
sampled_subgraph = dgl.compact_graphs(sampled_subgraph, always_preserve=ids)
### Encode text embeddings
text_embeddings = self._text_embeddings[
sampled_subgraph.ndata[dgl.NID]]
features = self._text_encoder(text_embeddings)
### GNN goes brr...
for layer in self._layers:
features = layer(sampled_subgraph, features)
### Select features for initial ids
# TODO: write it more efficiently?
matches = sampled_subgraph.ndata[dgl.NID].unsqueeze(0) == ids.unsqueeze(1)
ids_in_subgraph = matches.nonzero(as_tuple=True)[1]
features = features[ids_in_subgraph]
### Normalize and return
features = features / torch.linalg.norm(features, dim=1, keepdim=True)
return features