Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
from typing import Tuple | |
from typing import Union | |
from funasr_detach.models.transformer.utils.nets_utils import make_non_pad_mask | |
from torch.nn import functional as F | |
import math | |
VAR2STD_EPSILON = 1e-12 | |
class StatisticPooling(torch.nn.Module): | |
def __init__(self, pooling_dim: Union[int, Tuple] = 2, eps=1e-12): | |
super(StatisticPooling, self).__init__() | |
if isinstance(pooling_dim, int): | |
pooling_dim = (pooling_dim,) | |
self.pooling_dim = pooling_dim | |
self.eps = eps | |
def forward(self, xs_pad, ilens=None): | |
# xs_pad in (Batch, Channel, Time, Frequency) | |
if ilens is None: | |
masks = torch.ones_like(xs_pad).to(xs_pad) | |
else: | |
masks = make_non_pad_mask(ilens, xs_pad, length_dim=2).to(xs_pad) | |
mean = torch.sum(xs_pad, dim=self.pooling_dim, keepdim=True) / torch.sum( | |
masks, dim=self.pooling_dim, keepdim=True | |
) | |
squared_difference = torch.pow(xs_pad - mean, 2.0) | |
variance = torch.sum( | |
squared_difference, dim=self.pooling_dim, keepdim=True | |
) / torch.sum(masks, dim=self.pooling_dim, keepdim=True) | |
for i in reversed(self.pooling_dim): | |
mean, variance = torch.squeeze(mean, dim=i), torch.squeeze(variance, dim=i) | |
mask = torch.less_equal(variance, self.eps).float() | |
variance = (1.0 - mask) * variance + mask * self.eps | |
stddev = torch.sqrt(variance) | |
stat_pooling = torch.cat([mean, stddev], dim=1) | |
return stat_pooling | |
def convert_tf2torch(self, var_dict_tf, var_dict_torch): | |
return {} | |
def statistic_pooling( | |
xs_pad: torch.Tensor, ilens: torch.Tensor = None, pooling_dim: Tuple = (2, 3) | |
) -> torch.Tensor: | |
# xs_pad in (Batch, Channel, Time, Frequency) | |
if ilens is None: | |
seq_mask = torch.ones_like(xs_pad).to(xs_pad) | |
else: | |
seq_mask = make_non_pad_mask(ilens, xs_pad, length_dim=2).to(xs_pad) | |
mean = torch.sum(xs_pad, dim=pooling_dim, keepdim=True) / torch.sum( | |
seq_mask, dim=pooling_dim, keepdim=True | |
) | |
squared_difference = torch.pow(xs_pad - mean, 2.0) | |
variance = torch.sum(squared_difference, dim=pooling_dim, keepdim=True) / torch.sum( | |
seq_mask, dim=pooling_dim, keepdim=True | |
) | |
for i in reversed(pooling_dim): | |
mean, variance = torch.squeeze(mean, dim=i), torch.squeeze(variance, dim=i) | |
value_mask = torch.less_equal(variance, VAR2STD_EPSILON).float() | |
variance = (1.0 - value_mask) * variance + value_mask * VAR2STD_EPSILON | |
stddev = torch.sqrt(variance) | |
stat_pooling = torch.cat([mean, stddev], dim=1) | |
return stat_pooling | |
def windowed_statistic_pooling( | |
xs_pad: torch.Tensor, | |
ilens: torch.Tensor = None, | |
pooling_dim: Tuple = (2, 3), | |
pooling_size: int = 20, | |
pooling_stride: int = 1, | |
) -> Tuple[torch.Tensor, int]: | |
# xs_pad in (Batch, Channel, Time, Frequency) | |
tt = xs_pad.shape[2] | |
num_chunk = int(math.ceil(tt / pooling_stride)) | |
pad = pooling_size // 2 | |
if len(xs_pad.shape) == 4: | |
features = F.pad(xs_pad, (0, 0, pad, pad), "replicate") | |
else: | |
features = F.pad(xs_pad, (pad, pad), "replicate") | |
stat_list = [] | |
for i in range(num_chunk): | |
# B x C | |
st, ed = i * pooling_stride, i * pooling_stride + pooling_size | |
stat = statistic_pooling(features[:, :, st:ed], pooling_dim=pooling_dim) | |
stat_list.append(stat.unsqueeze(2)) | |
# B x C x T | |
return torch.cat(stat_list, dim=2), ilens / pooling_stride | |