File size: 2,960 Bytes
9b0d6c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import numpy as np
import gradio as gr
from models.atstframe.ATSTF_wrapper import ATSTWrapper
from models.beats.BEATs_wrapper import BEATsWrapper
from models.frame_passt.fpasst_wrapper import FPaSSTWrapper
from models.m2d.M2D_wrapper import M2DWrapper
from models.asit.ASIT_wrapper import ASiTWrapper
from models.frame_mn.Frame_MN_wrapper import FrameMNWrapper
from models.prediction_wrapper import PredictionsWrapper
from models.frame_mn.utils import NAME_TO_WIDTH
import torch
from torch import nn
import pandas as pd

class TransformerClassifier(nn.Module):
    def __init__(self, model, n_classes):
        super(TransformerClassifier, self).__init__()
        self.model = model
        self.linear = nn.Linear(model.embed_dim, n_classes)

    def forward(self, x):
        mel = self.model.mel_forward(x)
        features = self.model(mel).squeeze(1)
        return self.linear(features)


def get_model(model_name):
    if model_name == "BEATs":
        beats = BEATsWrapper()
        model = PredictionsWrapper(beats, checkpoint=None, head_type=None, seq_len=1)
    elif model_name == "ATST-F":
        atst = ATSTWrapper()
        model = PredictionsWrapper(atst, checkpoint=None, head_type=None, seq_len=1)
    elif model_name == "fpasst":
        fpasst = FPaSSTWrapper()
        model = PredictionsWrapper(fpasst, checkpoint=None, head_type=None, seq_len=1)
    elif model_name == "M2D":
        m2d = M2DWrapper()
        model = PredictionsWrapper(m2d, checkpoint=None, head_type=None, seq_len=1,
                                    embed_dim=m2d.m2d.cfg.feature_d)
    elif model_name == "ASIT":
        asit = ASiTWrapper()
        model = PredictionsWrapper(asit, checkpoint=None, head_type=None, seq_len=1)
    elif model_name.startswith("frame_mn"):
        width = NAME_TO_WIDTH(model_name)
        frame_mn = FrameMNWrapper(width)
        embed_dim = frame_mn.state_dict()['frame_mn.features.16.1.bias'].shape[0]
        model = PredictionsWrapper(frame_mn, checkpoint=None, head_type=None, seq_len=1, embed_dim=embed_dim)
    else:
        raise NotImplementedError(f"Model {model_name} not (yet) implemented")
    main_model = TransformerClassifier(model, n_classes=88)
    # main_model.compile()
    main_model.load_state_dict(torch.load(f"resources/best_model_{model_name}.pth", map_location='cpu'))
    print(main_model)
    main_model.eval()
    return main_model

model = get_model("BEATs")
label_mapping = pd.read_csv("resources/labelvocabulary.csv", header=None, index_col=0).to_dict()[1]
def apply_sepia(input_audio):
    # Apply sepia effect to the audio
    waveform = torch.from_numpy(input_audio[1]).float()  # Convert to tensor
    output = model(waveform.unsqueeze(0))
    output = output.detach().cpu().numpy()
    output = np.argmax(output, axis=1)
    return int(label_mapping[str(output.item())])



demo = gr.Interface(apply_sepia, gr.Audio(max_length=4,), "number",title="NSynth Pitch Classification",)
demo.launch()