graph-rec / exp /gnn /loss.py
erermeev-d
Refactored experiments code
c746c39
raw
history blame contribute delete
509 Bytes
import torch
### Based on https://arxiv.org/pdf/2205.03169
def nt_xent_loss(sim, temperature):
sim = sim / temperature
n = sim.shape[0] // 2 # n = |user_batch|
aligment_loss = -torch.mean(sim[torch.arange(n), torch.arange(n)+n])
mask = torch.diag(torch.ones(2*n, dtype=torch.bool)).to(sim.device)
sim = torch.where(mask, -torch.inf, sim)
sim = sim[:n, :]
distribution_loss = torch.mean(torch.logsumexp(sim, dim=1))
loss = aligment_loss + distribution_loss
return loss