sohamc10's picture
gradio app
9b0d6c2
import torch
from torchaudio.transforms import AmplitudeToDB, MelSpectrogram
from models.atstframe.audio_transformer import FrameASTModel
from models.transformer_wrapper import BaseModelWrapper
class ATSTWrapper(BaseModelWrapper):
def __init__(self, atst_dropout=0.0) -> None:
super().__init__()
self.atst_mel = ATSTMel()
self.atst = FrameASTModel(atst_dropout=atst_dropout)
self.fake_length = torch.tensor([1001])
self.cls_embed = None
def mel_forward(self, x):
return self.atst_mel(x)
def forward(self, spec):
atst_x = self.atst.get_intermediate_layers(
spec,
self.fake_length.to(spec).repeat(len(spec)),
1,
scene=False
)
return atst_x
def separate_params(self):
pt_params = [[], [], [], [], [], [], [], [], [], [], [], []]
for k, p in self.named_parameters():
if k in ['atst.mask_embed', 'atst.pos_embed', 'atst.patch_embed.patch_embed.weight',
'atst.patch_embed.patch_embed.bias'] or "blocks.0." in k:
pt_params[0].append(p)
elif "blocks.1." in k:
pt_params[1].append(p)
elif "blocks.2." in k:
pt_params[2].append(p)
elif "blocks.3." in k:
pt_params[3].append(p)
elif "blocks.4." in k:
pt_params[4].append(p)
elif "blocks.5." in k:
pt_params[5].append(p)
elif "blocks.6." in k:
pt_params[6].append(p)
elif "blocks.7." in k:
pt_params[7].append(p)
elif "blocks.8" in k:
pt_params[8].append(p)
elif "blocks.9." in k:
pt_params[9].append(p)
elif "blocks.10." in k:
pt_params[10].append(p)
elif "blocks.11." in k or ".norm_frame." in k:
pt_params[11].append(p)
else:
raise ValueError(f"Check separate params for ATST! Unknown key: {k}")
return list(reversed(pt_params))
class ATSTMel(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.mel_transform = MelSpectrogram(
16000,
f_min=60,
f_max=7800,
hop_length=160,
win_length=1024,
n_fft=1024,
n_mels=64
)
self.amp_to_db = AmplitudeToDB(stype="power", top_db=80)
self.scaler = MinMax(min=-79.6482, max=50.6842)
def amp2db(self, spec):
return self.amp_to_db(spec).clamp(min=-50, max=80)
def forward(self, audio):
with torch.autocast(device_type="cuda", enabled=False):
spec = self.mel_transform(audio)
spec = self.scaler(self.amp2db(spec))
spec = spec.unsqueeze(1)
return spec
class CustomAudioTransform:
def __repr__(self):
return self.__class__.__name__ + '()'
class MinMax(CustomAudioTransform):
def __init__(self, min, max):
self.min = min
self.max = max
def __call__(self, input):
if self.min is None:
min_ = torch.min(input)
max_ = torch.max(input)
else:
min_ = self.min
max_ = self.max
input = (input - min_) / (max_ - min_) * 2. - 1.
return input