deepfake / models /image.py
Dharshaneshwaran
Full updated code with finding ai generated images too
ddcedb5
raw
history blame
7.87 kB
import re
import os
import wget
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from models.rawnet import SincConv, Residual_block
from models.classifiers import DeepFakeClassifier
class ImageEncoder(nn.Module):
def __init__(self, args):
super(ImageEncoder, self).__init__()
self.device = args.device
self.args = args
self.flatten = nn.Flatten()
self.sigmoid = nn.Sigmoid()
# self.fc = nn.Linear(in_features=2560, out_features = 2)
self.pretrained_image_encoder = args.pretrained_image_encoder
self.freeze_image_encoder = args.freeze_image_encoder
if self.pretrained_image_encoder == False:
self.model = DeepFakeClassifier(encoder = "tf_efficientnet_b7_ns").to(self.device)
else:
self.pretrained_ckpt = torch.load('pretrained\\final_999_DeepFakeClassifier_tf_efficientnet_b7_ns_0_23', map_location = torch.device(self.args.device))
self.state_dict = self.pretrained_ckpt.get("state_dict", self.pretrained_ckpt)
self.model = DeepFakeClassifier(encoder = "tf_efficientnet_b7_ns").to(self.device)
print("Loading pretrained image encoder...")
self.model.load_state_dict({re.sub("^module.", "", k): v for k, v in self.state_dict.items()}, strict=True)
print("Loaded pretrained image encoder.")
if self.freeze_image_encoder == True:
for idx, param in self.model.named_parameters():
param.requires_grad = False
# self.model.fc = nn.Identity()
def forward(self, x):
x = self.model(x)
out = self.sigmoid(x)
# x = self.flatten(x)
# out = self.fc(x)
return out
class RawNet(nn.Module):
def __init__(self, args):
super(RawNet, self).__init__()
self.device=args.device
self.filts = [20, [20, 20], [20, 128], [128, 128]]
self.Sinc_conv=SincConv(device=self.device,
out_channels = self.filts[0],
kernel_size = 1024,
in_channels = args.in_channels)
self.first_bn = nn.BatchNorm1d(num_features = self.filts[0])
self.selu = nn.SELU(inplace=True)
self.block0 = nn.Sequential(Residual_block(nb_filts = self.filts[1], first = True))
self.block1 = nn.Sequential(Residual_block(nb_filts = self.filts[1]))
self.block2 = nn.Sequential(Residual_block(nb_filts = self.filts[2]))
self.filts[2][0] = self.filts[2][1]
self.block3 = nn.Sequential(Residual_block(nb_filts = self.filts[2]))
self.block4 = nn.Sequential(Residual_block(nb_filts = self.filts[2]))
self.block5 = nn.Sequential(Residual_block(nb_filts = self.filts[2]))
self.avgpool = nn.AdaptiveAvgPool1d(1)
self.fc_attention0 = self._make_attention_fc(in_features = self.filts[1][-1],
l_out_features = self.filts[1][-1])
self.fc_attention1 = self._make_attention_fc(in_features = self.filts[1][-1],
l_out_features = self.filts[1][-1])
self.fc_attention2 = self._make_attention_fc(in_features = self.filts[2][-1],
l_out_features = self.filts[2][-1])
self.fc_attention3 = self._make_attention_fc(in_features = self.filts[2][-1],
l_out_features = self.filts[2][-1])
self.fc_attention4 = self._make_attention_fc(in_features = self.filts[2][-1],
l_out_features = self.filts[2][-1])
self.fc_attention5 = self._make_attention_fc(in_features = self.filts[2][-1],
l_out_features = self.filts[2][-1])
self.bn_before_gru = nn.BatchNorm1d(num_features = self.filts[2][-1])
self.gru = nn.GRU(input_size = self.filts[2][-1],
hidden_size = args.gru_node,
num_layers = args.nb_gru_layer,
batch_first = True)
self.fc1_gru = nn.Linear(in_features = args.gru_node,
out_features = args.nb_fc_node)
self.fc2_gru = nn.Linear(in_features = args.nb_fc_node,
out_features = args.nb_classes ,bias=True)
self.sig = nn.Sigmoid()
self.logsoftmax = nn.LogSoftmax(dim=1)
self.pretrained_audio_encoder = args.pretrained_audio_encoder
self.freeze_audio_encoder = args.freeze_audio_encoder
if self.pretrained_audio_encoder == True:
print("Loading pretrained audio encoder")
ckpt = torch.load('pretrained\\RawNet.pth', map_location = torch.device(self.device))
print("Loaded pretrained audio encoder")
self.load_state_dict(ckpt, strict = True)
if self.freeze_audio_encoder:
for param in self.parameters():
param.requires_grad = False
def forward(self, x, y = None):
nb_samp = x.shape[0]
len_seq = x.shape[1]
x=x.view(nb_samp,1,len_seq)
x = self.Sinc_conv(x)
x = F.max_pool1d(torch.abs(x), 3)
x = self.first_bn(x)
x = self.selu(x)
x0 = self.block0(x)
y0 = self.avgpool(x0).view(x0.size(0), -1) # torch.Size([batch, filter])
y0 = self.fc_attention0(y0)
y0 = self.sig(y0).view(y0.size(0), y0.size(1), -1) # torch.Size([batch, filter, 1])
x = x0 * y0 + y0 # (batch, filter, time) x (batch, filter, 1)
x1 = self.block1(x)
y1 = self.avgpool(x1).view(x1.size(0), -1) # torch.Size([batch, filter])
y1 = self.fc_attention1(y1)
y1 = self.sig(y1).view(y1.size(0), y1.size(1), -1) # torch.Size([batch, filter, 1])
x = x1 * y1 + y1 # (batch, filter, time) x (batch, filter, 1)
x2 = self.block2(x)
y2 = self.avgpool(x2).view(x2.size(0), -1) # torch.Size([batch, filter])
y2 = self.fc_attention2(y2)
y2 = self.sig(y2).view(y2.size(0), y2.size(1), -1) # torch.Size([batch, filter, 1])
x = x2 * y2 + y2 # (batch, filter, time) x (batch, filter, 1)
x3 = self.block3(x)
y3 = self.avgpool(x3).view(x3.size(0), -1) # torch.Size([batch, filter])
y3 = self.fc_attention3(y3)
y3 = self.sig(y3).view(y3.size(0), y3.size(1), -1) # torch.Size([batch, filter, 1])
x = x3 * y3 + y3 # (batch, filter, time) x (batch, filter, 1)
x4 = self.block4(x)
y4 = self.avgpool(x4).view(x4.size(0), -1) # torch.Size([batch, filter])
y4 = self.fc_attention4(y4)
y4 = self.sig(y4).view(y4.size(0), y4.size(1), -1) # torch.Size([batch, filter, 1])
x = x4 * y4 + y4 # (batch, filter, time) x (batch, filter, 1)
x5 = self.block5(x)
y5 = self.avgpool(x5).view(x5.size(0), -1) # torch.Size([batch, filter])
y5 = self.fc_attention5(y5)
y5 = self.sig(y5).view(y5.size(0), y5.size(1), -1) # torch.Size([batch, filter, 1])
x = x5 * y5 + y5 # (batch, filter, time) x (batch, filter, 1)
x = self.bn_before_gru(x)
x = self.selu(x)
x = x.permute(0, 2, 1) #(batch, filt, time) >> (batch, time, filt)
self.gru.flatten_parameters()
x, _ = self.gru(x)
x = x[:,-1,:]
x = self.fc1_gru(x)
x = self.fc2_gru(x)
output=self.logsoftmax(x)
return output
def _make_attention_fc(self, in_features, l_out_features):
l_fc = []
l_fc.append(nn.Linear(in_features = in_features,
out_features = l_out_features))
return nn.Sequential(*l_fc)
def _make_layer(self, nb_blocks, nb_filts, first = False):
layers = []
#def __init__(self, nb_filts, first = False):
for i in range(nb_blocks):
first = first if i == 0 else False
layers.append(Residual_block(nb_filts = nb_filts,
first = first))
if i == 0: nb_filts[0] = nb_filts[1]
return nn.Sequential(*layers)