Spaces:
Sleeping
Sleeping
import torch | |
from torchaudio.transforms import AmplitudeToDB, MelSpectrogram | |
from models.atstframe.audio_transformer import FrameASTModel | |
from models.transformer_wrapper import BaseModelWrapper | |
class ATSTWrapper(BaseModelWrapper): | |
def __init__(self, atst_dropout=0.0) -> None: | |
super().__init__() | |
self.atst_mel = ATSTMel() | |
self.atst = FrameASTModel(atst_dropout=atst_dropout) | |
self.fake_length = torch.tensor([1001]) | |
self.cls_embed = None | |
def mel_forward(self, x): | |
return self.atst_mel(x) | |
def forward(self, spec): | |
atst_x = self.atst.get_intermediate_layers( | |
spec, | |
self.fake_length.to(spec).repeat(len(spec)), | |
1, | |
scene=False | |
) | |
return atst_x | |
def separate_params(self): | |
pt_params = [[], [], [], [], [], [], [], [], [], [], [], []] | |
for k, p in self.named_parameters(): | |
if k in ['atst.mask_embed', 'atst.pos_embed', 'atst.patch_embed.patch_embed.weight', | |
'atst.patch_embed.patch_embed.bias'] or "blocks.0." in k: | |
pt_params[0].append(p) | |
elif "blocks.1." in k: | |
pt_params[1].append(p) | |
elif "blocks.2." in k: | |
pt_params[2].append(p) | |
elif "blocks.3." in k: | |
pt_params[3].append(p) | |
elif "blocks.4." in k: | |
pt_params[4].append(p) | |
elif "blocks.5." in k: | |
pt_params[5].append(p) | |
elif "blocks.6." in k: | |
pt_params[6].append(p) | |
elif "blocks.7." in k: | |
pt_params[7].append(p) | |
elif "blocks.8" in k: | |
pt_params[8].append(p) | |
elif "blocks.9." in k: | |
pt_params[9].append(p) | |
elif "blocks.10." in k: | |
pt_params[10].append(p) | |
elif "blocks.11." in k or ".norm_frame." in k: | |
pt_params[11].append(p) | |
else: | |
raise ValueError(f"Check separate params for ATST! Unknown key: {k}") | |
return list(reversed(pt_params)) | |
class ATSTMel(torch.nn.Module): | |
def __init__(self) -> None: | |
super().__init__() | |
self.mel_transform = MelSpectrogram( | |
16000, | |
f_min=60, | |
f_max=7800, | |
hop_length=160, | |
win_length=1024, | |
n_fft=1024, | |
n_mels=64 | |
) | |
self.amp_to_db = AmplitudeToDB(stype="power", top_db=80) | |
self.scaler = MinMax(min=-79.6482, max=50.6842) | |
def amp2db(self, spec): | |
return self.amp_to_db(spec).clamp(min=-50, max=80) | |
def forward(self, audio): | |
with torch.autocast(device_type="cuda", enabled=False): | |
spec = self.mel_transform(audio) | |
spec = self.scaler(self.amp2db(spec)) | |
spec = spec.unsqueeze(1) | |
return spec | |
class CustomAudioTransform: | |
def __repr__(self): | |
return self.__class__.__name__ + '()' | |
class MinMax(CustomAudioTransform): | |
def __init__(self, min, max): | |
self.min = min | |
self.max = max | |
def __call__(self, input): | |
if self.min is None: | |
min_ = torch.min(input) | |
max_ = torch.max(input) | |
else: | |
min_ = self.min | |
max_ = self.max | |
input = (input - min_) / (max_ - min_) * 2. - 1. | |
return input | |