import numpy as np import gradio as gr from models.atstframe.ATSTF_wrapper import ATSTWrapper from models.beats.BEATs_wrapper import BEATsWrapper from models.frame_passt.fpasst_wrapper import FPaSSTWrapper from models.m2d.M2D_wrapper import M2DWrapper from models.asit.ASIT_wrapper import ASiTWrapper from models.frame_mn.Frame_MN_wrapper import FrameMNWrapper from models.prediction_wrapper import PredictionsWrapper from models.frame_mn.utils import NAME_TO_WIDTH import torch from torch import nn import pandas as pd class TransformerClassifier(nn.Module): def __init__(self, model, n_classes): super(TransformerClassifier, self).__init__() self.model = model self.linear = nn.Linear(model.embed_dim, n_classes) def forward(self, x): mel = self.model.mel_forward(x) features = self.model(mel).squeeze(1) return self.linear(features) def get_model(model_name): if model_name == "BEATs": beats = BEATsWrapper() model = PredictionsWrapper(beats, checkpoint=None, head_type=None, seq_len=1) elif model_name == "ATST-F": atst = ATSTWrapper() model = PredictionsWrapper(atst, checkpoint=None, head_type=None, seq_len=1) elif model_name == "fpasst": fpasst = FPaSSTWrapper() model = PredictionsWrapper(fpasst, checkpoint=None, head_type=None, seq_len=1) elif model_name == "M2D": m2d = M2DWrapper() model = PredictionsWrapper(m2d, checkpoint=None, head_type=None, seq_len=1, embed_dim=m2d.m2d.cfg.feature_d) elif model_name == "ASIT": asit = ASiTWrapper() model = PredictionsWrapper(asit, checkpoint=None, head_type=None, seq_len=1) elif model_name.startswith("frame_mn"): width = NAME_TO_WIDTH(model_name) frame_mn = FrameMNWrapper(width) embed_dim = frame_mn.state_dict()['frame_mn.features.16.1.bias'].shape[0] model = PredictionsWrapper(frame_mn, checkpoint=None, head_type=None, seq_len=1, embed_dim=embed_dim) else: raise NotImplementedError(f"Model {model_name} not (yet) implemented") main_model = TransformerClassifier(model, n_classes=88) # main_model.compile() main_model.load_state_dict(torch.load(f"resources/best_model_{model_name}.pth", map_location='cpu')) print(main_model) main_model.eval() return main_model model = get_model("BEATs") label_mapping = pd.read_csv("resources/labelvocabulary.csv", header=None, index_col=0).to_dict()[1] def apply_sepia(input_audio): # Apply sepia effect to the audio waveform = torch.from_numpy(input_audio[1]).float() # Convert to tensor output = model(waveform.unsqueeze(0)) output = output.detach().cpu().numpy() output = np.argmax(output, axis=1) return int(label_mapping[str(output.item())]) demo = gr.Interface(apply_sepia, gr.Audio(max_length=4,), "number",title="NSynth Pitch Classification",) demo.launch()