sohamc10's picture
gradio app
9b0d6c2
import numpy as np
import gradio as gr
from models.atstframe.ATSTF_wrapper import ATSTWrapper
from models.beats.BEATs_wrapper import BEATsWrapper
from models.frame_passt.fpasst_wrapper import FPaSSTWrapper
from models.m2d.M2D_wrapper import M2DWrapper
from models.asit.ASIT_wrapper import ASiTWrapper
from models.frame_mn.Frame_MN_wrapper import FrameMNWrapper
from models.prediction_wrapper import PredictionsWrapper
from models.frame_mn.utils import NAME_TO_WIDTH
import torch
from torch import nn
import pandas as pd
class TransformerClassifier(nn.Module):
def __init__(self, model, n_classes):
super(TransformerClassifier, self).__init__()
self.model = model
self.linear = nn.Linear(model.embed_dim, n_classes)
def forward(self, x):
mel = self.model.mel_forward(x)
features = self.model(mel).squeeze(1)
return self.linear(features)
def get_model(model_name):
if model_name == "BEATs":
beats = BEATsWrapper()
model = PredictionsWrapper(beats, checkpoint=None, head_type=None, seq_len=1)
elif model_name == "ATST-F":
atst = ATSTWrapper()
model = PredictionsWrapper(atst, checkpoint=None, head_type=None, seq_len=1)
elif model_name == "fpasst":
fpasst = FPaSSTWrapper()
model = PredictionsWrapper(fpasst, checkpoint=None, head_type=None, seq_len=1)
elif model_name == "M2D":
m2d = M2DWrapper()
model = PredictionsWrapper(m2d, checkpoint=None, head_type=None, seq_len=1,
embed_dim=m2d.m2d.cfg.feature_d)
elif model_name == "ASIT":
asit = ASiTWrapper()
model = PredictionsWrapper(asit, checkpoint=None, head_type=None, seq_len=1)
elif model_name.startswith("frame_mn"):
width = NAME_TO_WIDTH(model_name)
frame_mn = FrameMNWrapper(width)
embed_dim = frame_mn.state_dict()['frame_mn.features.16.1.bias'].shape[0]
model = PredictionsWrapper(frame_mn, checkpoint=None, head_type=None, seq_len=1, embed_dim=embed_dim)
else:
raise NotImplementedError(f"Model {model_name} not (yet) implemented")
main_model = TransformerClassifier(model, n_classes=88)
# main_model.compile()
main_model.load_state_dict(torch.load(f"resources/best_model_{model_name}.pth", map_location='cpu'))
print(main_model)
main_model.eval()
return main_model
model = get_model("BEATs")
label_mapping = pd.read_csv("resources/labelvocabulary.csv", header=None, index_col=0).to_dict()[1]
def apply_sepia(input_audio):
# Apply sepia effect to the audio
waveform = torch.from_numpy(input_audio[1]).float() # Convert to tensor
output = model(waveform.unsqueeze(0))
output = output.detach().cpu().numpy()
output = np.argmax(output, axis=1)
return int(label_mapping[str(output.item())])
demo = gr.Interface(apply_sepia, gr.Audio(max_length=4,), "number",title="NSynth Pitch Classification",)
demo.launch()