Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,426 Bytes
78e32cc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
import torch
import torch.nn as nn
import numpy as np
class MultiFrequencyDiscriminator(nn.Module):
def __init__(self, nch, window):
super(MultiFrequencyDiscriminator, self).__init__()
self.nch = nch
self.window = window
self.hidden_channels = 8
self.eps = torch.finfo(torch.float32).eps
self.discriminators = nn.ModuleList([FrequencyDiscriminator(2*nch, self.hidden_channels) for _ in range(len(self.window))])
def forward(self, est, sample_rate=44100):
B, nch, _ = est.shape
assert nch == self.nch
# normalize power
est = est / (est.pow(2).sum((1,2)) + self.eps).sqrt().reshape(B, 1, 1)
est = est.view(-1, est.shape[-1])
est_outputs = []
est_feature_maps = []
for i in range(len(self.discriminators)):
est_spec = torch.stft(est.float(), self.window[i], self.window[i]//2,
window=torch.hann_window(self.window[i]).to(est.device).float(),
return_complex=True)
est_RI = torch.stack([est_spec.real, est_spec.imag], dim=1)
est_RI = est_RI.view(B, nch*2, est_RI.shape[-2], est_RI.shape[-1]).type(est.type())
valid_enc = int(est_RI.shape[2] * sample_rate / 44100)
est_out, est_feat_map = self.discriminators[i](est_RI[:,:,:valid_enc].contiguous())
est_outputs.append(est_out)
est_feature_maps.append(est_feat_map)
return est_outputs, est_feature_maps
class FrequencyDiscriminator(nn.Module):
def __init__(self, in_channels, hidden_channels=512):
super(FrequencyDiscriminator, self).__init__()
self.eps = torch.finfo(torch.float32).eps
self.discriminator = nn.ModuleList()
self.discriminator += [
nn.Sequential(
nn.utils.spectral_norm(nn.Conv2d(in_channels, hidden_channels, kernel_size=(3, 3), padding=(1, 1), stride=(1, 1))),
nn.LeakyReLU(0.2, True)
),
nn.Sequential(
nn.utils.spectral_norm(nn.Conv2d(hidden_channels, hidden_channels*2, kernel_size=(3, 3), padding=(1, 1), stride=(2, 2))),
nn.LeakyReLU(0.2, True)
),
nn.Sequential(
nn.utils.spectral_norm(nn.Conv2d(hidden_channels*2, hidden_channels*4, kernel_size=(3, 3), padding=(1, 1), stride=(1, 1))),
nn.LeakyReLU(0.2, True)
),
nn.Sequential(
nn.utils.spectral_norm(nn.Conv2d(hidden_channels*4, hidden_channels*8, kernel_size=(3, 3), padding=(1, 1), stride=(2, 2))),
nn.LeakyReLU(0.2, True)
),
nn.Sequential(
nn.utils.spectral_norm(nn.Conv2d(hidden_channels*8, hidden_channels*16, kernel_size=(3, 3), padding=(1, 1), stride=(1, 1))),
nn.LeakyReLU(0.2, True)
),
nn.Sequential(
nn.utils.spectral_norm(nn.Conv2d(hidden_channels*16, hidden_channels*32, kernel_size=(3, 3), padding=(1, 1), stride=(2, 2))),
nn.LeakyReLU(0.2, True)
),
nn.Conv2d(hidden_channels*32, 1, kernel_size=(3, 3), padding=(1, 1), stride=(1, 1))
]
def forward(self, x):
hiddens = []
for layer in self.discriminator:
x = layer(x)
hiddens.append(x)
return x, hiddens[:-1] |