MusicGenrePulse / app.py
Skynova's picture
Update app.py
f6d8d83 verified
import gradio as gr
import torch
import numpy as np
import librosa
import time
from src.utility import slice_songs
from src.models import MusicCNN, MusicCRNN2D
# Configuration
DESIRED_SR = 22050
HOP_LENGTH = 512
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
NUM_CLASSES = 10
# Model loading
models = {"cnn": {}, "crnn": {}}
cnn_model_paths = {1: "models/cnn/1s.pth", 3: "models/cnn/3s.pth", 5: "models/cnn/5s.pth", 10: "models/cnn/10s.pth"}
crnn_model_paths = {1: "models/crnn/1s.pth", 3: "models/crnn/3s.pth", 5: "models/crnn/5s.pth",
10: "models/crnn/10s.pth"}
def get_frames(slice_length):
return int(slice_length * DESIRED_SR / HOP_LENGTH)
# Load cnn models
for slice_len, path in cnn_model_paths.items():
model = MusicCNN(num_classes=NUM_CLASSES, device=DEVICE)
dummy_input = torch.randn(2, 1, 128, get_frames(slice_len)).to(DEVICE)
_ = model(dummy_input)
model.load_state_dict(torch.load(path, map_location=DEVICE))
model.to(DEVICE)
model.eval()
models["cnn"][slice_len] = model
# Load crnn models
for slice_len, path in crnn_model_paths.items():
model = MusicCRNN2D(num_classes=NUM_CLASSES, device=DEVICE)
dummy_input = torch.randn(2, 1, 128, get_frames(slice_len)).to(DEVICE)
_ = model(dummy_input)
model.load_state_dict(torch.load(path, map_location=DEVICE))
model.to(DEVICE)
model.eval()
models["crnn"][slice_len] = model
GENRE_LABELS = ["Blues", "Classical", "Country", "Disco", "HipHop", "Jazz", "Metal", "Pop", "Reggae", "Rock"]
def predict_genre(audio_file, slice_length, architecture):
slice_length = int(slice_length)
start_time = time.time()
y, sr = librosa.load(audio_file, sr=DESIRED_SR)
target_length = int(np.ceil(len(y) / sr)) * sr
if len(y) < target_length:
y = np.pad(y, (0, target_length - len(y)), mode='constant')
mel_spectrogram = librosa.feature.melspectrogram(y=y, sr=sr, n_fft=2048, hop_length=HOP_LENGTH, n_mels=128)
mel_spectrogram_db = librosa.power_to_db(mel_spectrogram, ref=np.max)
min_val, max_val = np.min(mel_spectrogram_db), np.max(mel_spectrogram_db)
normalized_spectrogram = (mel_spectrogram_db - min_val) / (
max_val - min_val) if max_val - min_val > 0 else mel_spectrogram_db
X_slices, _, _ = slice_songs([normalized_spectrogram], [0], ["temp"], sr=sr, hop_length=HOP_LENGTH,
length_in_seconds=slice_length)
X_slices = torch.tensor(X_slices, dtype=torch.float32).unsqueeze(1).to(DEVICE)
model_used = models[architecture][slice_length]
with torch.no_grad():
outputs = model_used(X_slices)
probabilities = torch.nn.functional.softmax(outputs, dim=1).cpu().numpy()
avg_probs = np.mean(probabilities, axis=0)
genre_distribution = {GENRE_LABELS[i]: float(avg_probs[i]) for i in range(NUM_CLASSES)}
inference_time = time.time() - start_time
return genre_distribution, f"Inference Time: {inference_time:.2f} seconds"
slice_length_dropdown = gr.Dropdown(choices=["1", "3", "5", "10"], value="1", label="Slice Length (seconds)")
architecture_dropdown = gr.Dropdown(choices=["cnn", "crnn"], value="cnn", label="Model Architecture")
demo = gr.Interface(
fn=predict_genre,
inputs=[gr.Audio(type="filepath", label="Upload Audio File"), slice_length_dropdown, architecture_dropdown],
outputs=[gr.Label(num_top_classes=10, label="Genre Distribution"), gr.Textbox(label="Inference Time")],
title="Music Genre Classifier",
description="Upload an audio file, select a slice length and model architecture to predict its genre distribution."
)
if __name__ == "__main__":
demo.launch()