File size: 3,701 Bytes
af51a98
 
 
 
 
f6d8d83
af51a98
 
 
 
 
 
 
 
f6d8d83
af51a98
f6d8d83
af51a98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import gradio as gr
import torch
import numpy as np
import librosa
import time
from src.utility import slice_songs
from src.models import MusicCNN, MusicCRNN2D

# Configuration
DESIRED_SR = 22050
HOP_LENGTH = 512
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
NUM_CLASSES = 10

# Model loading
models = {"cnn": {}, "crnn": {}}

cnn_model_paths = {1: "models/cnn/1s.pth", 3: "models/cnn/3s.pth", 5: "models/cnn/5s.pth", 10: "models/cnn/10s.pth"}
crnn_model_paths = {1: "models/crnn/1s.pth", 3: "models/crnn/3s.pth", 5: "models/crnn/5s.pth",
                    10: "models/crnn/10s.pth"}


def get_frames(slice_length):
    return int(slice_length * DESIRED_SR / HOP_LENGTH)


# Load cnn models
for slice_len, path in cnn_model_paths.items():
    model = MusicCNN(num_classes=NUM_CLASSES, device=DEVICE)
    dummy_input = torch.randn(2, 1, 128, get_frames(slice_len)).to(DEVICE)
    _ = model(dummy_input)
    model.load_state_dict(torch.load(path, map_location=DEVICE))
    model.to(DEVICE)
    model.eval()
    models["cnn"][slice_len] = model

# Load crnn models
for slice_len, path in crnn_model_paths.items():
    model = MusicCRNN2D(num_classes=NUM_CLASSES, device=DEVICE)
    dummy_input = torch.randn(2, 1, 128, get_frames(slice_len)).to(DEVICE)
    _ = model(dummy_input)
    model.load_state_dict(torch.load(path, map_location=DEVICE))
    model.to(DEVICE)
    model.eval()
    models["crnn"][slice_len] = model

GENRE_LABELS = ["Blues", "Classical", "Country", "Disco", "HipHop", "Jazz", "Metal", "Pop", "Reggae", "Rock"]


def predict_genre(audio_file, slice_length, architecture):
    slice_length = int(slice_length)
    start_time = time.time()

    y, sr = librosa.load(audio_file, sr=DESIRED_SR)
    target_length = int(np.ceil(len(y) / sr)) * sr
    if len(y) < target_length:
        y = np.pad(y, (0, target_length - len(y)), mode='constant')

    mel_spectrogram = librosa.feature.melspectrogram(y=y, sr=sr, n_fft=2048, hop_length=HOP_LENGTH, n_mels=128)
    mel_spectrogram_db = librosa.power_to_db(mel_spectrogram, ref=np.max)
    min_val, max_val = np.min(mel_spectrogram_db), np.max(mel_spectrogram_db)
    normalized_spectrogram = (mel_spectrogram_db - min_val) / (
                max_val - min_val) if max_val - min_val > 0 else mel_spectrogram_db

    X_slices, _, _ = slice_songs([normalized_spectrogram], [0], ["temp"], sr=sr, hop_length=HOP_LENGTH,
                                 length_in_seconds=slice_length)
    X_slices = torch.tensor(X_slices, dtype=torch.float32).unsqueeze(1).to(DEVICE)

    model_used = models[architecture][slice_length]
    with torch.no_grad():
        outputs = model_used(X_slices)
        probabilities = torch.nn.functional.softmax(outputs, dim=1).cpu().numpy()

    avg_probs = np.mean(probabilities, axis=0)
    genre_distribution = {GENRE_LABELS[i]: float(avg_probs[i]) for i in range(NUM_CLASSES)}
    inference_time = time.time() - start_time
    return genre_distribution, f"Inference Time: {inference_time:.2f} seconds"


slice_length_dropdown = gr.Dropdown(choices=["1", "3", "5", "10"], value="1", label="Slice Length (seconds)")
architecture_dropdown = gr.Dropdown(choices=["cnn", "crnn"], value="cnn", label="Model Architecture")

demo = gr.Interface(
    fn=predict_genre,
    inputs=[gr.Audio(type="filepath", label="Upload Audio File"), slice_length_dropdown, architecture_dropdown],
    outputs=[gr.Label(num_top_classes=10, label="Genre Distribution"), gr.Textbox(label="Inference Time")],
    title="Music Genre Classifier",
    description="Upload an audio file, select a slice length and model architecture to predict its genre distribution."
)

if __name__ == "__main__":
    demo.launch()