Spaces:
Running
Running
import inspect | |
import torch | |
import numpy as np | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import math | |
from .base_model import BaseModel | |
from ..layers import activations, normalizations | |
def GlobLN(nOut): | |
return nn.GroupNorm(1, nOut, eps=1e-8) | |
class ConvNormAct(nn.Module): | |
""" | |
This class defines the convolution layer with normalization and a PReLU | |
activation | |
""" | |
def __init__(self, nIn, nOut, kSize, stride=1, groups=1): | |
""" | |
:param nIn: number of input channels | |
:param nOut: number of output channels | |
:param kSize: kernel size | |
:param stride: stride rate for down-sampling. Default is 1 | |
""" | |
super().__init__() | |
padding = int((kSize - 1) / 2) | |
self.conv = nn.Conv1d( | |
nIn, nOut, kSize, stride=stride, padding=padding, bias=True, groups=groups | |
) | |
self.norm = GlobLN(nOut) | |
self.act = nn.PReLU() | |
def forward(self, input): | |
output = self.conv(input) | |
output = self.norm(output) | |
return self.act(output) | |
class ConvNorm(nn.Module): | |
""" | |
This class defines the convolution layer with normalization and PReLU activation | |
""" | |
def __init__(self, nIn, nOut, kSize, stride=1, groups=1, bias=True): | |
""" | |
:param nIn: number of input channels | |
:param nOut: number of output channels | |
:param kSize: kernel size | |
:param stride: stride rate for down-sampling. Default is 1 | |
""" | |
super().__init__() | |
padding = int((kSize - 1) / 2) | |
self.conv = nn.Conv1d( | |
nIn, nOut, kSize, stride=stride, padding=padding, bias=bias, groups=groups | |
) | |
self.norm = GlobLN(nOut) | |
def forward(self, input): | |
output = self.conv(input) | |
return self.norm(output) | |
class ATTConvActNorm(nn.Module): | |
def __init__( | |
self, | |
in_chan: int = 1, | |
out_chan: int = 1, | |
kernel_size: int = -1, | |
stride: int = 1, | |
groups: int = 1, | |
dilation: int = 1, | |
padding: int = None, | |
norm_type: str = None, | |
act_type: str = None, | |
n_freqs: int = -1, | |
xavier_init: bool = False, | |
bias: bool = True, | |
is2d: bool = False, | |
*args, | |
**kwargs, | |
): | |
super(ATTConvActNorm, self).__init__() | |
self.in_chan = in_chan | |
self.out_chan = out_chan | |
self.kernel_size = kernel_size | |
self.stride = stride | |
self.groups = groups | |
self.dilation = dilation | |
self.padding = padding | |
self.norm_type = norm_type | |
self.act_type = act_type | |
self.n_freqs = n_freqs | |
self.xavier_init = xavier_init | |
self.bias = bias | |
if self.padding is None: | |
self.padding = 0 if self.stride > 1 else "same" | |
if kernel_size > 0: | |
conv = nn.Conv2d if is2d else nn.Conv1d | |
self.conv = conv( | |
in_channels=self.in_chan, | |
out_channels=self.out_chan, | |
kernel_size=self.kernel_size, | |
stride=self.stride, | |
padding=self.padding, | |
dilation=self.dilation, | |
groups=self.groups, | |
bias=self.bias, | |
) | |
if self.xavier_init: | |
nn.init.xavier_uniform_(self.conv.weight) | |
else: | |
self.conv = nn.Identity() | |
self.act = activations.get(self.act_type)() | |
self.norm = normalizations.get(self.norm_type)( | |
(self.out_chan, self.n_freqs) if self.norm_type == "LayerNormalization4D" else self.out_chan | |
) | |
def forward(self, x: torch.Tensor): | |
output = self.conv(x) | |
output = self.act(output) | |
output = self.norm(output) | |
return output | |
def get_config(self): | |
encoder_args = {} | |
for k, v in (self.__dict__).items(): | |
if not k.startswith("_") and k != "training": | |
if not inspect.ismethod(v): | |
encoder_args[k] = v | |
return encoder_args | |
class DilatedConvNorm(nn.Module): | |
""" | |
This class defines the dilated convolution with normalized output. | |
""" | |
def __init__(self, nIn, nOut, kSize, stride=1, d=1, groups=1): | |
""" | |
:param nIn: number of input channels | |
:param nOut: number of output channels | |
:param kSize: kernel size | |
:param stride: optional stride rate for down-sampling | |
:param d: optional dilation rate | |
""" | |
super().__init__() | |
self.conv = nn.Conv1d( | |
nIn, | |
nOut, | |
kSize, | |
stride=stride, | |
dilation=d, | |
padding=((kSize - 1) // 2) * d, | |
groups=groups, | |
) | |
# self.norm = nn.GroupNorm(1, nOut, eps=1e-08) | |
self.norm = GlobLN(nOut) | |
def forward(self, input): | |
output = self.conv(input) | |
return self.norm(output) | |
class Mlp(nn.Module): | |
def __init__(self, in_features, hidden_size, drop=0.1): | |
super().__init__() | |
self.fc1 = ConvNorm(in_features, hidden_size, 1, bias=False) | |
self.dwconv = nn.Conv1d( | |
hidden_size, hidden_size, 5, 1, 2, bias=True, groups=hidden_size | |
) | |
self.act = nn.ReLU() | |
self.fc2 = ConvNorm(hidden_size, in_features, 1, bias=False) | |
self.drop = nn.Dropout(drop) | |
def forward(self, x): | |
x = self.fc1(x) | |
x = self.dwconv(x) | |
x = self.act(x) | |
x = self.drop(x) | |
x = self.fc2(x) | |
x = self.drop(x) | |
return x | |
class InjectionMultiSum(nn.Module): | |
def __init__(self, inp: int, oup: int, kernel: int = 1) -> None: | |
super().__init__() | |
groups = 1 | |
if inp == oup: | |
groups = inp | |
self.local_embedding = ConvNorm(inp, oup, kernel, groups=groups, bias=False) | |
self.global_embedding = ConvNorm(inp, oup, kernel, groups=groups, bias=False) | |
self.global_act = ConvNorm(inp, oup, kernel, groups=groups, bias=False) | |
self.act = nn.Sigmoid() | |
def forward(self, x_l, x_g): | |
""" | |
x_g: global features | |
x_l: local features | |
""" | |
B, N, T = x_l.shape | |
local_feat = self.local_embedding(x_l) | |
global_act = self.global_act(x_g) | |
sig_act = F.interpolate(self.act(global_act), size=T, mode="nearest") | |
# sig_act = self.act(global_act) | |
global_feat = self.global_embedding(x_g) | |
global_feat = F.interpolate(global_feat, size=T, mode="nearest") | |
out = local_feat * sig_act + global_feat | |
return out | |
class InjectionMulti(nn.Module): | |
def __init__(self, inp: int, oup: int, kernel: int = 1) -> None: | |
super().__init__() | |
groups = 1 | |
if inp == oup: | |
groups = inp | |
self.local_embedding = ConvNorm(inp, oup, kernel, groups=groups, bias=False) | |
self.global_act = ConvNorm(inp, oup, kernel, groups=groups, bias=False) | |
self.act = nn.Sigmoid() | |
def forward(self, x_l, x_g): | |
""" | |
x_g: global features | |
x_l: local features | |
""" | |
B, N, T = x_l.shape | |
local_feat = self.local_embedding(x_l) | |
global_act = self.global_act(x_g) | |
sig_act = F.interpolate(self.act(global_act), size=T, mode="nearest") | |
# sig_act = self.act(global_act) | |
out = local_feat * sig_act | |
return out | |
class UConvBlock(nn.Module): | |
""" | |
This class defines the block which performs successive downsampling and | |
upsampling in order to be able to analyze the input features in multiple | |
resolutions. | |
""" | |
def __init__(self, out_channels=128, in_channels=512, upsampling_depth=4, model_T=True): | |
super().__init__() | |
self.proj_1x1 = ConvNormAct(out_channels, in_channels, 1, stride=1, groups=1) | |
self.depth = upsampling_depth | |
self.spp_dw = nn.ModuleList() | |
self.spp_dw.append( | |
DilatedConvNorm( | |
in_channels, in_channels, kSize=5, stride=1, groups=in_channels, d=1 | |
) | |
) | |
for i in range(1, upsampling_depth): | |
self.spp_dw.append( | |
DilatedConvNorm( | |
in_channels, | |
in_channels, | |
kSize=5, | |
stride=2, | |
groups=in_channels, | |
d=1, | |
) | |
) | |
self.loc_glo_fus = nn.ModuleList([]) | |
for i in range(upsampling_depth): | |
self.loc_glo_fus.append(InjectionMultiSum(in_channels, in_channels)) | |
self.res_conv = nn.Conv1d(in_channels, out_channels, 1) | |
self.globalatt = Mlp(in_channels, in_channels, drop=0.1) | |
self.last_layer = nn.ModuleList([]) | |
for i in range(self.depth - 1): | |
self.last_layer.append(InjectionMultiSum(in_channels, in_channels, 5)) | |
def forward(self, x): | |
""" | |
:param x: input feature map | |
:return: transformed feature map | |
""" | |
residual = x.clone() | |
# Reduce --> project high-dimensional feature maps to low-dimensional space | |
output1 = self.proj_1x1(x) | |
output = [self.spp_dw[0](output1)] | |
# Do the downsampling process from the previous level | |
for k in range(1, self.depth): | |
out_k = self.spp_dw[k](output[-1]) | |
output.append(out_k) | |
# global features | |
global_f = torch.zeros( | |
output[-1].shape, requires_grad=True, device=output1.device | |
) | |
for fea in output: | |
global_f = global_f + F.adaptive_avg_pool1d( | |
fea, output_size=output[-1].shape[-1] | |
) | |
# global_f = global_f + fea | |
global_f = self.globalatt(global_f) # [B, N, T] | |
x_fused = [] | |
# Gather them now in reverse order | |
for idx in range(self.depth): | |
local = output[idx] | |
x_fused.append(self.loc_glo_fus[idx](local, global_f)) | |
expanded = None | |
for i in range(self.depth - 2, -1, -1): | |
if i == self.depth - 2: | |
expanded = self.last_layer[i](x_fused[i], x_fused[i - 1]) | |
else: | |
expanded = self.last_layer[i](x_fused[i], expanded) | |
# import pdb; pdb.set_trace() | |
return self.res_conv(expanded) + residual | |
class MultiHeadSelfAttention2D(nn.Module): | |
def __init__( | |
self, | |
in_chan: int, | |
n_freqs: int, | |
n_head: int = 4, | |
hid_chan: int = 4, | |
act_type: str = "prelu", | |
norm_type: str = "LayerNormalization4D", | |
dim: int = 3, | |
*args, | |
**kwargs, | |
): | |
super(MultiHeadSelfAttention2D, self).__init__() | |
self.in_chan = in_chan | |
self.n_freqs = n_freqs | |
self.n_head = n_head | |
self.hid_chan = hid_chan | |
self.act_type = act_type | |
self.norm_type = norm_type | |
self.dim = dim | |
assert self.in_chan % self.n_head == 0 | |
self.Queries = nn.ModuleList() | |
self.Keys = nn.ModuleList() | |
self.Values = nn.ModuleList() | |
for _ in range(self.n_head): | |
self.Queries.append( | |
ATTConvActNorm( | |
in_chan=self.in_chan, | |
out_chan=self.hid_chan, | |
kernel_size=1, | |
act_type=self.act_type, | |
norm_type=self.norm_type, | |
n_freqs=self.n_freqs, | |
is2d=True, | |
) | |
) | |
self.Keys.append( | |
ATTConvActNorm( | |
in_chan=self.in_chan, | |
out_chan=self.hid_chan, | |
kernel_size=1, | |
act_type=self.act_type, | |
norm_type=self.norm_type, | |
n_freqs=self.n_freqs, | |
is2d=True, | |
) | |
) | |
self.Values.append( | |
ATTConvActNorm( | |
in_chan=self.in_chan, | |
out_chan=self.in_chan // self.n_head, | |
kernel_size=1, | |
act_type=self.act_type, | |
norm_type=self.norm_type, | |
n_freqs=self.n_freqs, | |
is2d=True, | |
) | |
) | |
self.attn_concat_proj = ATTConvActNorm( | |
in_chan=self.in_chan, | |
out_chan=self.in_chan, | |
kernel_size=1, | |
act_type=self.act_type, | |
norm_type=self.norm_type, | |
n_freqs=self.n_freqs, | |
is2d=True, | |
) | |
def forward(self, x: torch.Tensor): | |
if self.dim == 4: | |
x = x.transpose(-2, -1).contiguous() | |
batch_size, _, time, freq = x.size() | |
residual = x | |
all_Q = [q(x) for q in self.Queries] # [B, E, T, F] | |
all_K = [k(x) for k in self.Keys] # [B, E, T, F] | |
all_V = [v(x) for v in self.Values] # [B, C/n_head, T, F] | |
Q = torch.cat(all_Q, dim=0) # [B', E, T, F] B' = B*n_head | |
K = torch.cat(all_K, dim=0) # [B', E, T, F] | |
V = torch.cat(all_V, dim=0) # [B', C/n_head, T, F] | |
Q = Q.transpose(1, 2).flatten(start_dim=2) # [B', T, E*F] | |
K = K.transpose(1, 2).flatten(start_dim=2) # [B', T, E*F] | |
V = V.transpose(1, 2) # [B', T, C/n_head, F] | |
old_shape = V.shape | |
V = V.flatten(start_dim=2) # [B', T, C*F/n_head] | |
emb_dim = Q.shape[-1] # C*F/n_head | |
attn_mat = torch.matmul(Q, K.transpose(1, 2)) / (emb_dim**0.5) # [B', T, T] | |
attn_mat = F.softmax(attn_mat, dim=2) # [B', T, T] | |
V = torch.matmul(attn_mat, V) # [B', T, C*F/n_head] | |
V = V.reshape(old_shape) # [B', T, C/n_head, F] | |
V = V.transpose(1, 2) # [B', C/n_head, T, F] | |
emb_dim = V.shape[1] # C/n_head | |
x = V.view([self.n_head, batch_size, emb_dim, time, freq]) # [n_head, B, C/n_head, T, F] | |
x = x.transpose(0, 1).contiguous() # [B, n_head, C/n_head, T, F] | |
x = x.view([batch_size, self.n_head * emb_dim, time, freq]) # [B, C, T, F] | |
x = self.attn_concat_proj(x) # [B, C, T, F] | |
x = x + residual | |
if self.dim == 4: | |
x = x.transpose(-2, -1).contiguous() | |
return x | |
class Recurrent(nn.Module): | |
def __init__( | |
self, | |
out_channels=128, | |
in_channels=512, | |
nband=8, | |
upsampling_depth=3, | |
n_head=4, | |
att_hid_chan=4, | |
kernel_size: int = 8, | |
stride: int = 1, | |
_iter=4 | |
): | |
super().__init__() | |
self.nband = nband | |
self.freq_path = nn.ModuleList([ | |
UConvBlock(out_channels, in_channels, upsampling_depth), | |
MultiHeadSelfAttention2D(out_channels, 1, n_head=n_head, hid_chan=att_hid_chan, act_type="prelu", norm_type="LayerNormalization4D", dim=4), | |
normalizations.get("LayerNormalization4D")((out_channels, 1)) | |
]) | |
self.frame_path = nn.ModuleList([ | |
UConvBlock(out_channels, in_channels, upsampling_depth), | |
MultiHeadSelfAttention2D(out_channels, 1, n_head=n_head, hid_chan=att_hid_chan, act_type="prelu", norm_type="LayerNormalization4D", dim=4), | |
normalizations.get("LayerNormalization4D")((out_channels, 1)) | |
]) | |
self.iter = _iter | |
self.concat_block = nn.Sequential( | |
nn.Conv2d(out_channels, out_channels, 1, 1, groups=out_channels), nn.PReLU() | |
) | |
def forward(self, x): | |
# B, nband, N, T | |
B, nband, N, T = x.shape | |
x = x.permute(0, 2, 1, 3).contiguous() # B, N, nband, T | |
mixture = x.clone() | |
for i in range(self.iter): | |
if i == 0: | |
x = self.freq_time_process(x, B, nband, N, T) # B, N, nband, T | |
else: | |
x = self.freq_time_process(self.concat_block(mixture + x), B, nband, N, T) # B, N, nband, T | |
return x.permute(0, 2, 1, 3).contiguous() # B, nband, N, T | |
def freq_time_process(self, x, B, nband, N, T): | |
# Process Frequency Path | |
residual_1 = x.clone() | |
x = x.permute(0, 3, 1, 2).contiguous() # B, T, N, nband | |
freq_fea = self.freq_path[0](x.view(B*T, N, nband)) # B*T, N, nband | |
freq_fea = freq_fea.view(B, T, N, nband).permute(0, 2, 1, 3).contiguous() # B, N, T, nband | |
freq_fea = self.freq_path[1](freq_fea) # B, N, T, nband | |
freq_fea = self.freq_path[2](freq_fea) # B, N, T, nband | |
freq_fea = freq_fea.permute(0, 1, 3, 2).contiguous() | |
x = freq_fea + residual_1 # B, N, nband, T | |
# Process Frame Path | |
residual_2 = x.clone() | |
x2 = x.permute(0, 2, 1, 3).contiguous() | |
frame_fea = self.frame_path[0](x2.view(B*nband, N, T)) # B*nband, N, T | |
frame_fea = frame_fea.view(B, nband, N, T).permute(0, 2, 1, 3).contiguous() | |
frame_fea = self.frame_path[1](frame_fea) # B, N, nband, T | |
frame_fea = self.frame_path[2](frame_fea) # B, N, nband, T | |
x = frame_fea + residual_2 # B, N, nband, T | |
return x | |
class TIGER(BaseModel): | |
def __init__( | |
self, | |
out_channels=128, | |
in_channels=512, | |
num_blocks=16, | |
upsampling_depth=4, | |
att_n_head=4, | |
att_hid_chan=4, | |
att_kernel_size=8, | |
att_stride=1, | |
win=2048, | |
stride=512, | |
num_sources=2, | |
sample_rate=44100, | |
): | |
super(TIGER, self).__init__(sample_rate=sample_rate) | |
self.sample_rate = sample_rate | |
self.win = win | |
self.stride = stride | |
self.group = self.win // 2 | |
self.enc_dim = self.win // 2 + 1 | |
self.feature_dim = out_channels | |
self.num_output = num_sources | |
self.eps = torch.finfo(torch.float32).eps | |
# 0-1k (25 hop), 1k-2k (100 hop), 2k-4k (250 hop), 4k-8k (500 hop) | |
bandwidth_25 = int(np.floor(25 / (sample_rate / 2.) * self.enc_dim)) | |
bandwidth_100 = int(np.floor(100 / (sample_rate / 2.) * self.enc_dim)) | |
bandwidth_250 = int(np.floor(250 / (sample_rate / 2.) * self.enc_dim)) | |
bandwidth_500 = int(np.floor(500 / (sample_rate / 2.) * self.enc_dim)) | |
self.band_width = [bandwidth_25]*40 | |
self.band_width += [bandwidth_100]*10 | |
self.band_width += [bandwidth_250]*8 | |
self.band_width += [bandwidth_500]*8 | |
self.band_width.append(self.enc_dim - np.sum(self.band_width)) | |
self.nband = len(self.band_width) | |
print(self.band_width) | |
self.BN = nn.ModuleList([]) | |
for i in range(self.nband): | |
self.BN.append(nn.Sequential(nn.GroupNorm(1, self.band_width[i]*2, self.eps), | |
nn.Conv1d(self.band_width[i]*2, self.feature_dim, 1) | |
) | |
) | |
self.separator = Recurrent(self.feature_dim, in_channels, self.nband, upsampling_depth, att_n_head, att_hid_chan, att_kernel_size, att_stride, num_blocks) | |
self.mask = nn.ModuleList([]) | |
for i in range(self.nband): | |
self.mask.append(nn.Sequential( | |
nn.PReLU(), | |
nn.Conv1d(self.feature_dim, self.band_width[i]*4*num_sources, 1, groups=num_sources) | |
) | |
) | |
def pad_input(self, input, window, stride): | |
""" | |
Zero-padding input according to window/stride size. | |
""" | |
batch_size, nsample = input.shape | |
# pad the signals at the end for matching the window/stride size | |
rest = window - (stride + nsample % window) % window | |
if rest > 0: | |
pad = torch.zeros(batch_size, rest).type(input.type()) | |
input = torch.cat([input, pad], 1) | |
pad_aux = torch.zeros(batch_size, stride).type(input.type()) | |
input = torch.cat([pad_aux, input, pad_aux], 1) | |
return input, rest | |
def forward(self, input): | |
# input shape: (B, C, T) | |
was_one_d = False | |
if input.ndim == 1: | |
was_one_d = True | |
input = input.unsqueeze(0).unsqueeze(1) | |
if input.ndim == 2: | |
was_one_d = True | |
input = input.unsqueeze(1) | |
if input.ndim == 3: | |
input = input | |
batch_size, nch, nsample = input.shape | |
input = input.view(batch_size*nch, -1) | |
# frequency-domain separation | |
spec = torch.stft(input, n_fft=self.win, hop_length=self.stride, | |
window=torch.hann_window(self.win).to(input.device).type(input.type()), | |
return_complex=True) | |
# print(spec.shape) | |
# concat real and imag, split to subbands | |
spec_RI = torch.stack([spec.real, spec.imag], 1) # B*nch, 2, F, T | |
subband_spec_RI = [] | |
subband_spec = [] | |
band_idx = 0 | |
for i in range(len(self.band_width)): | |
subband_spec_RI.append(spec_RI[:,:,band_idx:band_idx+self.band_width[i]].contiguous()) | |
subband_spec.append(spec[:,band_idx:band_idx+self.band_width[i]]) # B*nch, BW, T | |
band_idx += self.band_width[i] | |
# normalization and bottleneck | |
subband_feature = [] | |
for i in range(len(self.band_width)): | |
subband_feature.append(self.BN[i](subband_spec_RI[i].view(batch_size*nch, self.band_width[i]*2, -1))) | |
subband_feature = torch.stack(subband_feature, 1) # B, nband, N, T | |
# import pdb; pdb.set_trace() | |
# separator | |
sep_output = self.separator(subband_feature.view(batch_size*nch, self.nband, self.feature_dim, -1)) # B, nband, N, T | |
sep_output = sep_output.view(batch_size*nch, self.nband, self.feature_dim, -1) | |
sep_subband_spec = [] | |
for i in range(self.nband): | |
this_output = self.mask[i](sep_output[:,i]).view(batch_size*nch, 2, 2, self.num_output, self.band_width[i], -1) | |
this_mask = this_output[:,0] * torch.sigmoid(this_output[:,1]) # B*nch, 2, K, BW, T | |
this_mask_real = this_mask[:,0] # B*nch, K, BW, T | |
this_mask_imag = this_mask[:,1] # B*nch, K, BW, T | |
# force mask sum to 1 | |
this_mask_real_sum = this_mask_real.sum(1).unsqueeze(1) # B*nch, 1, BW, T | |
this_mask_imag_sum = this_mask_imag.sum(1).unsqueeze(1) # B*nch, 1, BW, T | |
this_mask_real = this_mask_real - (this_mask_real_sum - 1) / self.num_output | |
this_mask_imag = this_mask_imag - this_mask_imag_sum / self.num_output | |
est_spec_real = subband_spec[i].real.unsqueeze(1) * this_mask_real - subband_spec[i].imag.unsqueeze(1) * this_mask_imag # B*nch, K, BW, T | |
est_spec_imag = subband_spec[i].real.unsqueeze(1) * this_mask_imag + subband_spec[i].imag.unsqueeze(1) * this_mask_real # B*nch, K, BW, T | |
sep_subband_spec.append(torch.complex(est_spec_real, est_spec_imag)) | |
sep_subband_spec = torch.cat(sep_subband_spec, 2) | |
output = torch.istft(sep_subband_spec.view(batch_size*nch*self.num_output, self.enc_dim, -1), | |
n_fft=self.win, hop_length=self.stride, | |
window=torch.hann_window(self.win).to(input.device).type(input.type()), length=nsample) | |
output = output.view(batch_size*nch, self.num_output, -1) | |
# if was_one_d: | |
# return output.squeeze(0) | |
return output | |
def get_model_args(self): | |
model_args = {"n_sample_rate": 2} | |
return model_args |