|
import argparse |
|
import os |
|
import tempfile |
|
|
|
import numpy as np |
|
import pandas as pd |
|
import dgl |
|
import torch |
|
import wandb |
|
from tqdm.auto import tqdm |
|
|
|
from exp.utils import normalize_embeddings |
|
from exp.prepare_recsys import prepare_recsys |
|
from exp.evaluate import evaluate_recsys |
|
from exp.gnn.model import GNNModel |
|
from exp.gnn.loss import nt_xent_loss |
|
from exp.gnn.utils import ( |
|
prepare_graphs, LRSchedule, fix_random, |
|
sample_item_batch, inference_model) |
|
|
|
|
|
def prepare_gnn_embeddings(config): |
|
|
|
fix_random(config["seed"]) |
|
|
|
|
|
bipartite_graph, _ = prepare_graphs(config["items_path"], config["train_ratings_path"]) |
|
bipartite_graph = bipartite_graph.to(config["device"]) |
|
|
|
|
|
if config["use_wandb"]: |
|
wandb.init(project="graph-rec-gnn", name=config["wandb_name"], config=config) |
|
|
|
|
|
text_embeddings = torch.tensor(np.load(config["text_embeddings_path"])).to(config["device"]) |
|
model = GNNModel( |
|
bipartite_graph=bipartite_graph, |
|
text_embeddings=text_embeddings, |
|
num_layers=config["num_layers"], |
|
hidden_dim=config["hidden_dim"], |
|
aggregator_type=config["aggregator_type"], |
|
skip_connection=config["skip_connection"], |
|
bidirectional=config["bidirectional"], |
|
num_traversals=config["num_traversals"], |
|
termination_prob=config["termination_prob"], |
|
num_random_walks=config["num_random_walks"], |
|
num_neighbor=config["num_neighbor"] |
|
) |
|
model = model.to(config["device"]) |
|
|
|
|
|
all_users = torch.arange(bipartite_graph.num_nodes("User")).to(config["device"]) |
|
all_users = all_users[bipartite_graph.in_degrees(all_users, etype="ItemUser") > 1] |
|
dataloader = torch.utils.data.DataLoader( |
|
all_users, batch_size=config["batch_size"], shuffle=True, drop_last=True) |
|
|
|
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=config["lr"]) |
|
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda _: 1.0) |
|
|
|
|
|
for epoch in range(config["num_epochs"]): |
|
|
|
model.train() |
|
for user_batch in tqdm(dataloader): |
|
item_batch = sample_item_batch(user_batch, bipartite_graph) |
|
item_batch = item_batch.reshape(-1) |
|
features = model(item_batch) |
|
sim = features @ features.T |
|
loss = nt_xent_loss(sim, config["temperature"]) |
|
if config["use_wandb"]: |
|
wandb.log({"loss": loss.item()}) |
|
optimizer.zero_grad() |
|
loss.backward() |
|
optimizer.step() |
|
lr_scheduler.step() |
|
|
|
if (config["validate_every_n_epoch"] is not None) and (((epoch + 1) % config["validate_every_n_epoch"]) == 0): |
|
item_embeddings = inference_model( |
|
model, bipartite_graph, config["batch_size"], config["hidden_dim"], config["device"]) |
|
with tempfile.TemporaryDirectory() as tmp_dir_name: |
|
tmp_embeddings_path = os.path.join(tmp_dir_name, "embeddings.npy") |
|
np.save(tmp_embeddings_path, item_embeddings) |
|
prepare_recsys(config["items_path"], tmp_embeddings_path, tmp_dir_name) |
|
metrics = evaluate_recsys( |
|
config["val_ratings_path"], |
|
os.path.join(tmp_dir_name, "index.faiss"), |
|
os.path.join(tmp_dir_name, "items.db")) |
|
print(f"Epoch {epoch + 1} / {config['num_epochs']}. {metrics}") |
|
if config["use_wandb"]: |
|
wandb.log(metrics) |
|
|
|
if config["use_wandb"]: |
|
wandb.finish() |
|
|
|
|
|
item_embeddings = inference_model(model, bipartite_graph, config["batch_size"], config["hidden_dim"], config["device"]) |
|
np.save(config["embeddings_savepath"], item_embeddings) |
|
|
|
|
|
torch.save(model.to("cpu").state_dict(), config["model_savepath"]) |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser(description="Prepare GNN Embeddings") |
|
|
|
|
|
parser.add_argument("--items_path", type=str, required=True, help="Path to the items file") |
|
parser.add_argument("--train_ratings_path", type=str, required=True, help="Path to the train ratings file") |
|
parser.add_argument("--val_ratings_path", type=str, required=True, help="Path to the validation ratings file") |
|
parser.add_argument("--text_embeddings_path", type=str, required=True, help="Path to the text embeddings file") |
|
parser.add_argument("--embeddings_savepath", type=str, required=True, help="Path to the file where gnn embeddings will be saved") |
|
parser.add_argument("--model_savepath", type=str, required=True, help="Path to save final model checkpoint.") |
|
|
|
|
|
parser.add_argument("--temperature", type=float, default=0.1, help="Temperature for NT-Xent loss") |
|
parser.add_argument("--batch_size", type=int, default=512, help="Batch size for training") |
|
parser.add_argument("--lr", type=float, default=0.01, help="Learning rate") |
|
parser.add_argument("--num_epochs", type=int, default=4, help="Number of epochs") |
|
|
|
|
|
parser.add_argument("--num_layers", type=int, default=2, help="Number of layers in the model") |
|
parser.add_argument("--hidden_dim", type=int, default=64, help="Hidden dimension size") |
|
parser.add_argument("--aggregator_type", type=str, default="mean", help="Type of aggregator in SAGEConv") |
|
parser.add_argument("--skip_connection", action="store_true", dest="skip_connection", help="Disable skip connections") |
|
parser.add_argument("--no_bidirectional", action="store_false", dest="bidirectional", help="Do not use reversed edges in convolution") |
|
parser.add_argument("--num_traversals", type=int, default=4, help="Number of traversals in PinSAGE-like sampler") |
|
parser.add_argument("--termination_prob", type=float, default=0.5, help="Termination probability in PinSAGE-like sampler") |
|
parser.add_argument("--num_random_walks", type=int, default=200, help="Number of random walks in PinSAGE-like sampler") |
|
parser.add_argument("--num_neighbor", type=int, default=10, help="Number of neighbors in PinSAGE-like sampler") |
|
|
|
|
|
parser.add_argument("--seed", type=int, default=42, help="Random seed.") |
|
parser.add_argument("--validate_every_n_epoch", type=int, default=4, help="Perform RecSys validation every n train epochs.") |
|
parser.add_argument("--device", type=str, default="cpu", help="Device to run the model on (cpu or cuda)") |
|
parser.add_argument("--wandb_name", type=str, help="WandB run name") |
|
parser.add_argument("--no_wandb", action="store_false", dest="use_wandb", help="Disable WandB logging") |
|
|
|
args = parser.parse_args() |
|
|
|
prepare_gnn_embeddings(vars(args)) |