Audio-Separator / look2hear /layers /normalizations.py
fffiloni's picture
Migrated from GitHub
406f22d verified
import torch
from torch.autograd import Variable
import torch.nn as nn
import numpy as np
from typing import List
from torch.nn.modules.batchnorm import _BatchNorm
from collections.abc import Iterable
def norm(x, dims: List[int], EPS: float = 1e-8):
mean = x.mean(dim=dims, keepdim=True)
var2 = torch.var(x, dim=dims, keepdim=True, unbiased=False)
value = (x - mean) / torch.sqrt(var2 + EPS)
return value
def glob_norm(x, ESP: float = 1e-8):
dims: List[int] = torch.arange(1, len(x.shape)).tolist()
return norm(x, dims, ESP)
class MLayerNorm(nn.Module):
def __init__(self, channel_size):
super().__init__()
self.channel_size = channel_size
self.gamma = nn.Parameter(torch.ones(channel_size), requires_grad=True)
self.beta = nn.Parameter(torch.ones(channel_size), requires_grad=True)
def apply_gain_and_bias(self, normed_x):
"""Assumes input of size `[batch, chanel, *]`."""
return (self.gamma * normed_x.transpose(1, -1) + self.beta).transpose(1, -1)
def forward(self, x, EPS: float = 1e-8):
pass
class GlobalLN(MLayerNorm):
def forward(self, x, EPS: float = 1e-8):
value = glob_norm(x, EPS)
return self.apply_gain_and_bias(value)
class ChannelLN(MLayerNorm):
def forward(self, x, EPS: float = 1e-8):
mean = torch.mean(x, dim=1, keepdim=True)
var = torch.var(x, dim=1, keepdim=True, unbiased=False)
return self.apply_gain_and_bias((x - mean) / (var + EPS).sqrt())
# class CumulateLN(MLayerNorm):
# def forward(self, x, EPS: float = 1e-8):
# batch, channels, time = x.size()
# cum_sum = torch.cumsum(x.sum(1, keepdim=True), dim=1)
# cum_pow_sum = torch.cumsum(x.pow(2).sum(1, keepdim=True), dim=1)
# cnt = torch.arange(
# start=channels, end=channels * (time + 1), step=channels, dtype=x.dtype, device=x.device
# ).view(1, 1, -1)
# cum_mean = cum_sum / cnt
# cum_var = (cum_pow_sum / cnt) - cum_mean.pow(2)
# return self.apply_gain_and_bias((x - cum_mean) / (cum_var + EPS).sqrt())
class BatchNorm(_BatchNorm):
"""Wrapper class for pytorch BatchNorm1D and BatchNorm2D"""
def _check_input_dim(self, input):
if input.dim() < 2 or input.dim() > 4:
raise ValueError(
"expected 4D or 3D input (got {}D input)".format(input.dim())
)
class CumulativeLayerNorm(nn.LayerNorm):
def __init__(self, dim, elementwise_affine=True):
super(CumulativeLayerNorm, self).__init__(
dim, elementwise_affine=elementwise_affine
)
def forward(self, x):
# x: N x C x L
# N x L x C
x = torch.transpose(x, 1, -1)
# N x L x C == only channel norm
x = super().forward(x)
# N x C x L
x = torch.transpose(x, 1, -1)
return x
class CumulateLN(nn.Module):
def __init__(self, dimension, eps=1e-8, trainable=True):
super(CumulateLN, self).__init__()
self.eps = eps
if trainable:
self.gain = nn.Parameter(torch.ones(1, dimension, 1))
self.bias = nn.Parameter(torch.zeros(1, dimension, 1))
else:
self.gain = Variable(torch.ones(1, dimension, 1), requires_grad=False)
self.bias = Variable(torch.zeros(1, dimension, 1), requires_grad=False)
def forward(self, input):
# input size: (Batch, Freq, Time)
# cumulative mean for each time step
batch_size = input.size(0)
channel = input.size(1)
time_step = input.size(2)
step_sum = input.sum(1) # B, T
step_pow_sum = input.pow(2).sum(1) # B, T
cum_sum = torch.cumsum(step_sum, dim=1) # B, T
cum_pow_sum = torch.cumsum(step_pow_sum, dim=1) # B, T
entry_cnt = np.arange(channel, channel * (time_step + 1), channel)
entry_cnt = torch.from_numpy(entry_cnt).type(input.type())
entry_cnt = entry_cnt.view(1, -1).expand_as(cum_sum)
cum_mean = cum_sum / entry_cnt # B, T
cum_var = (cum_pow_sum - 2 * cum_mean * cum_sum) / entry_cnt + cum_mean.pow(
2
) # B, T
cum_std = (cum_var + self.eps).sqrt() # B, T
cum_mean = cum_mean.unsqueeze(1)
cum_std = cum_std.unsqueeze(1)
x = (input - cum_mean.expand_as(input)) / cum_std.expand_as(input)
return x * self.gain.expand_as(x).type(x.type()) + self.bias.expand_as(x).type(
x.type()
)
class LayerNormalization4D(nn.Module):
def __init__(self, input_dimension: Iterable, eps: float = 1e-5):
super(LayerNormalization4D, self).__init__()
assert len(input_dimension) == 2
param_size = [1, input_dimension[0], 1, input_dimension[1]]
self.dim = (1, 3) if param_size[-1] > 1 else (1,)
self.gamma = nn.Parameter(torch.Tensor(*param_size).to(torch.float32))
self.beta = nn.Parameter(torch.Tensor(*param_size).to(torch.float32))
nn.init.ones_(self.gamma)
nn.init.zeros_(self.beta)
self.eps = eps
def forward(self, x: torch.Tensor):
mu_ = x.mean(dim=self.dim, keepdim=True)
std_ = torch.sqrt(x.var(dim=self.dim, unbiased=False, keepdim=True) + self.eps)
x_hat = ((x - mu_) / std_) * self.gamma + self.beta
return x_hat
# Aliases.
gLN = GlobalLN
cLN = CumulateLN
LN = CumulativeLayerNorm
bN = BatchNorm
LN4D = LayerNormalization4D
def get(identifier):
"""Returns a norm class from a string. Returns its input if it
is callable (already a :class:`._LayerNorm` for example).
Args:
identifier (str or Callable or None): the norm identifier.
Returns:
:class:`._LayerNorm` or None
"""
if identifier is None:
return None
elif callable(identifier):
return identifier
elif isinstance(identifier, str):
cls = globals().get(identifier)
if cls is None:
raise ValueError(
"Could not interpret normalization identifier: " + str(identifier)
)
return cls
else:
raise ValueError(
"Could not interpret normalization identifier: " + str(identifier)
)