fffiloni's picture
Migrated from GitHub
406f22d verified
import inspect
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import math
from .base_model import BaseModel
from ..layers import activations, normalizations
def GlobLN(nOut):
return nn.GroupNorm(1, nOut, eps=1e-8)
class ConvNormAct(nn.Module):
"""
This class defines the convolution layer with normalization and a PReLU
activation
"""
def __init__(self, nIn, nOut, kSize, stride=1, groups=1):
"""
:param nIn: number of input channels
:param nOut: number of output channels
:param kSize: kernel size
:param stride: stride rate for down-sampling. Default is 1
"""
super().__init__()
padding = int((kSize - 1) / 2)
self.conv = nn.Conv1d(
nIn, nOut, kSize, stride=stride, padding=padding, bias=True, groups=groups
)
self.norm = GlobLN(nOut)
self.act = nn.PReLU()
def forward(self, input):
output = self.conv(input)
output = self.norm(output)
return self.act(output)
class ConvNorm(nn.Module):
"""
This class defines the convolution layer with normalization and PReLU activation
"""
def __init__(self, nIn, nOut, kSize, stride=1, groups=1, bias=True):
"""
:param nIn: number of input channels
:param nOut: number of output channels
:param kSize: kernel size
:param stride: stride rate for down-sampling. Default is 1
"""
super().__init__()
padding = int((kSize - 1) / 2)
self.conv = nn.Conv1d(
nIn, nOut, kSize, stride=stride, padding=padding, bias=bias, groups=groups
)
self.norm = GlobLN(nOut)
def forward(self, input):
output = self.conv(input)
return self.norm(output)
class ATTConvActNorm(nn.Module):
def __init__(
self,
in_chan: int = 1,
out_chan: int = 1,
kernel_size: int = -1,
stride: int = 1,
groups: int = 1,
dilation: int = 1,
padding: int = None,
norm_type: str = None,
act_type: str = None,
n_freqs: int = -1,
xavier_init: bool = False,
bias: bool = True,
is2d: bool = False,
*args,
**kwargs,
):
super(ATTConvActNorm, self).__init__()
self.in_chan = in_chan
self.out_chan = out_chan
self.kernel_size = kernel_size
self.stride = stride
self.groups = groups
self.dilation = dilation
self.padding = padding
self.norm_type = norm_type
self.act_type = act_type
self.n_freqs = n_freqs
self.xavier_init = xavier_init
self.bias = bias
if self.padding is None:
self.padding = 0 if self.stride > 1 else "same"
if kernel_size > 0:
conv = nn.Conv2d if is2d else nn.Conv1d
self.conv = conv(
in_channels=self.in_chan,
out_channels=self.out_chan,
kernel_size=self.kernel_size,
stride=self.stride,
padding=self.padding,
dilation=self.dilation,
groups=self.groups,
bias=self.bias,
)
if self.xavier_init:
nn.init.xavier_uniform_(self.conv.weight)
else:
self.conv = nn.Identity()
self.act = activations.get(self.act_type)()
self.norm = normalizations.get(self.norm_type)(
(self.out_chan, self.n_freqs) if self.norm_type == "LayerNormalization4D" else self.out_chan
)
def forward(self, x: torch.Tensor):
output = self.conv(x)
output = self.act(output)
output = self.norm(output)
return output
def get_config(self):
encoder_args = {}
for k, v in (self.__dict__).items():
if not k.startswith("_") and k != "training":
if not inspect.ismethod(v):
encoder_args[k] = v
return encoder_args
class DilatedConvNorm(nn.Module):
"""
This class defines the dilated convolution with normalized output.
"""
def __init__(self, nIn, nOut, kSize, stride=1, d=1, groups=1):
"""
:param nIn: number of input channels
:param nOut: number of output channels
:param kSize: kernel size
:param stride: optional stride rate for down-sampling
:param d: optional dilation rate
"""
super().__init__()
self.conv = nn.Conv1d(
nIn,
nOut,
kSize,
stride=stride,
dilation=d,
padding=((kSize - 1) // 2) * d,
groups=groups,
)
# self.norm = nn.GroupNorm(1, nOut, eps=1e-08)
self.norm = GlobLN(nOut)
def forward(self, input):
output = self.conv(input)
return self.norm(output)
class Mlp(nn.Module):
def __init__(self, in_features, hidden_size, drop=0.1):
super().__init__()
self.fc1 = ConvNorm(in_features, hidden_size, 1, bias=False)
self.dwconv = nn.Conv1d(
hidden_size, hidden_size, 5, 1, 2, bias=True, groups=hidden_size
)
self.act = nn.ReLU()
self.fc2 = ConvNorm(hidden_size, in_features, 1, bias=False)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.dwconv(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class InjectionMultiSum(nn.Module):
def __init__(self, inp: int, oup: int, kernel: int = 1) -> None:
super().__init__()
groups = 1
if inp == oup:
groups = inp
self.local_embedding = ConvNorm(inp, oup, kernel, groups=groups, bias=False)
self.global_embedding = ConvNorm(inp, oup, kernel, groups=groups, bias=False)
self.global_act = ConvNorm(inp, oup, kernel, groups=groups, bias=False)
self.act = nn.Sigmoid()
def forward(self, x_l, x_g):
"""
x_g: global features
x_l: local features
"""
B, N, T = x_l.shape
local_feat = self.local_embedding(x_l)
global_act = self.global_act(x_g)
sig_act = F.interpolate(self.act(global_act), size=T, mode="nearest")
# sig_act = self.act(global_act)
global_feat = self.global_embedding(x_g)
global_feat = F.interpolate(global_feat, size=T, mode="nearest")
out = local_feat * sig_act + global_feat
return out
class InjectionMulti(nn.Module):
def __init__(self, inp: int, oup: int, kernel: int = 1) -> None:
super().__init__()
groups = 1
if inp == oup:
groups = inp
self.local_embedding = ConvNorm(inp, oup, kernel, groups=groups, bias=False)
self.global_act = ConvNorm(inp, oup, kernel, groups=groups, bias=False)
self.act = nn.Sigmoid()
def forward(self, x_l, x_g):
"""
x_g: global features
x_l: local features
"""
B, N, T = x_l.shape
local_feat = self.local_embedding(x_l)
global_act = self.global_act(x_g)
sig_act = F.interpolate(self.act(global_act), size=T, mode="nearest")
# sig_act = self.act(global_act)
out = local_feat * sig_act
return out
class UConvBlock(nn.Module):
"""
This class defines the block which performs successive downsampling and
upsampling in order to be able to analyze the input features in multiple
resolutions.
"""
def __init__(self, out_channels=128, in_channels=512, upsampling_depth=4, model_T=True):
super().__init__()
self.proj_1x1 = ConvNormAct(out_channels, in_channels, 1, stride=1, groups=1)
self.depth = upsampling_depth
self.spp_dw = nn.ModuleList()
self.spp_dw.append(
DilatedConvNorm(
in_channels, in_channels, kSize=5, stride=1, groups=in_channels, d=1
)
)
for i in range(1, upsampling_depth):
self.spp_dw.append(
DilatedConvNorm(
in_channels,
in_channels,
kSize=5,
stride=2,
groups=in_channels,
d=1,
)
)
self.loc_glo_fus = nn.ModuleList([])
for i in range(upsampling_depth):
self.loc_glo_fus.append(InjectionMultiSum(in_channels, in_channels))
self.res_conv = nn.Conv1d(in_channels, out_channels, 1)
self.globalatt = Mlp(in_channels, in_channels, drop=0.1)
self.last_layer = nn.ModuleList([])
for i in range(self.depth - 1):
self.last_layer.append(InjectionMultiSum(in_channels, in_channels, 5))
def forward(self, x):
"""
:param x: input feature map
:return: transformed feature map
"""
residual = x.clone()
# Reduce --> project high-dimensional feature maps to low-dimensional space
output1 = self.proj_1x1(x)
output = [self.spp_dw[0](output1)]
# Do the downsampling process from the previous level
for k in range(1, self.depth):
out_k = self.spp_dw[k](output[-1])
output.append(out_k)
# global features
global_f = torch.zeros(
output[-1].shape, requires_grad=True, device=output1.device
)
for fea in output:
global_f = global_f + F.adaptive_avg_pool1d(
fea, output_size=output[-1].shape[-1]
)
# global_f = global_f + fea
global_f = self.globalatt(global_f) # [B, N, T]
x_fused = []
# Gather them now in reverse order
for idx in range(self.depth):
local = output[idx]
x_fused.append(self.loc_glo_fus[idx](local, global_f))
expanded = None
for i in range(self.depth - 2, -1, -1):
if i == self.depth - 2:
expanded = self.last_layer[i](x_fused[i], x_fused[i - 1])
else:
expanded = self.last_layer[i](x_fused[i], expanded)
# import pdb; pdb.set_trace()
return self.res_conv(expanded) + residual
class MultiHeadSelfAttention2D(nn.Module):
def __init__(
self,
in_chan: int,
n_freqs: int,
n_head: int = 4,
hid_chan: int = 4,
act_type: str = "prelu",
norm_type: str = "LayerNormalization4D",
dim: int = 3,
*args,
**kwargs,
):
super(MultiHeadSelfAttention2D, self).__init__()
self.in_chan = in_chan
self.n_freqs = n_freqs
self.n_head = n_head
self.hid_chan = hid_chan
self.act_type = act_type
self.norm_type = norm_type
self.dim = dim
assert self.in_chan % self.n_head == 0
self.Queries = nn.ModuleList()
self.Keys = nn.ModuleList()
self.Values = nn.ModuleList()
for _ in range(self.n_head):
self.Queries.append(
ATTConvActNorm(
in_chan=self.in_chan,
out_chan=self.hid_chan,
kernel_size=1,
act_type=self.act_type,
norm_type=self.norm_type,
n_freqs=self.n_freqs,
is2d=True,
)
)
self.Keys.append(
ATTConvActNorm(
in_chan=self.in_chan,
out_chan=self.hid_chan,
kernel_size=1,
act_type=self.act_type,
norm_type=self.norm_type,
n_freqs=self.n_freqs,
is2d=True,
)
)
self.Values.append(
ATTConvActNorm(
in_chan=self.in_chan,
out_chan=self.in_chan // self.n_head,
kernel_size=1,
act_type=self.act_type,
norm_type=self.norm_type,
n_freqs=self.n_freqs,
is2d=True,
)
)
self.attn_concat_proj = ATTConvActNorm(
in_chan=self.in_chan,
out_chan=self.in_chan,
kernel_size=1,
act_type=self.act_type,
norm_type=self.norm_type,
n_freqs=self.n_freqs,
is2d=True,
)
def forward(self, x: torch.Tensor):
if self.dim == 4:
x = x.transpose(-2, -1).contiguous()
batch_size, _, time, freq = x.size()
residual = x
all_Q = [q(x) for q in self.Queries] # [B, E, T, F]
all_K = [k(x) for k in self.Keys] # [B, E, T, F]
all_V = [v(x) for v in self.Values] # [B, C/n_head, T, F]
Q = torch.cat(all_Q, dim=0) # [B', E, T, F] B' = B*n_head
K = torch.cat(all_K, dim=0) # [B', E, T, F]
V = torch.cat(all_V, dim=0) # [B', C/n_head, T, F]
Q = Q.transpose(1, 2).flatten(start_dim=2) # [B', T, E*F]
K = K.transpose(1, 2).flatten(start_dim=2) # [B', T, E*F]
V = V.transpose(1, 2) # [B', T, C/n_head, F]
old_shape = V.shape
V = V.flatten(start_dim=2) # [B', T, C*F/n_head]
emb_dim = Q.shape[-1] # C*F/n_head
attn_mat = torch.matmul(Q, K.transpose(1, 2)) / (emb_dim**0.5) # [B', T, T]
attn_mat = F.softmax(attn_mat, dim=2) # [B', T, T]
V = torch.matmul(attn_mat, V) # [B', T, C*F/n_head]
V = V.reshape(old_shape) # [B', T, C/n_head, F]
V = V.transpose(1, 2) # [B', C/n_head, T, F]
emb_dim = V.shape[1] # C/n_head
x = V.view([self.n_head, batch_size, emb_dim, time, freq]) # [n_head, B, C/n_head, T, F]
x = x.transpose(0, 1).contiguous() # [B, n_head, C/n_head, T, F]
x = x.view([batch_size, self.n_head * emb_dim, time, freq]) # [B, C, T, F]
x = self.attn_concat_proj(x) # [B, C, T, F]
x = x + residual
if self.dim == 4:
x = x.transpose(-2, -1).contiguous()
return x
class Recurrent(nn.Module):
def __init__(
self,
out_channels=128,
in_channels=512,
nband=8,
upsampling_depth=3,
n_head=4,
att_hid_chan=4,
kernel_size: int = 8,
stride: int = 1,
_iter=4
):
super().__init__()
self.nband = nband
self.freq_path = nn.ModuleList([
UConvBlock(out_channels, in_channels, upsampling_depth),
MultiHeadSelfAttention2D(out_channels, 1, n_head=n_head, hid_chan=att_hid_chan, act_type="prelu", norm_type="LayerNormalization4D", dim=4),
normalizations.get("LayerNormalization4D")((out_channels, 1))
])
self.frame_path = nn.ModuleList([
UConvBlock(out_channels, in_channels, upsampling_depth),
MultiHeadSelfAttention2D(out_channels, 1, n_head=n_head, hid_chan=att_hid_chan, act_type="prelu", norm_type="LayerNormalization4D", dim=4),
normalizations.get("LayerNormalization4D")((out_channels, 1))
])
self.iter = _iter
self.concat_block = nn.Sequential(
nn.Conv2d(out_channels, out_channels, 1, 1, groups=out_channels), nn.PReLU()
)
def forward(self, x):
# B, nband, N, T
B, nband, N, T = x.shape
x = x.permute(0, 2, 1, 3).contiguous() # B, N, nband, T
mixture = x.clone()
for i in range(self.iter):
if i == 0:
x = self.freq_time_process(x, B, nband, N, T) # B, N, nband, T
else:
x = self.freq_time_process(self.concat_block(mixture + x), B, nband, N, T) # B, N, nband, T
return x.permute(0, 2, 1, 3).contiguous() # B, nband, N, T
def freq_time_process(self, x, B, nband, N, T):
# Process Frequency Path
residual_1 = x.clone()
x = x.permute(0, 3, 1, 2).contiguous() # B, T, N, nband
freq_fea = self.freq_path[0](x.view(B*T, N, nband)) # B*T, N, nband
freq_fea = freq_fea.view(B, T, N, nband).permute(0, 2, 1, 3).contiguous() # B, N, T, nband
freq_fea = self.freq_path[1](freq_fea) # B, N, T, nband
freq_fea = self.freq_path[2](freq_fea) # B, N, T, nband
freq_fea = freq_fea.permute(0, 1, 3, 2).contiguous()
x = freq_fea + residual_1 # B, N, nband, T
# Process Frame Path
residual_2 = x.clone()
x2 = x.permute(0, 2, 1, 3).contiguous()
frame_fea = self.frame_path[0](x2.view(B*nband, N, T)) # B*nband, N, T
frame_fea = frame_fea.view(B, nband, N, T).permute(0, 2, 1, 3).contiguous()
frame_fea = self.frame_path[1](frame_fea) # B, N, nband, T
frame_fea = self.frame_path[2](frame_fea) # B, N, nband, T
x = frame_fea + residual_2 # B, N, nband, T
return x
class TIGER(BaseModel):
def __init__(
self,
out_channels=128,
in_channels=512,
num_blocks=16,
upsampling_depth=4,
att_n_head=4,
att_hid_chan=4,
att_kernel_size=8,
att_stride=1,
win=2048,
stride=512,
num_sources=2,
sample_rate=44100,
):
super(TIGER, self).__init__(sample_rate=sample_rate)
self.sample_rate = sample_rate
self.win = win
self.stride = stride
self.group = self.win // 2
self.enc_dim = self.win // 2 + 1
self.feature_dim = out_channels
self.num_output = num_sources
self.eps = torch.finfo(torch.float32).eps
# 0-1k (25 hop), 1k-2k (100 hop), 2k-4k (250 hop), 4k-8k (500 hop)
bandwidth_25 = int(np.floor(25 / (sample_rate / 2.) * self.enc_dim))
bandwidth_100 = int(np.floor(100 / (sample_rate / 2.) * self.enc_dim))
bandwidth_250 = int(np.floor(250 / (sample_rate / 2.) * self.enc_dim))
bandwidth_500 = int(np.floor(500 / (sample_rate / 2.) * self.enc_dim))
self.band_width = [bandwidth_25]*40
self.band_width += [bandwidth_100]*10
self.band_width += [bandwidth_250]*8
self.band_width += [bandwidth_500]*8
self.band_width.append(self.enc_dim - np.sum(self.band_width))
self.nband = len(self.band_width)
print(self.band_width)
self.BN = nn.ModuleList([])
for i in range(self.nband):
self.BN.append(nn.Sequential(nn.GroupNorm(1, self.band_width[i]*2, self.eps),
nn.Conv1d(self.band_width[i]*2, self.feature_dim, 1)
)
)
self.separator = Recurrent(self.feature_dim, in_channels, self.nband, upsampling_depth, att_n_head, att_hid_chan, att_kernel_size, att_stride, num_blocks)
self.mask = nn.ModuleList([])
for i in range(self.nband):
self.mask.append(nn.Sequential(
nn.PReLU(),
nn.Conv1d(self.feature_dim, self.band_width[i]*4*num_sources, 1, groups=num_sources)
)
)
def pad_input(self, input, window, stride):
"""
Zero-padding input according to window/stride size.
"""
batch_size, nsample = input.shape
# pad the signals at the end for matching the window/stride size
rest = window - (stride + nsample % window) % window
if rest > 0:
pad = torch.zeros(batch_size, rest).type(input.type())
input = torch.cat([input, pad], 1)
pad_aux = torch.zeros(batch_size, stride).type(input.type())
input = torch.cat([pad_aux, input, pad_aux], 1)
return input, rest
def forward(self, input):
# input shape: (B, C, T)
was_one_d = False
if input.ndim == 1:
was_one_d = True
input = input.unsqueeze(0).unsqueeze(1)
if input.ndim == 2:
was_one_d = True
input = input.unsqueeze(1)
if input.ndim == 3:
input = input
batch_size, nch, nsample = input.shape
input = input.view(batch_size*nch, -1)
# frequency-domain separation
spec = torch.stft(input, n_fft=self.win, hop_length=self.stride,
window=torch.hann_window(self.win).to(input.device).type(input.type()),
return_complex=True)
# print(spec.shape)
# concat real and imag, split to subbands
spec_RI = torch.stack([spec.real, spec.imag], 1) # B*nch, 2, F, T
subband_spec_RI = []
subband_spec = []
band_idx = 0
for i in range(len(self.band_width)):
subband_spec_RI.append(spec_RI[:,:,band_idx:band_idx+self.band_width[i]].contiguous())
subband_spec.append(spec[:,band_idx:band_idx+self.band_width[i]]) # B*nch, BW, T
band_idx += self.band_width[i]
# normalization and bottleneck
subband_feature = []
for i in range(len(self.band_width)):
subband_feature.append(self.BN[i](subband_spec_RI[i].view(batch_size*nch, self.band_width[i]*2, -1)))
subband_feature = torch.stack(subband_feature, 1) # B, nband, N, T
# import pdb; pdb.set_trace()
# separator
sep_output = self.separator(subband_feature.view(batch_size*nch, self.nband, self.feature_dim, -1)) # B, nband, N, T
sep_output = sep_output.view(batch_size*nch, self.nband, self.feature_dim, -1)
sep_subband_spec = []
for i in range(self.nband):
this_output = self.mask[i](sep_output[:,i]).view(batch_size*nch, 2, 2, self.num_output, self.band_width[i], -1)
this_mask = this_output[:,0] * torch.sigmoid(this_output[:,1]) # B*nch, 2, K, BW, T
this_mask_real = this_mask[:,0] # B*nch, K, BW, T
this_mask_imag = this_mask[:,1] # B*nch, K, BW, T
# force mask sum to 1
this_mask_real_sum = this_mask_real.sum(1).unsqueeze(1) # B*nch, 1, BW, T
this_mask_imag_sum = this_mask_imag.sum(1).unsqueeze(1) # B*nch, 1, BW, T
this_mask_real = this_mask_real - (this_mask_real_sum - 1) / self.num_output
this_mask_imag = this_mask_imag - this_mask_imag_sum / self.num_output
est_spec_real = subband_spec[i].real.unsqueeze(1) * this_mask_real - subband_spec[i].imag.unsqueeze(1) * this_mask_imag # B*nch, K, BW, T
est_spec_imag = subband_spec[i].real.unsqueeze(1) * this_mask_imag + subband_spec[i].imag.unsqueeze(1) * this_mask_real # B*nch, K, BW, T
sep_subband_spec.append(torch.complex(est_spec_real, est_spec_imag))
sep_subband_spec = torch.cat(sep_subband_spec, 2)
output = torch.istft(sep_subband_spec.view(batch_size*nch*self.num_output, self.enc_dim, -1),
n_fft=self.win, hop_length=self.stride,
window=torch.hann_window(self.win).to(input.device).type(input.type()), length=nsample)
output = output.view(batch_size*nch, self.num_output, -1)
# if was_one_d:
# return output.squeeze(0)
return output
def get_model_args(self):
model_args = {"n_sample_rate": 2}
return model_args