Spaces:
Sleeping
Sleeping
import numpy as np | |
import gradio as gr | |
from models.atstframe.ATSTF_wrapper import ATSTWrapper | |
from models.beats.BEATs_wrapper import BEATsWrapper | |
from models.frame_passt.fpasst_wrapper import FPaSSTWrapper | |
from models.m2d.M2D_wrapper import M2DWrapper | |
from models.asit.ASIT_wrapper import ASiTWrapper | |
from models.frame_mn.Frame_MN_wrapper import FrameMNWrapper | |
from models.prediction_wrapper import PredictionsWrapper | |
from models.frame_mn.utils import NAME_TO_WIDTH | |
import torch | |
from torch import nn | |
import pandas as pd | |
class TransformerClassifier(nn.Module): | |
def __init__(self, model, n_classes): | |
super(TransformerClassifier, self).__init__() | |
self.model = model | |
self.linear = nn.Linear(model.embed_dim, n_classes) | |
def forward(self, x): | |
mel = self.model.mel_forward(x) | |
features = self.model(mel).squeeze(1) | |
return self.linear(features) | |
def get_model(model_name): | |
if model_name == "BEATs": | |
beats = BEATsWrapper() | |
model = PredictionsWrapper(beats, checkpoint=None, head_type=None, seq_len=1) | |
elif model_name == "ATST-F": | |
atst = ATSTWrapper() | |
model = PredictionsWrapper(atst, checkpoint=None, head_type=None, seq_len=1) | |
elif model_name == "fpasst": | |
fpasst = FPaSSTWrapper() | |
model = PredictionsWrapper(fpasst, checkpoint=None, head_type=None, seq_len=1) | |
elif model_name == "M2D": | |
m2d = M2DWrapper() | |
model = PredictionsWrapper(m2d, checkpoint=None, head_type=None, seq_len=1, | |
embed_dim=m2d.m2d.cfg.feature_d) | |
elif model_name == "ASIT": | |
asit = ASiTWrapper() | |
model = PredictionsWrapper(asit, checkpoint=None, head_type=None, seq_len=1) | |
elif model_name.startswith("frame_mn"): | |
width = NAME_TO_WIDTH(model_name) | |
frame_mn = FrameMNWrapper(width) | |
embed_dim = frame_mn.state_dict()['frame_mn.features.16.1.bias'].shape[0] | |
model = PredictionsWrapper(frame_mn, checkpoint=None, head_type=None, seq_len=1, embed_dim=embed_dim) | |
else: | |
raise NotImplementedError(f"Model {model_name} not (yet) implemented") | |
main_model = TransformerClassifier(model, n_classes=88) | |
# main_model.compile() | |
main_model.load_state_dict(torch.load(f"resources/best_model_{model_name}.pth", map_location='cpu')) | |
print(main_model) | |
main_model.eval() | |
return main_model | |
model = get_model("BEATs") | |
label_mapping = pd.read_csv("resources/labelvocabulary.csv", header=None, index_col=0).to_dict()[1] | |
def apply_sepia(input_audio): | |
# Apply sepia effect to the audio | |
waveform = torch.from_numpy(input_audio[1]).float() # Convert to tensor | |
output = model(waveform.unsqueeze(0)) | |
output = output.detach().cpu().numpy() | |
output = np.argmax(output, axis=1) | |
return int(label_mapping[str(output.item())]) | |
demo = gr.Interface(apply_sepia, gr.Audio(max_length=4,), "number",title="NSynth Pitch Classification",) | |
demo.launch() | |