Spaces:
Sleeping
Sleeping
import torch | |
import random | |
import numpy as np | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.distributions.beta import Beta | |
def frame_shift(mels, labels, embeddings=None, pseudo_labels=None, | |
net_pooling=4, shift_range=0.125): | |
bsz, channels, n_bands, frames = mels.shape | |
abs_shift_mel = int(frames * shift_range) | |
if embeddings is not None: | |
embed_frames = embeddings.shape[-1] | |
embed_pool_fact = frames / embed_frames | |
for bindx in range(bsz): | |
shift = int(random.gauss(0, abs_shift_mel)) | |
mels[bindx] = torch.roll(mels[bindx], shift, dims=-1) | |
label_shift = -abs(shift) / net_pooling if shift < 0 else shift / net_pooling | |
label_shift = round(label_shift) | |
labels[bindx] = torch.roll(labels[bindx], label_shift, dims=-1) | |
if pseudo_labels is not None: | |
pseudo_labels[bindx] = torch.roll(pseudo_labels[bindx], label_shift, dims=-1) | |
if embeddings is not None: | |
embed_shift = -abs(shift) / embed_pool_fact if shift < 0 else shift / embed_pool_fact | |
embed_shift = round(embed_shift) | |
embeddings[bindx] = torch.roll(embeddings[bindx], embed_shift, dims=-1) | |
out_args = [mels] | |
if embeddings is not None: | |
out_args.append(embeddings) | |
out_args.append(labels) | |
if pseudo_labels is not None: | |
out_args.append(pseudo_labels) | |
return tuple(out_args) | |
def time_mask(features, labels, embeddings=None, pseudo_labels=None, net_pooling=4, | |
min_mask_ratio=0.05, max_mask_ratio=0.2): | |
_, _, n_frame = labels.shape | |
if embeddings is not None: | |
embed_frames = embeddings.shape[-1] | |
embed_pool_fact = embed_frames / n_frame | |
t_width = torch.randint(low=int(n_frame * min_mask_ratio), high=int(n_frame * max_mask_ratio), size=(1,)) | |
t_low = torch.randint(low=0, high=n_frame-t_width[0], size=(1,)) | |
features[:, :, :, t_low * net_pooling:(t_low+t_width)*net_pooling] = 0 | |
labels[:, :, t_low:t_low+t_width] = 0 | |
if pseudo_labels is not None: | |
labels[:, :, t_low:t_low + t_width] = 0 | |
if embeddings is not None: | |
low = round((t_low * embed_pool_fact).item()) | |
high = round(((t_low + t_width) * embed_pool_fact).item()) | |
embeddings[..., low:high] = 0 | |
out_args = [features] | |
if embeddings is not None: | |
out_args.append(embeddings) | |
out_args.append(labels) | |
if pseudo_labels is not None: | |
out_args.append(pseudo_labels) | |
return tuple(out_args) | |
def mixup(data, embeddings=None, targets=None, pseudo_strong=None, alpha=0.2, beta=0.2, return_mix_coef=False): | |
with torch.no_grad(): | |
batch_size = data.size(0) | |
c = np.random.beta(alpha, beta, size=batch_size) | |
c = np.maximum(c, 1 - c) | |
perm = torch.randperm(batch_size) | |
cd = torch.tensor(c, dtype=data.dtype, device=data.device).view(batch_size, *([1] * (data.ndim - 1))) | |
mixed_data = cd * data + (1 - cd) * data[perm, :] | |
if embeddings is not None: | |
ce = torch.tensor(c, dtype=embeddings.dtype, device=embeddings.device).view(batch_size, *([1] * (embeddings.ndim - 1))) | |
mixed_embeddings = ce * embeddings + (1 - ce) * embeddings[perm, :] | |
if targets is not None: | |
ct = torch.tensor(c, dtype=data.dtype, device=data.device).view(batch_size, *([1] * (targets.ndim - 1))) | |
mixed_target = torch.clamp( | |
ct * targets + (1 - ct) * targets[perm, :], min=0, max=1 | |
) | |
if pseudo_strong is not None: | |
cp = torch.tensor(c, dtype=pseudo_strong.dtype, device=pseudo_strong.device).view(batch_size, | |
*([1] * (pseudo_strong.ndim - 1))) | |
mixed_pseudo_strong = cp * pseudo_strong + (1 - cp) * pseudo_strong[perm, :] | |
out_args = [mixed_data] | |
if embeddings is not None: | |
out_args.append(mixed_embeddings) | |
if targets is not None: | |
out_args.append(mixed_target) | |
if pseudo_strong is not None: | |
out_args.append(mixed_pseudo_strong) | |
if return_mix_coef: | |
out_args.append(perm) | |
out_args.append(c) | |
return tuple(out_args) | |
def filt_aug_(features, db_range=(-6, 6), n_band=(3, 6), min_bw=6): | |
batch_size, channels, n_freq_bin, _ = features.shape | |
n_freq_band = torch.randint(low=n_band[0], high=n_band[1], size=(1,)).item() # [low, high) | |
if n_freq_band > 1: | |
while n_freq_bin - n_freq_band * min_bw + 1 < 0: | |
min_bw -= 1 | |
band_bndry_freqs = torch.sort(torch.randint(0, n_freq_bin - n_freq_band * min_bw + 1, | |
(n_freq_band - 1,)))[0] + \ | |
torch.arange(1, n_freq_band) * min_bw | |
band_bndry_freqs = torch.cat((torch.tensor([0]), band_bndry_freqs, torch.tensor([n_freq_bin]))) | |
band_factors = torch.rand((batch_size, n_freq_band + 1)).to(features) * (db_range[1] - db_range[0]) + db_range[0] | |
freq_filt = torch.ones((batch_size, n_freq_bin, 1)).to(features) | |
for i in range(n_freq_band): | |
for j in range(batch_size): | |
freq_filt[j, band_bndry_freqs[i]:band_bndry_freqs[i+1], :] = \ | |
torch.linspace(band_factors[j, i], band_factors[j, i+1], | |
band_bndry_freqs[i+1] - band_bndry_freqs[i]).unsqueeze(-1) | |
freq_filt = 10 ** (freq_filt / 20) | |
return features * freq_filt.unsqueeze(1) | |
else: | |
return features | |
def filter_augmentation(features, n_transform=1, filter_db_range=(-6, 6), | |
filter_bands=(3, 6), filter_minimum_bandwidth=6): | |
if n_transform == 2: | |
feature_list = [] | |
for _ in range(n_transform): | |
features_temp = features | |
features_temp = filt_aug_(features_temp, db_range=filter_db_range, n_band=filter_bands, | |
min_bw=filter_minimum_bandwidth) | |
feature_list.append(features_temp) | |
return feature_list | |
elif n_transform == 1: | |
features = filt_aug_(features, db_range=filter_db_range, n_band=filter_bands, | |
min_bw=filter_minimum_bandwidth) | |
return [features, features] | |
else: | |
return [features, features] | |
def mixstyle(x, alpha=0.4, eps=1e-6): | |
batch_size = x.size(0) | |
# frequency-wise statistics | |
f_mu = x.mean(dim=3, keepdim=True) | |
f_var = x.var(dim=3, keepdim=True) | |
f_sig = (f_var + eps).sqrt() # compute instance standard deviation | |
f_mu, f_sig = f_mu.detach(), f_sig.detach() # block gradients | |
x_normed = (x - f_mu) / f_sig # normalize input | |
lmda = Beta(alpha, alpha).sample((batch_size, 1, 1, 1)).to(x.device, dtype=x.dtype) # sample instance-wise convex weights | |
lmda = torch.max(lmda, 1-lmda) | |
perm = torch.randperm(batch_size).to(x.device) # generate shuffling indices | |
f_mu_perm, f_sig_perm = f_mu[perm], f_sig[perm] # shuffling | |
mu_mix = f_mu * lmda + f_mu_perm * (1 - lmda) # generate mixed mean | |
sig_mix = f_sig * lmda + f_sig_perm * (1 - lmda) # generate mixed standard deviation | |
x = x_normed * sig_mix + mu_mix # denormalize input using the mixed statistics | |
return x | |
class RandomResizeCrop(nn.Module): | |
"""Random Resize Crop block. | |
Args: | |
virtual_crop_scale: Virtual crop area `(F ratio, T ratio)` in ratio to input size. | |
freq_scale: Random frequency range `(min, max)`. | |
time_scale: Random time frame range `(min, max)`. | |
""" | |
def __init__(self, virtual_crop_scale=(1.0, 1.5), freq_scale=(0.6, 1.0), time_scale=(0.6, 1.5)): | |
super().__init__() | |
self.virtual_crop_scale = virtual_crop_scale | |
self.freq_scale = freq_scale | |
self.time_scale = time_scale | |
self.interpolation = 'bicubic' | |
assert time_scale[1] >= 1.0 and freq_scale[1] >= 1.0 | |
def get_params(virtual_crop_size, in_size, time_scale, freq_scale): | |
canvas_h, canvas_w = virtual_crop_size | |
src_h, src_w = in_size | |
h = np.clip(int(np.random.uniform(*freq_scale) * src_h), 1, canvas_h) | |
w = np.clip(int(np.random.uniform(*time_scale) * src_w), 1, canvas_w) | |
i = random.randint(0, canvas_h - h) if canvas_h > h else 0 | |
j = random.randint(0, canvas_w - w) if canvas_w > w else 0 | |
return i, j, h, w | |
def forward(self, lms): | |
# spec_output = [] | |
# for lms in specs: | |
# lms = lms.unsqueeze(0) | |
# make virtual_crop_arear empty space (virtual crop area) and copy the input log mel spectrogram to th the center | |
virtual_crop_size = [int(s * c) for s, c in zip(lms.shape[-2:], self.virtual_crop_scale)] | |
virtual_crop_area = (torch.zeros((lms.shape[0], virtual_crop_size[0], virtual_crop_size[1])) | |
.to(torch.float).to(lms.device)) | |
_, lh, lw = virtual_crop_area.shape | |
c, h, w = lms.shape | |
x, y = (lw - w) // 2, (lh - h) // 2 | |
virtual_crop_area[:, y:y+h, x:x+w] = lms | |
# get random area | |
i, j, h, w = self.get_params(virtual_crop_area.shape[-2:], lms.shape[-2:], self.time_scale, self.freq_scale) | |
crop = virtual_crop_area[:, i:i+h, j:j+w] | |
# print(f'shapes {virtual_crop_area.shape} {crop.shape} -> {lms.shape}') | |
lms = F.interpolate(crop.unsqueeze(1), size=lms.shape[-2:], | |
mode=self.interpolation, align_corners=True).squeeze(1) | |
# spec_output.append(lms.float()) | |
return lms.float() # torch.concat(lms, dim=0) | |
def __repr__(self): | |
format_string = self.__class__.__name__ + f'(virtual_crop_size={self.virtual_crop_scale}' | |
format_string += ', time_scale={0}'.format(tuple(round(s, 4) for s in self.time_scale)) | |
format_string += ', freq_scale={0})'.format(tuple(round(r, 4) for r in self.freq_scale)) | |
return format_string | |