import torch | |
### Based on https://arxiv.org/pdf/2205.03169 | |
def nt_xent_loss(sim, temperature): | |
sim = sim / temperature | |
n = sim.shape[0] // 2 # n = |user_batch| | |
aligment_loss = -torch.mean(sim[torch.arange(n), torch.arange(n)+n]) | |
mask = torch.diag(torch.ones(2*n, dtype=torch.bool)).to(sim.device) | |
sim = torch.where(mask, -torch.inf, sim) | |
sim = sim[:n, :] | |
distribution_loss = torch.mean(torch.logsumexp(sim, dim=1)) | |
loss = aligment_loss + distribution_loss | |
return loss |