File size: 509 Bytes
c746c39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import torch


### Based on https://arxiv.org/pdf/2205.03169
def nt_xent_loss(sim, temperature):
    sim = sim / temperature
    n = sim.shape[0] // 2  # n = |user_batch|

    aligment_loss = -torch.mean(sim[torch.arange(n), torch.arange(n)+n])

    mask = torch.diag(torch.ones(2*n, dtype=torch.bool)).to(sim.device)
    sim = torch.where(mask, -torch.inf, sim)
    sim = sim[:n, :]
    distribution_loss = torch.mean(torch.logsumexp(sim, dim=1))

    loss = aligment_loss + distribution_loss
    return loss