import gradio as gr import torch import numpy as np import librosa import time from src.utility import slice_songs from src.models import MusicCNN, MusicCRNN2D # Configuration DESIRED_SR = 22050 HOP_LENGTH = 512 DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") NUM_CLASSES = 10 # Model loading models = {"cnn": {}, "crnn": {}} cnn_model_paths = {1: "models/cnn/1s.pth", 3: "models/cnn/3s.pth", 5: "models/cnn/5s.pth", 10: "models/cnn/10s.pth"} crnn_model_paths = {1: "models/crnn/1s.pth", 3: "models/crnn/3s.pth", 5: "models/crnn/5s.pth", 10: "models/crnn/10s.pth"} def get_frames(slice_length): return int(slice_length * DESIRED_SR / HOP_LENGTH) # Load cnn models for slice_len, path in cnn_model_paths.items(): model = MusicCNN(num_classes=NUM_CLASSES, device=DEVICE) dummy_input = torch.randn(2, 1, 128, get_frames(slice_len)).to(DEVICE) _ = model(dummy_input) model.load_state_dict(torch.load(path, map_location=DEVICE)) model.to(DEVICE) model.eval() models["cnn"][slice_len] = model # Load crnn models for slice_len, path in crnn_model_paths.items(): model = MusicCRNN2D(num_classes=NUM_CLASSES, device=DEVICE) dummy_input = torch.randn(2, 1, 128, get_frames(slice_len)).to(DEVICE) _ = model(dummy_input) model.load_state_dict(torch.load(path, map_location=DEVICE)) model.to(DEVICE) model.eval() models["crnn"][slice_len] = model GENRE_LABELS = ["Blues", "Classical", "Country", "Disco", "HipHop", "Jazz", "Metal", "Pop", "Reggae", "Rock"] def predict_genre(audio_file, slice_length, architecture): slice_length = int(slice_length) start_time = time.time() y, sr = librosa.load(audio_file, sr=DESIRED_SR) target_length = int(np.ceil(len(y) / sr)) * sr if len(y) < target_length: y = np.pad(y, (0, target_length - len(y)), mode='constant') mel_spectrogram = librosa.feature.melspectrogram(y=y, sr=sr, n_fft=2048, hop_length=HOP_LENGTH, n_mels=128) mel_spectrogram_db = librosa.power_to_db(mel_spectrogram, ref=np.max) min_val, max_val = np.min(mel_spectrogram_db), np.max(mel_spectrogram_db) normalized_spectrogram = (mel_spectrogram_db - min_val) / ( max_val - min_val) if max_val - min_val > 0 else mel_spectrogram_db X_slices, _, _ = slice_songs([normalized_spectrogram], [0], ["temp"], sr=sr, hop_length=HOP_LENGTH, length_in_seconds=slice_length) X_slices = torch.tensor(X_slices, dtype=torch.float32).unsqueeze(1).to(DEVICE) model_used = models[architecture][slice_length] with torch.no_grad(): outputs = model_used(X_slices) probabilities = torch.nn.functional.softmax(outputs, dim=1).cpu().numpy() avg_probs = np.mean(probabilities, axis=0) genre_distribution = {GENRE_LABELS[i]: float(avg_probs[i]) for i in range(NUM_CLASSES)} inference_time = time.time() - start_time return genre_distribution, f"Inference Time: {inference_time:.2f} seconds" slice_length_dropdown = gr.Dropdown(choices=["1", "3", "5", "10"], value="1", label="Slice Length (seconds)") architecture_dropdown = gr.Dropdown(choices=["cnn", "crnn"], value="cnn", label="Model Architecture") demo = gr.Interface( fn=predict_genre, inputs=[gr.Audio(type="filepath", label="Upload Audio File"), slice_length_dropdown, architecture_dropdown], outputs=[gr.Label(num_top_classes=10, label="Genre Distribution"), gr.Textbox(label="Inference Time")], title="Music Genre Classifier", description="Upload an audio file, select a slice length and model architecture to predict its genre distribution." ) if __name__ == "__main__": demo.launch()