Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from src.text_embedding import TextEmbeddingModel | |
class ClassificationHead(nn.Module): | |
"""Head for sentence-level classification tasks.""" | |
def __init__(self, in_dim, out_dim): | |
super(ClassificationHead, self).__init__() | |
self.dense1 = nn.Linear(in_dim, in_dim//4) | |
self.dense2 = nn.Linear(in_dim//4, in_dim//16) | |
self.out_proj = nn.Linear(in_dim//16, out_dim) | |
nn.init.xavier_uniform_(self.dense1.weight) | |
nn.init.xavier_uniform_(self.dense2.weight) | |
nn.init.xavier_uniform_(self.out_proj.weight) | |
nn.init.normal_(self.dense1.bias, std=1e-6) | |
nn.init.normal_(self.dense2.bias, std=1e-6) | |
nn.init.normal_(self.out_proj.bias, std=1e-6) | |
def forward(self, features): | |
x = features | |
x = self.dense1(x) | |
x = torch.tanh(x) | |
x = self.dense2(x) | |
x = torch.tanh(x) | |
x = self.out_proj(x) | |
return x | |
class SimCLR_Classifier_SCL(nn.Module): | |
def __init__(self, opt,fabric): | |
super(SimCLR_Classifier_SCL, self).__init__() | |
self.temperature = opt.temperature | |
self.opt=opt | |
self.fabric = fabric | |
self.model = TextEmbeddingModel(opt.model_name) | |
self.device=self.model.model.device | |
if opt.resum: | |
state_dict = torch.load(opt.pth_path, map_location=self.device) | |
self.model.load_state_dict(state_dict) | |
self.esp=torch.tensor(1e-6,device=self.device) | |
self.classifier = ClassificationHead(opt.projection_size, opt.classifier_dim) | |
self.a=torch.tensor(opt.a,device=self.device) | |
self.d=torch.tensor(opt.d,device=self.device) | |
self.only_classifier=opt.only_classifier | |
def get_encoder(self): | |
return self.model | |
def _compute_logits(self, q,q_index1, q_index2,q_label,k,k_index1,k_index2,k_label): | |
def cosine_similarity_matrix(q, k): | |
q_norm = F.normalize(q,dim=-1) | |
k_norm = F.normalize(k,dim=-1) | |
cosine_similarity = q_norm@k_norm.T | |
return cosine_similarity | |
logits=cosine_similarity_matrix(q,k)/self.temperature | |
q_labels=q_label.view(-1, 1)# N,1 | |
k_labels=k_label.view(1, -1)# 1,N+K | |
same_label=(q_labels==k_labels)# N,N+K | |
#model:model set | |
pos_logits_model = torch.sum(logits*same_label,dim=1)/torch.max(torch.sum(same_label,dim=1),self.esp) | |
neg_logits_model=logits*torch.logical_not(same_label) | |
logits_model=torch.cat((pos_logits_model.unsqueeze(1), neg_logits_model), dim=1) | |
return logits_model | |
def forward(self, batch, indices1, indices2,label): | |
bsz = batch['input_ids'].size(0) | |
q = self.model(batch) | |
k = q.clone().detach() | |
k = self.fabric.all_gather(k).view(-1, k.size(1)) | |
k_label = self.fabric.all_gather(label).view(-1) | |
k_index1 = self.fabric.all_gather(indices1).view(-1) | |
k_index2 = self.fabric.all_gather(indices2).view(-1) | |
#q:N | |
#k:4N | |
logits_label = self._compute_logits(q,indices1, indices2,label,k,k_index1,k_index2,k_label) | |
out = self.classifier(q) | |
if self.opt.AA: | |
loss_classfiy = F.cross_entropy(out, indices1) | |
else: | |
loss_classfiy = F.cross_entropy(out, label) | |
gt = torch.zeros(bsz, dtype=torch.long,device=logits_label.device) | |
if self.only_classifier: | |
loss_label = torch.tensor(0,device=self.device) | |
else: | |
loss_label = F.cross_entropy(logits_label, gt) | |
loss = self.a*loss_label+self.d*loss_classfiy | |
if self.training: | |
return loss,loss_label,loss_classfiy,k,k_label | |
else: | |
out = self.fabric.all_gather(out).view(-1, out.size(1)) | |
return loss,out,k,k_label | |
class SimCLR_Classifier_test(nn.Module): | |
def __init__(self, opt,fabric): | |
super(SimCLR_Classifier_test, self).__init__() | |
self.fabric = fabric | |
self.model = TextEmbeddingModel(opt.model_name) | |
self.classifier = ClassificationHead(opt.projection_size, opt.classifier_dim) | |
self.device=self.model.model.device | |
def forward(self, batch): | |
q = self.model(batch) | |
out = self.classifier(q) | |
return out | |
class SimCLR_Classifier(nn.Module): | |
def __init__(self, opt,fabric): | |
super(SimCLR_Classifier, self).__init__() | |
self.temperature = opt.temperature | |
self.opt=opt | |
self.fabric = fabric | |
self.model = TextEmbeddingModel(opt.model_name) | |
if opt.resum: | |
state_dict = torch.load(opt.pth_path, | |
map_location=self.model.device) | |
self.model.load_state_dict(state_dict) | |
self.device = self.model.model.device | |
self.esp = torch.tensor(1e-6,device=self.device) | |
self.a = torch.tensor(opt.a, | |
device=self.device) | |
self.b = torch.tensor(opt.b, | |
device=self.device) | |
self.c = torch.tensor(opt.c, | |
device=self.device) | |
self.classifier = ClassificationHead(opt.projection_size, | |
opt.classifier_dim) | |
self.only_classifier = opt.only_classifier | |
def get_encoder(self): | |
return self.model | |
def _compute_logits(self, | |
q,q_index1, q_index2, q_label, | |
k,k_index1,k_index2,k_label): | |
def cosine_similarity_matrix(q, k): | |
q_norm = F.normalize(q,dim=-1) | |
k_norm = F.normalize(k,dim=-1) | |
cosine_similarity = q_norm@k_norm.T | |
return cosine_similarity | |
logits=cosine_similarity_matrix(q,k)/self.temperature | |
q_index1=q_index1.view(-1, 1)# change to tensor of size N, 1 | |
q_index2=q_index2.view(-1, 1)# change to tensor of size N, 1 | |
q_labels=q_label.view(-1, 1)# change to tensor of size N, 1 | |
k_index1=k_index1.view(1, -1)# 1,N+K | |
k_index2=k_index2.view(1, -1) #1, N+K | |
k_labels=k_label.view(1, -1)# 1,N+K | |
same_mixed = (q_index1== k_index1) | |
same_set=(q_index2==k_index2)# N,N+K | |
same_label=(q_labels==k_labels)# N,N+K | |
is_human=(q_label==1).view(-1) | |
is_machine=(q_label==0).view(-1) | |
is_mixed=(q_index1==1).view(-1) | |
#human: human | |
pos_logits_human = torch.sum(logits*same_label,dim=1)/torch.max(torch.sum(same_label,dim=1),self.esp) | |
neg_logits_human=logits*torch.logical_not(same_label) | |
logits_human=torch.cat((pos_logits_human.unsqueeze(1), neg_logits_human), dim=1) | |
logits_human=logits_human[is_human] | |
#human+ai: general | |
pos_logits_mixed = torch.sum(logits*same_mixed,dim=1)/torch.maximum(torch.sum(same_mixed,dim=1),self.esp) | |
neg_logits_mixed=logits*torch.logical_not(same_mixed) | |
logits_mixed=torch.cat((pos_logits_mixed.unsqueeze(1), neg_logits_mixed), dim=1) | |
logits_mixed=logits_mixed[is_mixed] | |
#human+ai: model | |
pos_logits_mixed_set = torch.sum(logits*torch.logical_and(same_mixed, same_set),dim=1)/torch.max(torch.sum(torch.logical_and(same_mixed, same_set),dim=1),self.esp) | |
neg_logits_mixed_set=logits*torch.logical_not(torch.logical_and(same_mixed, same_set)) | |
logits_mixed_set=torch.cat((pos_logits_mixed_set.unsqueeze(1), neg_logits_mixed_set), dim=1) | |
logits_mixed_set=logits_mixed_set[is_mixed] | |
#model set:label | |
pos_logits_set = torch.sum(logits*same_set,dim=1)/torch.max(torch.sum(same_set,dim=1),self.esp) | |
neg_logits_set=logits*torch.logical_not(same_set) | |
logits_set=torch.cat((pos_logits_set.unsqueeze(1), neg_logits_set), dim=1) | |
logits_set=logits_set[is_machine] | |
#label: label | |
pos_logits_label = torch.sum(logits*same_label, dim=1)/torch.max(torch.sum(same_label,dim=1),self.esp) | |
neg_logits_label=logits*torch.logical_not(same_label) | |
logits_label=torch.cat((pos_logits_label.unsqueeze(1), neg_logits_label), dim=1) | |
logits_label=logits_label[is_machine] | |
return logits_human, logits_mixed, logits_mixed_set, logits_set, logits_label | |
def forward(self, encoded_batch, label, indices1, indices2):#, weights): | |
# print(len(text)) | |
q = self.model(encoded_batch) | |
k = q.clone().detach() | |
k = self.fabric.all_gather(k).view(-1, k.size(1)) | |
k_label = self.fabric.all_gather(label).view(-1) | |
k_index1 = self.fabric.all_gather(indices1).view(-1) | |
k_index2 = self.fabric.all_gather(indices2).view(-1) | |
#q:N | |
#k:4N | |
logits_human, logits_mixed, logits_mixed_set, logits_set, logits_label = self._compute_logits(q,indices1, indices2,label, | |
k,k_index1,k_index2,k_label) | |
out = self.classifier(q) | |
if self.opt.AA: | |
loss_classfiy = F.cross_entropy(out, indices1) | |
else: | |
loss_classfiy = F.cross_entropy(out, label) #, weight=weights) | |
gt_mixed = torch.zeros(logits_mixed.size(0), | |
dtype=torch.long, | |
device=logits_mixed.device) | |
gt_mixed_set = torch.zeros(logits_mixed_set.size(0), | |
dtype=torch.long, | |
device=logits_mixed_set.device) | |
gt_set = torch.zeros(logits_set.size(0), | |
dtype=torch.long, | |
device=logits_set.device) | |
gt_label = torch.zeros(logits_label.size(0), | |
dtype=torch.long, | |
device=logits_label.device) | |
gt_human = torch.zeros(logits_human.size(0), | |
dtype=torch.long, | |
device=logits_human.device) | |
loss_mixed = F.cross_entropy(logits_mixed, | |
gt_mixed) | |
loss_mixed_set = F.cross_entropy(logits_mixed_set, | |
gt_mixed_set) | |
loss_set = F.cross_entropy(logits_set, | |
gt_set) | |
loss_label = F.cross_entropy(logits_label, | |
gt_label) | |
if logits_human.numel()!=0: | |
loss_human = F.cross_entropy(logits_human.to(torch.float64), | |
gt_human) | |
else: | |
loss_human=torch.tensor(0,device=self.device) | |
loss = self.a*loss_set + (4*self.b-self.a)*loss_label + self.b*loss_human+ self.b*loss_mixed + \ | |
2*self.b*loss_mixed_set+self.c*loss_classfiy | |
if self.training: | |
if self.opt.AA: | |
return loss,loss_mixed, loss_mixed_set,loss_set,loss_label,loss_human,loss_classfiy,k,k_index1 | |
else: | |
return loss,loss_mixed, loss_mixed_set,loss_set,loss_label,loss_classfiy,loss_human,k,k_label | |
else: | |
out = self.fabric.all_gather(out).view(-1, out.size(1)) | |
if self.opt.AA: | |
return loss,out,k,k_index1 | |
else: | |
return loss,out,k,k_label | |