falcon-api / src /simclr.py
ngocminhta
upload hf
3fef185
import torch
import torch.nn as nn
import torch.nn.functional as F
from src.text_embedding import TextEmbeddingModel
class ClassificationHead(nn.Module):
"""Head for sentence-level classification tasks."""
def __init__(self, in_dim, out_dim):
super(ClassificationHead, self).__init__()
self.dense1 = nn.Linear(in_dim, in_dim//4)
self.dense2 = nn.Linear(in_dim//4, in_dim//16)
self.out_proj = nn.Linear(in_dim//16, out_dim)
nn.init.xavier_uniform_(self.dense1.weight)
nn.init.xavier_uniform_(self.dense2.weight)
nn.init.xavier_uniform_(self.out_proj.weight)
nn.init.normal_(self.dense1.bias, std=1e-6)
nn.init.normal_(self.dense2.bias, std=1e-6)
nn.init.normal_(self.out_proj.bias, std=1e-6)
def forward(self, features):
x = features
x = self.dense1(x)
x = torch.tanh(x)
x = self.dense2(x)
x = torch.tanh(x)
x = self.out_proj(x)
return x
class SimCLR_Classifier_SCL(nn.Module):
def __init__(self, opt,fabric):
super(SimCLR_Classifier_SCL, self).__init__()
self.temperature = opt.temperature
self.opt=opt
self.fabric = fabric
self.model = TextEmbeddingModel(opt.model_name)
self.device=self.model.model.device
if opt.resum:
state_dict = torch.load(opt.pth_path, map_location=self.device)
self.model.load_state_dict(state_dict)
self.esp=torch.tensor(1e-6,device=self.device)
self.classifier = ClassificationHead(opt.projection_size, opt.classifier_dim)
self.a=torch.tensor(opt.a,device=self.device)
self.d=torch.tensor(opt.d,device=self.device)
self.only_classifier=opt.only_classifier
def get_encoder(self):
return self.model
def _compute_logits(self, q,q_index1, q_index2,q_label,k,k_index1,k_index2,k_label):
def cosine_similarity_matrix(q, k):
q_norm = F.normalize(q,dim=-1)
k_norm = F.normalize(k,dim=-1)
cosine_similarity = q_norm@k_norm.T
return cosine_similarity
logits=cosine_similarity_matrix(q,k)/self.temperature
q_labels=q_label.view(-1, 1)# N,1
k_labels=k_label.view(1, -1)# 1,N+K
same_label=(q_labels==k_labels)# N,N+K
#model:model set
pos_logits_model = torch.sum(logits*same_label,dim=1)/torch.max(torch.sum(same_label,dim=1),self.esp)
neg_logits_model=logits*torch.logical_not(same_label)
logits_model=torch.cat((pos_logits_model.unsqueeze(1), neg_logits_model), dim=1)
return logits_model
def forward(self, batch, indices1, indices2,label):
bsz = batch['input_ids'].size(0)
q = self.model(batch)
k = q.clone().detach()
k = self.fabric.all_gather(k).view(-1, k.size(1))
k_label = self.fabric.all_gather(label).view(-1)
k_index1 = self.fabric.all_gather(indices1).view(-1)
k_index2 = self.fabric.all_gather(indices2).view(-1)
#q:N
#k:4N
logits_label = self._compute_logits(q,indices1, indices2,label,k,k_index1,k_index2,k_label)
out = self.classifier(q)
if self.opt.AA:
loss_classfiy = F.cross_entropy(out, indices1)
else:
loss_classfiy = F.cross_entropy(out, label)
gt = torch.zeros(bsz, dtype=torch.long,device=logits_label.device)
if self.only_classifier:
loss_label = torch.tensor(0,device=self.device)
else:
loss_label = F.cross_entropy(logits_label, gt)
loss = self.a*loss_label+self.d*loss_classfiy
if self.training:
return loss,loss_label,loss_classfiy,k,k_label
else:
out = self.fabric.all_gather(out).view(-1, out.size(1))
return loss,out,k,k_label
class SimCLR_Classifier_test(nn.Module):
def __init__(self, opt,fabric):
super(SimCLR_Classifier_test, self).__init__()
self.fabric = fabric
self.model = TextEmbeddingModel(opt.model_name)
self.classifier = ClassificationHead(opt.projection_size, opt.classifier_dim)
self.device=self.model.model.device
def forward(self, batch):
q = self.model(batch)
out = self.classifier(q)
return out
class SimCLR_Classifier(nn.Module):
def __init__(self, opt,fabric):
super(SimCLR_Classifier, self).__init__()
self.temperature = opt.temperature
self.opt=opt
self.fabric = fabric
self.model = TextEmbeddingModel(opt.model_name)
if opt.resum:
state_dict = torch.load(opt.pth_path,
map_location=self.model.device)
self.model.load_state_dict(state_dict)
self.device = self.model.model.device
self.esp = torch.tensor(1e-6,device=self.device)
self.a = torch.tensor(opt.a,
device=self.device)
self.b = torch.tensor(opt.b,
device=self.device)
self.c = torch.tensor(opt.c,
device=self.device)
self.classifier = ClassificationHead(opt.projection_size,
opt.classifier_dim)
self.only_classifier = opt.only_classifier
def get_encoder(self):
return self.model
def _compute_logits(self,
q,q_index1, q_index2, q_label,
k,k_index1,k_index2,k_label):
def cosine_similarity_matrix(q, k):
q_norm = F.normalize(q,dim=-1)
k_norm = F.normalize(k,dim=-1)
cosine_similarity = q_norm@k_norm.T
return cosine_similarity
logits=cosine_similarity_matrix(q,k)/self.temperature
q_index1=q_index1.view(-1, 1)# change to tensor of size N, 1
q_index2=q_index2.view(-1, 1)# change to tensor of size N, 1
q_labels=q_label.view(-1, 1)# change to tensor of size N, 1
k_index1=k_index1.view(1, -1)# 1,N+K
k_index2=k_index2.view(1, -1) #1, N+K
k_labels=k_label.view(1, -1)# 1,N+K
same_mixed = (q_index1== k_index1)
same_set=(q_index2==k_index2)# N,N+K
same_label=(q_labels==k_labels)# N,N+K
is_human=(q_label==1).view(-1)
is_machine=(q_label==0).view(-1)
is_mixed=(q_index1==1).view(-1)
#human: human
pos_logits_human = torch.sum(logits*same_label,dim=1)/torch.max(torch.sum(same_label,dim=1),self.esp)
neg_logits_human=logits*torch.logical_not(same_label)
logits_human=torch.cat((pos_logits_human.unsqueeze(1), neg_logits_human), dim=1)
logits_human=logits_human[is_human]
#human+ai: general
pos_logits_mixed = torch.sum(logits*same_mixed,dim=1)/torch.maximum(torch.sum(same_mixed,dim=1),self.esp)
neg_logits_mixed=logits*torch.logical_not(same_mixed)
logits_mixed=torch.cat((pos_logits_mixed.unsqueeze(1), neg_logits_mixed), dim=1)
logits_mixed=logits_mixed[is_mixed]
#human+ai: model
pos_logits_mixed_set = torch.sum(logits*torch.logical_and(same_mixed, same_set),dim=1)/torch.max(torch.sum(torch.logical_and(same_mixed, same_set),dim=1),self.esp)
neg_logits_mixed_set=logits*torch.logical_not(torch.logical_and(same_mixed, same_set))
logits_mixed_set=torch.cat((pos_logits_mixed_set.unsqueeze(1), neg_logits_mixed_set), dim=1)
logits_mixed_set=logits_mixed_set[is_mixed]
#model set:label
pos_logits_set = torch.sum(logits*same_set,dim=1)/torch.max(torch.sum(same_set,dim=1),self.esp)
neg_logits_set=logits*torch.logical_not(same_set)
logits_set=torch.cat((pos_logits_set.unsqueeze(1), neg_logits_set), dim=1)
logits_set=logits_set[is_machine]
#label: label
pos_logits_label = torch.sum(logits*same_label, dim=1)/torch.max(torch.sum(same_label,dim=1),self.esp)
neg_logits_label=logits*torch.logical_not(same_label)
logits_label=torch.cat((pos_logits_label.unsqueeze(1), neg_logits_label), dim=1)
logits_label=logits_label[is_machine]
return logits_human, logits_mixed, logits_mixed_set, logits_set, logits_label
def forward(self, encoded_batch, label, indices1, indices2):#, weights):
# print(len(text))
q = self.model(encoded_batch)
k = q.clone().detach()
k = self.fabric.all_gather(k).view(-1, k.size(1))
k_label = self.fabric.all_gather(label).view(-1)
k_index1 = self.fabric.all_gather(indices1).view(-1)
k_index2 = self.fabric.all_gather(indices2).view(-1)
#q:N
#k:4N
logits_human, logits_mixed, logits_mixed_set, logits_set, logits_label = self._compute_logits(q,indices1, indices2,label,
k,k_index1,k_index2,k_label)
out = self.classifier(q)
if self.opt.AA:
loss_classfiy = F.cross_entropy(out, indices1)
else:
loss_classfiy = F.cross_entropy(out, label) #, weight=weights)
gt_mixed = torch.zeros(logits_mixed.size(0),
dtype=torch.long,
device=logits_mixed.device)
gt_mixed_set = torch.zeros(logits_mixed_set.size(0),
dtype=torch.long,
device=logits_mixed_set.device)
gt_set = torch.zeros(logits_set.size(0),
dtype=torch.long,
device=logits_set.device)
gt_label = torch.zeros(logits_label.size(0),
dtype=torch.long,
device=logits_label.device)
gt_human = torch.zeros(logits_human.size(0),
dtype=torch.long,
device=logits_human.device)
loss_mixed = F.cross_entropy(logits_mixed,
gt_mixed)
loss_mixed_set = F.cross_entropy(logits_mixed_set,
gt_mixed_set)
loss_set = F.cross_entropy(logits_set,
gt_set)
loss_label = F.cross_entropy(logits_label,
gt_label)
if logits_human.numel()!=0:
loss_human = F.cross_entropy(logits_human.to(torch.float64),
gt_human)
else:
loss_human=torch.tensor(0,device=self.device)
loss = self.a*loss_set + (4*self.b-self.a)*loss_label + self.b*loss_human+ self.b*loss_mixed + \
2*self.b*loss_mixed_set+self.c*loss_classfiy
if self.training:
if self.opt.AA:
return loss,loss_mixed, loss_mixed_set,loss_set,loss_label,loss_human,loss_classfiy,k,k_index1
else:
return loss,loss_mixed, loss_mixed_set,loss_set,loss_label,loss_classfiy,loss_human,k,k_label
else:
out = self.fabric.all_gather(out).view(-1, out.size(1))
if self.opt.AA:
return loss,out,k,k_index1
else:
return loss,out,k,k_label