fffiloni's picture
Migrated from GitHub
406f22d verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from . import normalizations, activations
class _Chop1d(nn.Module):
"""To ensure the output length is the same as the input."""
def __init__(self, chop_size):
super().__init__()
self.chop_size = chop_size
def forward(self, x):
return x[..., : -self.chop_size].contiguous()
class Conv1DBlock(nn.Module):
def __init__(
self,
in_chan,
hid_chan,
skip_out_chan,
kernel_size,
padding,
dilation,
norm_type="gLN",
causal=False,
):
super(Conv1DBlock, self).__init__()
self.skip_out_chan = skip_out_chan
conv_norm = normalizations.get(norm_type)
in_conv1d = nn.Conv1d(in_chan, hid_chan, 1)
depth_conv1d = nn.Conv1d(
hid_chan,
hid_chan,
kernel_size,
padding=padding,
dilation=dilation,
groups=hid_chan,
)
if causal:
depth_conv1d = nn.Sequential(depth_conv1d, _Chop1d(padding))
self.shared_block = nn.Sequential(
in_conv1d,
nn.PReLU(),
conv_norm(hid_chan),
depth_conv1d,
nn.PReLU(),
conv_norm(hid_chan),
)
self.res_conv = nn.Conv1d(hid_chan, in_chan, 1)
if skip_out_chan:
self.skip_conv = nn.Conv1d(hid_chan, skip_out_chan, 1)
def forward(self, x):
r"""Input shape $(batch, feats, seq)$."""
shared_out = self.shared_block(x)
res_out = self.res_conv(shared_out)
if not self.skip_out_chan:
return res_out
skip_out = self.skip_conv(shared_out)
return res_out, skip_out
class ConvNormAct(nn.Module):
"""
This class defines the convolution layer with normalization and a PReLU
activation
"""
def __init__(
self,
in_chan,
out_chan,
kernel_size,
stride=1,
groups=1,
dilation=1,
padding=0,
norm_type="gLN",
act_type="prelu",
):
super(ConvNormAct, self).__init__()
self.conv = nn.Conv1d(
in_chan,
out_chan,
kernel_size,
stride=stride,
dilation=dilation,
padding=padding,
bias=True,
groups=groups,
)
self.norm = normalizations.get(norm_type)(out_chan)
self.act = activations.get(act_type)()
def forward(self, x):
output = self.conv(x)
output = self.norm(output)
return self.act(output)
class ConvNorm(nn.Module):
def __init__(
self,
in_chan,
out_chan,
kernel_size,
stride=1,
groups=1,
dilation=1,
padding=0,
norm_type="gLN",
):
super(ConvNorm, self).__init__()
self.conv = nn.Conv1d(
in_chan,
out_chan,
kernel_size,
stride,
padding,
dilation,
bias=True,
groups=groups,
)
self.norm = normalizations.get(norm_type)(out_chan)
def forward(self, x):
output = self.conv(x)
return self.norm(output)
class NormAct(nn.Module):
"""
This class defines a normalization and PReLU activation
"""
def __init__(
self, out_chan, norm_type="gLN", act_type="prelu",
):
"""
:param nOut: number of output channels
"""
super(NormAct, self).__init__()
# self.norm = nn.GroupNorm(1, nOut, eps=1e-08)
self.norm = normalizations.get(norm_type)(out_chan)
self.act = activations.get(act_type)()
def forward(self, input):
output = self.norm(input)
return self.act(output)
class Video1DConv(nn.Module):
"""
video part 1-D Conv Block
in_chan: video Encoder output channels
out_chan: dconv channels
kernel_size: the depthwise conv kernel size
dilation: the depthwise conv dilation
residual: Whether to use residual connection
skip_con: Whether to use skip connection
first_block: first block, not residual
"""
def __init__(
self,
in_chan,
out_chan,
kernel_size,
dilation=1,
residual=True,
skip_con=True,
first_block=True,
):
super(Video1DConv, self).__init__()
self.first_block = first_block
# first block, not residual
self.residual = residual and not first_block
self.bn = nn.BatchNorm1d(in_chan) if not first_block else None
self.relu = nn.ReLU() if not first_block else None
self.dconv = nn.Conv1d(
in_chan,
in_chan,
kernel_size,
groups=in_chan,
dilation=dilation,
padding=(dilation * (kernel_size - 1)) // 2,
bias=True,
)
self.bconv = nn.Conv1d(in_chan, out_chan, 1)
self.sconv = nn.Conv1d(in_chan, out_chan, 1)
self.skip_con = skip_con
def forward(self, x):
"""
x: [B, N, T]
out: [B, N, T]
"""
if not self.first_block:
y = self.bn(self.relu(x))
y = self.dconv(y)
else:
y = self.dconv(x)
# skip connection
if self.skip_con:
skip = self.sconv(y)
if self.residual:
y = y + x
return skip, y
else:
return skip, y
else:
y = self.bconv(y)
if self.residual:
y = y + x
return y
else:
return y
class Concat(nn.Module):
def __init__(self, ain_chan, vin_chan, out_chan):
super(Concat, self).__init__()
self.ain_chan = ain_chan
self.vin_chan = vin_chan
# project
self.conv1d = nn.Sequential(
nn.Conv1d(ain_chan + vin_chan, out_chan, 1), nn.PReLU()
)
def forward(self, a, v):
# up-sample video features
v = torch.nn.functional.interpolate(v, size=a.size(-1))
# concat: n x (A+V) x Ta
y = torch.cat([a, v], dim=1)
# conv1d
return self.conv1d(y)
class FRCNNBlock(nn.Module):
def __init__(
self,
in_chan=128,
out_chan=512,
upsampling_depth=4,
norm_type="gLN",
act_type="prelu",
):
super().__init__()
self.proj_1x1 = ConvNormAct(
in_chan,
out_chan,
kernel_size=1,
stride=1,
groups=1,
dilation=1,
padding=0,
norm_type=norm_type,
act_type=act_type,
)
self.depth = upsampling_depth
self.spp_dw = nn.ModuleList([])
self.spp_dw.append(
ConvNorm(
out_chan,
out_chan,
kernel_size=5,
stride=1,
groups=out_chan,
dilation=1,
padding=((5 - 1) // 2) * 1,
norm_type=norm_type,
)
)
# ----------Down Sample Layer----------
for i in range(1, upsampling_depth):
self.spp_dw.append(
ConvNorm(
out_chan,
out_chan,
kernel_size=5,
stride=2,
groups=out_chan,
dilation=1,
padding=((5 - 1) // 2) * 1,
norm_type=norm_type,
)
)
# ----------Fusion Layer----------
self.fuse_layers = nn.ModuleList([])
for i in range(upsampling_depth):
fuse_layer = nn.ModuleList([])
for j in range(upsampling_depth):
if i == j:
fuse_layer.append(None)
elif j - i == 1:
fuse_layer.append(None)
elif i - j == 1:
fuse_layer.append(
ConvNorm(
out_chan,
out_chan,
kernel_size=5,
stride=2,
groups=out_chan,
dilation=1,
padding=((5 - 1) // 2) * 1,
norm_type=norm_type,
)
)
self.fuse_layers.append(fuse_layer)
self.concat_layer = nn.ModuleList([])
# ----------Concat Layer----------
for i in range(upsampling_depth):
if i == 0 or i == upsampling_depth - 1:
self.concat_layer.append(
ConvNormAct(
out_chan * 2,
out_chan,
1,
1,
norm_type=norm_type,
act_type=act_type,
)
)
else:
self.concat_layer.append(
ConvNormAct(
out_chan * 3,
out_chan,
1,
1,
norm_type=norm_type,
act_type=act_type,
)
)
self.last_layer = nn.Sequential(
ConvNormAct(
out_chan * upsampling_depth,
out_chan,
1,
1,
norm_type=norm_type,
act_type=act_type,
)
)
self.res_conv = nn.Conv1d(out_chan, in_chan, 1)
# ----------parameters-------------
self.depth = upsampling_depth
def forward(self, x):
"""
:param x: input feature map
:return: transformed feature map
"""
residual = x.clone()
# Reduce --> project high-dimensional feature maps to low-dimensional space
output1 = self.proj_1x1(x)
output = [self.spp_dw[0](output1)]
for k in range(1, self.depth):
out_k = self.spp_dw[k](output[-1])
output.append(out_k)
x_fuse = []
for i in range(len(self.fuse_layers)):
wav_length = output[i].shape[-1]
y = torch.cat(
(
self.fuse_layers[i][0](output[i - 1])
if i - 1 >= 0
else torch.Tensor().to(output1.device),
output[i],
F.interpolate(output[i + 1], size=wav_length, mode="nearest")
if i + 1 < self.depth
else torch.Tensor().to(output1.device),
),
dim=1,
)
x_fuse.append(self.concat_layer[i](y))
wav_length = output[0].shape[-1]
for i in range(1, len(x_fuse)):
x_fuse[i] = F.interpolate(x_fuse[i], size=wav_length, mode="nearest")
concat = self.last_layer(torch.cat(x_fuse, dim=1))
expanded = self.res_conv(concat)
return expanded + residual
class Bottomup(nn.Module):
def __init__(
self,
in_chan=128,
out_chan=512,
upsampling_depth=4,
norm_type="gLN",
act_type="prelu",
):
super().__init__()
self.proj_1x1 = ConvNormAct(
in_chan,
out_chan,
kernel_size=1,
stride=1,
groups=1,
dilation=1,
padding=0,
norm_type=norm_type,
act_type=act_type,
)
self.depth = upsampling_depth
self.spp_dw = nn.ModuleList([])
self.spp_dw.append(
ConvNorm(
out_chan,
out_chan,
kernel_size=5,
stride=1,
groups=out_chan,
dilation=1,
padding=((5 - 1) // 2) * 1,
norm_type=norm_type,
)
)
# ----------Down Sample Layer----------
for i in range(1, upsampling_depth):
self.spp_dw.append(
ConvNorm(
out_chan,
out_chan,
kernel_size=5,
stride=2,
groups=out_chan,
dilation=1,
padding=((5 - 1) // 2) * 1,
norm_type=norm_type,
)
)
def forward(self, x):
residual = x.clone()
# Reduce --> project high-dimensional feature maps to low-dimensional space
output1 = self.proj_1x1(x)
output = [self.spp_dw[0](output1)]
for k in range(1, self.depth):
out_k = self.spp_dw[k](output[-1])
output.append(out_k)
return residual, output[-1], output
class BottomupTCN(nn.Module):
def __init__(
self,
in_chan=128,
out_chan=512,
upsampling_depth=4,
norm_type="gLN",
act_type="prelu",
):
super().__init__()
self.proj_1x1 = ConvNormAct(
in_chan,
out_chan,
kernel_size=1,
stride=1,
groups=1,
dilation=1,
padding=0,
norm_type=norm_type,
act_type=act_type,
)
self.depth = upsampling_depth
self.spp_dw = nn.ModuleList([])
self.spp_dw.append(
Video1DConv(out_chan, out_chan, 3, skip_con=False, first_block=True)
)
# ----------Down Sample Layer----------
for i in range(1, upsampling_depth):
self.spp_dw.append(
Video1DConv(out_chan, out_chan, 3, skip_con=False, first_block=False)
)
def forward(self, x):
residual = x.clone()
# Reduce --> project high-dimensional feature maps to low-dimensional space
output1 = self.proj_1x1(x)
output = [self.spp_dw[0](output1)]
for k in range(1, self.depth):
out_k = self.spp_dw[k](output[-1])
output.append(out_k)
return residual, output[-1], output
class Bottomup_Concat_Topdown(nn.Module):
def __init__(
self,
in_chan=128,
out_chan=512,
upsampling_depth=4,
norm_type="gLN",
act_type="prelu",
):
super().__init__()
# ----------Fusion Layer----------
self.fuse_layers = nn.ModuleList([])
for i in range(upsampling_depth):
fuse_layer = nn.ModuleList([])
for j in range(upsampling_depth):
if i == j:
fuse_layer.append(None)
elif j - i == 1:
fuse_layer.append(None)
elif i - j == 1:
fuse_layer.append(
ConvNorm(
out_chan,
out_chan,
kernel_size=5,
stride=2,
groups=out_chan,
dilation=1,
padding=((5 - 1) // 2) * 1,
norm_type=norm_type,
)
)
self.fuse_layers.append(fuse_layer)
self.concat_layer = nn.ModuleList([])
# ----------Concat Layer----------
for i in range(upsampling_depth):
if i == 0 or i == upsampling_depth - 1:
self.concat_layer.append(
ConvNormAct(
out_chan * 3,
out_chan,
1,
1,
norm_type=norm_type,
act_type=act_type,
)
)
else:
self.concat_layer.append(
ConvNormAct(
out_chan * 4,
out_chan,
1,
1,
norm_type=norm_type,
act_type=act_type,
)
)
self.last_layer = nn.Sequential(
ConvNormAct(
out_chan * upsampling_depth,
out_chan,
1,
1,
norm_type=norm_type,
act_type=act_type,
)
)
self.res_conv = nn.Conv1d(out_chan, in_chan, 1)
# ----------parameters-------------
self.depth = upsampling_depth
def forward(self, residual, bottomup, topdown):
x_fuse = []
for i in range(len(self.fuse_layers)):
wav_length = bottomup[i].shape[-1]
y = torch.cat(
(
self.fuse_layers[i][0](bottomup[i - 1])
if i - 1 >= 0
else torch.Tensor().to(bottomup[i].device),
bottomup[i],
F.interpolate(bottomup[i + 1], size=wav_length, mode="nearest")
if i + 1 < self.depth
else torch.Tensor().to(bottomup[i].device),
F.interpolate(topdown, size=wav_length, mode="nearest"),
),
dim=1,
)
x_fuse.append(self.concat_layer[i](y))
wav_length = bottomup[0].shape[-1]
for i in range(1, len(x_fuse)):
x_fuse[i] = F.interpolate(x_fuse[i], size=wav_length, mode="nearest")
concat = self.last_layer(torch.cat(x_fuse, dim=1))
expanded = self.res_conv(concat)
return expanded + residual
class Bottomup_Concat_Topdown_TCN(nn.Module):
def __init__(
self,
in_chan=128,
out_chan=512,
upsampling_depth=4,
norm_type="gLN",
act_type="prelu",
):
super().__init__()
# ----------Fusion Layer----------
self.fuse_layers = nn.ModuleList([])
for i in range(upsampling_depth):
fuse_layer = nn.ModuleList([])
for j in range(upsampling_depth):
if i == j:
fuse_layer.append(None)
elif j - i == 1:
fuse_layer.append(None)
elif i - j == 1:
fuse_layer.append(None)
self.fuse_layers.append(fuse_layer)
self.concat_layer = nn.ModuleList([])
# ----------Concat Layer----------
for i in range(upsampling_depth):
if i == 0 or i == upsampling_depth - 1:
self.concat_layer.append(
ConvNormAct(
out_chan * 3,
out_chan,
1,
1,
norm_type=norm_type,
act_type=act_type,
)
)
else:
self.concat_layer.append(
ConvNormAct(
out_chan * 4,
out_chan,
1,
1,
norm_type=norm_type,
act_type=act_type,
)
)
self.last_layer = nn.Sequential(
ConvNormAct(
out_chan * upsampling_depth,
out_chan,
1,
1,
norm_type=norm_type,
act_type=act_type,
)
)
self.res_conv = nn.Conv1d(out_chan, in_chan, 1)
# ----------parameters-------------
self.depth = upsampling_depth
def forward(self, residual, bottomup, topdown):
x_fuse = []
for i in range(len(self.fuse_layers)):
wav_length = bottomup[i].shape[-1]
y = torch.cat(
(
bottomup[i - 1]
if i - 1 >= 0
else torch.Tensor().to(bottomup[i].device),
bottomup[i],
bottomup[i + 1]
if i + 1 < self.depth
else torch.Tensor().to(bottomup[i].device),
F.interpolate(topdown, size=wav_length, mode="nearest"),
),
dim=1,
)
x_fuse.append(self.concat_layer[i](y))
concat = self.last_layer(torch.cat(x_fuse, dim=1))
expanded = self.res_conv(concat)
return expanded + residual
class FRCNNBlockTCN(nn.Module):
def __init__(
self,
in_chan=128,
out_chan=512,
upsampling_depth=4,
norm_type="gLN",
act_type="prelu",
):
super().__init__()
self.proj_1x1 = ConvNormAct(
in_chan,
out_chan,
kernel_size=1,
stride=1,
groups=1,
dilation=1,
padding=0,
norm_type=norm_type,
act_type=act_type,
)
self.depth = upsampling_depth
self.spp_dw = nn.ModuleList([])
self.spp_dw.append(
Video1DConv(out_chan, out_chan, 3, skip_con=False, first_block=True)
)
# ----------Down Sample Layer----------
for i in range(1, upsampling_depth):
self.spp_dw.append(
Video1DConv(out_chan, out_chan, 3, skip_con=False, first_block=False)
)
# ----------Fusion Layer----------
self.fuse_layers = nn.ModuleList([])
for i in range(upsampling_depth):
fuse_layer = nn.ModuleList([])
for j in range(upsampling_depth):
if i == j:
fuse_layer.append(None)
elif j - i == 1:
fuse_layer.append(None)
elif i - j == 1:
fuse_layer.append(None)
self.fuse_layers.append(fuse_layer)
self.concat_layer = nn.ModuleList([])
# ----------Concat Layer----------
for i in range(upsampling_depth):
if i == 0 or i == upsampling_depth - 1:
self.concat_layer.append(
ConvNormAct(
out_chan * 2,
out_chan,
1,
1,
norm_type=norm_type,
act_type=act_type,
)
)
else:
self.concat_layer.append(
ConvNormAct(
out_chan * 3,
out_chan,
1,
1,
norm_type=norm_type,
act_type=act_type,
)
)
self.last_layer = nn.Sequential(
ConvNormAct(
out_chan * upsampling_depth,
out_chan,
1,
1,
norm_type=norm_type,
act_type=act_type,
)
)
self.res_conv = nn.Conv1d(out_chan, in_chan, 1)
# ----------parameters-------------
self.depth = upsampling_depth
def forward(self, x):
"""
:param x: input feature map
:return: transformed feature map
"""
residual = x.clone()
# Reduce --> project high-dimensional feature maps to low-dimensional space
output1 = self.proj_1x1(x)
output = [self.spp_dw[0](output1)]
for k in range(1, self.depth):
out_k = self.spp_dw[k](output[-1])
output.append(out_k)
x_fuse = []
for i in range(len(self.fuse_layers)):
wav_length = output[i].shape[-1]
y = torch.cat(
(
output[i - 1] if i - 1 >= 0 else torch.Tensor().to(output1.device),
output[i],
output[i + 1]
if i + 1 < self.depth
else torch.Tensor().to(output1.device),
),
dim=1,
)
x_fuse.append(self.concat_layer[i](y))
concat = self.last_layer(torch.cat(x_fuse, dim=1))
expanded = self.res_conv(concat)
return expanded + residual
class TAC(nn.Module):
"""Transform-Average-Concatenate inter-microphone-channel permutation invariant communication block [1].
Args:
input_dim (int): Number of features of input representation.
hidden_dim (int, optional): size of hidden layers in TAC operations.
activation (str, optional): type of activation used. See asteroid.masknn.activations.
norm_type (str, optional): type of normalization layer used. See asteroid.masknn.norms.
.. note:: Supports inputs of shape :math:`(batch, mic\_channels, features, chunk\_size, n\_chunks)`
as in FasNet-TAC. The operations are applied for each element in ``chunk_size`` and ``n_chunks``.
Output is of same shape as input.
References
[1] : Luo, Yi, et al. "End-to-end microphone permutation and number invariant multi-channel
speech separation." ICASSP 2020.
"""
def __init__(self, input_dim, hidden_dim=384, activation="prelu", norm_type="gLN"):
super().__init__()
self.hidden_dim = hidden_dim
self.input_tf = nn.Sequential(
nn.Linear(input_dim, hidden_dim), activations.get(activation)()
)
self.avg_tf = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim), activations.get(activation)()
)
self.concat_tf = nn.Sequential(
nn.Linear(2 * hidden_dim, input_dim), activations.get(activation)()
)
self.norm = normalizations.get(norm_type)(input_dim)
def forward(self, x, valid_mics=None):
"""
Args:
x: (:class:`torch.Tensor`): Input multi-channel DPRNN features.
Shape: :math:`(batch, mic\_channels, features, chunk\_size, n\_chunks)`.
valid_mics: (:class:`torch.LongTensor`): tensor containing effective number of microphones on each batch.
Batches can be composed of examples coming from arrays with a different
number of microphones and thus the ``mic_channels`` dimension is padded.
E.g. torch.tensor([4, 3]) means first example has 4 channels and the second 3.
Shape: :math`(batch)`.
Returns:
output (:class:`torch.Tensor`): features for each mic_channel after TAC inter-channel processing.
Shape :math:`(batch, mic\_channels, features, chunk\_size, n\_chunks)`.
"""
# Input is 5D because it is multi-channel DPRNN. DPRNN single channel is 4D.
batch_size, nmics, channels, chunk_size, n_chunks = x.size()
if valid_mics is None:
valid_mics = torch.LongTensor([nmics] * batch_size)
# First operation: transform the input for each frame and independently on each mic channel.
output = self.input_tf(
x.permute(0, 3, 4, 1, 2).reshape(
batch_size * nmics * chunk_size * n_chunks, channels
)
).reshape(batch_size, chunk_size, n_chunks, nmics, self.hidden_dim)
# Mean pooling across channels
if valid_mics.max() == 0:
# Fixed geometry array
mics_mean = output.mean(1)
else:
# Only consider valid channels in each batch element: each example can have different number of microphones.
mics_mean = [
output[b, :, :, : valid_mics[b]].mean(2).unsqueeze(0)
for b in range(batch_size)
] # 1, dim1*dim2, H
mics_mean = torch.cat(mics_mean, 0) # B*dim1*dim2, H
# The average is processed by a non-linear transform
mics_mean = self.avg_tf(
mics_mean.reshape(batch_size * chunk_size * n_chunks, self.hidden_dim)
)
mics_mean = (
mics_mean.reshape(batch_size, chunk_size, n_chunks, self.hidden_dim)
.unsqueeze(3)
.expand_as(output)
)
# Concatenate the transformed average in each channel with the original feats and
# project back to same number of features
output = torch.cat([output, mics_mean], -1)
output = self.concat_tf(
output.reshape(batch_size * chunk_size * n_chunks * nmics, -1)
).reshape(batch_size, chunk_size, n_chunks, nmics, -1)
output = self.norm(
output.permute(0, 3, 4, 1, 2).reshape(
batch_size * nmics, -1, chunk_size, n_chunks
)
).reshape(batch_size, nmics, -1, chunk_size, n_chunks)
output += x
return output