Spaces:
Running
Running
import gradio as gr | |
import torch | |
import numpy as np | |
import librosa | |
import time | |
from src.utility import slice_songs | |
from src.models import MusicCNN, MusicCRNN2D | |
# Configuration | |
DESIRED_SR = 22050 | |
HOP_LENGTH = 512 | |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
NUM_CLASSES = 10 | |
# Model loading | |
models = {"cnn": {}, "crnn": {}} | |
cnn_model_paths = {1: "models/cnn/1s.pth", 3: "models/cnn/3s.pth", 5: "models/cnn/5s.pth", 10: "models/cnn/10s.pth"} | |
crnn_model_paths = {1: "models/crnn/1s.pth", 3: "models/crnn/3s.pth", 5: "models/crnn/5s.pth", | |
10: "models/crnn/10s.pth"} | |
def get_frames(slice_length): | |
return int(slice_length * DESIRED_SR / HOP_LENGTH) | |
# Load cnn models | |
for slice_len, path in cnn_model_paths.items(): | |
model = MusicCNN(num_classes=NUM_CLASSES, device=DEVICE) | |
dummy_input = torch.randn(2, 1, 128, get_frames(slice_len)).to(DEVICE) | |
_ = model(dummy_input) | |
model.load_state_dict(torch.load(path, map_location=DEVICE)) | |
model.to(DEVICE) | |
model.eval() | |
models["cnn"][slice_len] = model | |
# Load crnn models | |
for slice_len, path in crnn_model_paths.items(): | |
model = MusicCRNN2D(num_classes=NUM_CLASSES, device=DEVICE) | |
dummy_input = torch.randn(2, 1, 128, get_frames(slice_len)).to(DEVICE) | |
_ = model(dummy_input) | |
model.load_state_dict(torch.load(path, map_location=DEVICE)) | |
model.to(DEVICE) | |
model.eval() | |
models["crnn"][slice_len] = model | |
GENRE_LABELS = ["Blues", "Classical", "Country", "Disco", "HipHop", "Jazz", "Metal", "Pop", "Reggae", "Rock"] | |
def predict_genre(audio_file, slice_length, architecture): | |
slice_length = int(slice_length) | |
start_time = time.time() | |
y, sr = librosa.load(audio_file, sr=DESIRED_SR) | |
target_length = int(np.ceil(len(y) / sr)) * sr | |
if len(y) < target_length: | |
y = np.pad(y, (0, target_length - len(y)), mode='constant') | |
mel_spectrogram = librosa.feature.melspectrogram(y=y, sr=sr, n_fft=2048, hop_length=HOP_LENGTH, n_mels=128) | |
mel_spectrogram_db = librosa.power_to_db(mel_spectrogram, ref=np.max) | |
min_val, max_val = np.min(mel_spectrogram_db), np.max(mel_spectrogram_db) | |
normalized_spectrogram = (mel_spectrogram_db - min_val) / ( | |
max_val - min_val) if max_val - min_val > 0 else mel_spectrogram_db | |
X_slices, _, _ = slice_songs([normalized_spectrogram], [0], ["temp"], sr=sr, hop_length=HOP_LENGTH, | |
length_in_seconds=slice_length) | |
X_slices = torch.tensor(X_slices, dtype=torch.float32).unsqueeze(1).to(DEVICE) | |
model_used = models[architecture][slice_length] | |
with torch.no_grad(): | |
outputs = model_used(X_slices) | |
probabilities = torch.nn.functional.softmax(outputs, dim=1).cpu().numpy() | |
avg_probs = np.mean(probabilities, axis=0) | |
genre_distribution = {GENRE_LABELS[i]: float(avg_probs[i]) for i in range(NUM_CLASSES)} | |
inference_time = time.time() - start_time | |
return genre_distribution, f"Inference Time: {inference_time:.2f} seconds" | |
slice_length_dropdown = gr.Dropdown(choices=["1", "3", "5", "10"], value="1", label="Slice Length (seconds)") | |
architecture_dropdown = gr.Dropdown(choices=["cnn", "crnn"], value="cnn", label="Model Architecture") | |
demo = gr.Interface( | |
fn=predict_genre, | |
inputs=[gr.Audio(type="filepath", label="Upload Audio File"), slice_length_dropdown, architecture_dropdown], | |
outputs=[gr.Label(num_top_classes=10, label="Genre Distribution"), gr.Textbox(label="Inference Time")], | |
title="Music Genre Classifier", | |
description="Upload an audio file, select a slice length and model architecture to predict its genre distribution." | |
) | |
if __name__ == "__main__": | |
demo.launch() | |