sohamc10's picture
gradio app
9b0d6c2
import torch
import random
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.beta import Beta
def frame_shift(mels, labels, embeddings=None, pseudo_labels=None,
net_pooling=4, shift_range=0.125):
bsz, channels, n_bands, frames = mels.shape
abs_shift_mel = int(frames * shift_range)
if embeddings is not None:
embed_frames = embeddings.shape[-1]
embed_pool_fact = frames / embed_frames
for bindx in range(bsz):
shift = int(random.gauss(0, abs_shift_mel))
mels[bindx] = torch.roll(mels[bindx], shift, dims=-1)
label_shift = -abs(shift) / net_pooling if shift < 0 else shift / net_pooling
label_shift = round(label_shift)
labels[bindx] = torch.roll(labels[bindx], label_shift, dims=-1)
if pseudo_labels is not None:
pseudo_labels[bindx] = torch.roll(pseudo_labels[bindx], label_shift, dims=-1)
if embeddings is not None:
embed_shift = -abs(shift) / embed_pool_fact if shift < 0 else shift / embed_pool_fact
embed_shift = round(embed_shift)
embeddings[bindx] = torch.roll(embeddings[bindx], embed_shift, dims=-1)
out_args = [mels]
if embeddings is not None:
out_args.append(embeddings)
out_args.append(labels)
if pseudo_labels is not None:
out_args.append(pseudo_labels)
return tuple(out_args)
def time_mask(features, labels, embeddings=None, pseudo_labels=None, net_pooling=4,
min_mask_ratio=0.05, max_mask_ratio=0.2):
_, _, n_frame = labels.shape
if embeddings is not None:
embed_frames = embeddings.shape[-1]
embed_pool_fact = embed_frames / n_frame
t_width = torch.randint(low=int(n_frame * min_mask_ratio), high=int(n_frame * max_mask_ratio), size=(1,))
t_low = torch.randint(low=0, high=n_frame-t_width[0], size=(1,))
features[:, :, :, t_low * net_pooling:(t_low+t_width)*net_pooling] = 0
labels[:, :, t_low:t_low+t_width] = 0
if pseudo_labels is not None:
labels[:, :, t_low:t_low + t_width] = 0
if embeddings is not None:
low = round((t_low * embed_pool_fact).item())
high = round(((t_low + t_width) * embed_pool_fact).item())
embeddings[..., low:high] = 0
out_args = [features]
if embeddings is not None:
out_args.append(embeddings)
out_args.append(labels)
if pseudo_labels is not None:
out_args.append(pseudo_labels)
return tuple(out_args)
def mixup(data, embeddings=None, targets=None, pseudo_strong=None, alpha=0.2, beta=0.2, return_mix_coef=False):
with torch.no_grad():
batch_size = data.size(0)
c = np.random.beta(alpha, beta, size=batch_size)
c = np.maximum(c, 1 - c)
perm = torch.randperm(batch_size)
cd = torch.tensor(c, dtype=data.dtype, device=data.device).view(batch_size, *([1] * (data.ndim - 1)))
mixed_data = cd * data + (1 - cd) * data[perm, :]
if embeddings is not None:
ce = torch.tensor(c, dtype=embeddings.dtype, device=embeddings.device).view(batch_size, *([1] * (embeddings.ndim - 1)))
mixed_embeddings = ce * embeddings + (1 - ce) * embeddings[perm, :]
if targets is not None:
ct = torch.tensor(c, dtype=data.dtype, device=data.device).view(batch_size, *([1] * (targets.ndim - 1)))
mixed_target = torch.clamp(
ct * targets + (1 - ct) * targets[perm, :], min=0, max=1
)
if pseudo_strong is not None:
cp = torch.tensor(c, dtype=pseudo_strong.dtype, device=pseudo_strong.device).view(batch_size,
*([1] * (pseudo_strong.ndim - 1)))
mixed_pseudo_strong = cp * pseudo_strong + (1 - cp) * pseudo_strong[perm, :]
out_args = [mixed_data]
if embeddings is not None:
out_args.append(mixed_embeddings)
if targets is not None:
out_args.append(mixed_target)
if pseudo_strong is not None:
out_args.append(mixed_pseudo_strong)
if return_mix_coef:
out_args.append(perm)
out_args.append(c)
return tuple(out_args)
def filt_aug_(features, db_range=(-6, 6), n_band=(3, 6), min_bw=6):
batch_size, channels, n_freq_bin, _ = features.shape
n_freq_band = torch.randint(low=n_band[0], high=n_band[1], size=(1,)).item() # [low, high)
if n_freq_band > 1:
while n_freq_bin - n_freq_band * min_bw + 1 < 0:
min_bw -= 1
band_bndry_freqs = torch.sort(torch.randint(0, n_freq_bin - n_freq_band * min_bw + 1,
(n_freq_band - 1,)))[0] + \
torch.arange(1, n_freq_band) * min_bw
band_bndry_freqs = torch.cat((torch.tensor([0]), band_bndry_freqs, torch.tensor([n_freq_bin])))
band_factors = torch.rand((batch_size, n_freq_band + 1)).to(features) * (db_range[1] - db_range[0]) + db_range[0]
freq_filt = torch.ones((batch_size, n_freq_bin, 1)).to(features)
for i in range(n_freq_band):
for j in range(batch_size):
freq_filt[j, band_bndry_freqs[i]:band_bndry_freqs[i+1], :] = \
torch.linspace(band_factors[j, i], band_factors[j, i+1],
band_bndry_freqs[i+1] - band_bndry_freqs[i]).unsqueeze(-1)
freq_filt = 10 ** (freq_filt / 20)
return features * freq_filt.unsqueeze(1)
else:
return features
def filter_augmentation(features, n_transform=1, filter_db_range=(-6, 6),
filter_bands=(3, 6), filter_minimum_bandwidth=6):
if n_transform == 2:
feature_list = []
for _ in range(n_transform):
features_temp = features
features_temp = filt_aug_(features_temp, db_range=filter_db_range, n_band=filter_bands,
min_bw=filter_minimum_bandwidth)
feature_list.append(features_temp)
return feature_list
elif n_transform == 1:
features = filt_aug_(features, db_range=filter_db_range, n_band=filter_bands,
min_bw=filter_minimum_bandwidth)
return [features, features]
else:
return [features, features]
def mixstyle(x, alpha=0.4, eps=1e-6):
batch_size = x.size(0)
# frequency-wise statistics
f_mu = x.mean(dim=3, keepdim=True)
f_var = x.var(dim=3, keepdim=True)
f_sig = (f_var + eps).sqrt() # compute instance standard deviation
f_mu, f_sig = f_mu.detach(), f_sig.detach() # block gradients
x_normed = (x - f_mu) / f_sig # normalize input
lmda = Beta(alpha, alpha).sample((batch_size, 1, 1, 1)).to(x.device, dtype=x.dtype) # sample instance-wise convex weights
lmda = torch.max(lmda, 1-lmda)
perm = torch.randperm(batch_size).to(x.device) # generate shuffling indices
f_mu_perm, f_sig_perm = f_mu[perm], f_sig[perm] # shuffling
mu_mix = f_mu * lmda + f_mu_perm * (1 - lmda) # generate mixed mean
sig_mix = f_sig * lmda + f_sig_perm * (1 - lmda) # generate mixed standard deviation
x = x_normed * sig_mix + mu_mix # denormalize input using the mixed statistics
return x
class RandomResizeCrop(nn.Module):
"""Random Resize Crop block.
Args:
virtual_crop_scale: Virtual crop area `(F ratio, T ratio)` in ratio to input size.
freq_scale: Random frequency range `(min, max)`.
time_scale: Random time frame range `(min, max)`.
"""
def __init__(self, virtual_crop_scale=(1.0, 1.5), freq_scale=(0.6, 1.0), time_scale=(0.6, 1.5)):
super().__init__()
self.virtual_crop_scale = virtual_crop_scale
self.freq_scale = freq_scale
self.time_scale = time_scale
self.interpolation = 'bicubic'
assert time_scale[1] >= 1.0 and freq_scale[1] >= 1.0
@staticmethod
def get_params(virtual_crop_size, in_size, time_scale, freq_scale):
canvas_h, canvas_w = virtual_crop_size
src_h, src_w = in_size
h = np.clip(int(np.random.uniform(*freq_scale) * src_h), 1, canvas_h)
w = np.clip(int(np.random.uniform(*time_scale) * src_w), 1, canvas_w)
i = random.randint(0, canvas_h - h) if canvas_h > h else 0
j = random.randint(0, canvas_w - w) if canvas_w > w else 0
return i, j, h, w
def forward(self, lms):
# spec_output = []
# for lms in specs:
# lms = lms.unsqueeze(0)
# make virtual_crop_arear empty space (virtual crop area) and copy the input log mel spectrogram to th the center
virtual_crop_size = [int(s * c) for s, c in zip(lms.shape[-2:], self.virtual_crop_scale)]
virtual_crop_area = (torch.zeros((lms.shape[0], virtual_crop_size[0], virtual_crop_size[1]))
.to(torch.float).to(lms.device))
_, lh, lw = virtual_crop_area.shape
c, h, w = lms.shape
x, y = (lw - w) // 2, (lh - h) // 2
virtual_crop_area[:, y:y+h, x:x+w] = lms
# get random area
i, j, h, w = self.get_params(virtual_crop_area.shape[-2:], lms.shape[-2:], self.time_scale, self.freq_scale)
crop = virtual_crop_area[:, i:i+h, j:j+w]
# print(f'shapes {virtual_crop_area.shape} {crop.shape} -> {lms.shape}')
lms = F.interpolate(crop.unsqueeze(1), size=lms.shape[-2:],
mode=self.interpolation, align_corners=True).squeeze(1)
# spec_output.append(lms.float())
return lms.float() # torch.concat(lms, dim=0)
def __repr__(self):
format_string = self.__class__.__name__ + f'(virtual_crop_size={self.virtual_crop_scale}'
format_string += ', time_scale={0}'.format(tuple(round(s, 4) for s in self.time_scale))
format_string += ', freq_scale={0})'.format(tuple(round(r, 4) for r in self.freq_scale))
return format_string