import os import pdb from statistics import mean import torch from torch import nn import numpy as np import librosa from piano_transcription_inference import PianoTranscription, sample_rate, load_audio import pretty_midi from utils import prediction2label from model import AudioModel from scipy.signal import resample def downsample_log_cqt(cqt_matrix, target_fs=5): original_fs = 44100 / 160 ratio = original_fs / target_fs downsampled = resample(cqt_matrix, int(cqt_matrix.shape[0] / ratio), axis=0) return downsampled def downsample_matrix(mat, original_fs, target_fs): ratio = original_fs / target_fs return resample(mat, int(mat.shape[0] / ratio), axis=0) def get_cqt_from_mp3(mp3_path): sample_rate = 44100 hop_length = 160 y, sr = librosa.load(mp3_path, sr=sample_rate, mono=True) cqt = librosa.cqt(y, sr=sr, hop_length=hop_length, n_bins=88, bins_per_octave=12) log_cqt = librosa.amplitude_to_db(np.abs(cqt)) log_cqt = log_cqt.T # shape (T, 88) log_cqt = downsample_log_cqt(log_cqt, target_fs=5) cqt_tensor = torch.tensor(log_cqt, dtype=torch.float32).unsqueeze(0).unsqueeze(0).cpu() print(f"cqt shape: {log_cqt.shape}") return cqt_tensor def get_pianoroll_from_mp3(mp3_path): audio, _ = load_audio(mp3_path, sr=sample_rate, mono=True) transcriptor = PianoTranscription(device="cuda" if torch.cuda.is_available() else "cpu") midi_path = "temp.mid" transcriptor.transcribe(audio, midi_path) midi_data = pretty_midi.PrettyMIDI(midi_path) fs = 5 # original frames per second piano_roll = midi_data.get_piano_roll(fs=fs)[21:109].T # shape: (T, 88) piano_roll = piano_roll / 127 time_steps = piano_roll.shape[0] onsets = np.zeros_like(piano_roll) for instrument in midi_data.instruments: for note in instrument.notes: pitch = note.pitch - 21 onset_frame = int(note.start * fs) if 0 <= pitch < 88 and onset_frame < time_steps: onsets[onset_frame, pitch] = 1.0 pr_tensor = torch.tensor(piano_roll.T).unsqueeze(0).unsqueeze(1).cpu().float() on_tensor = torch.tensor(onsets.T).unsqueeze(0).unsqueeze(1).cpu().float() out_tensor = torch.cat([pr_tensor, on_tensor], dim=1) print(f"piano_roll shape: {out_tensor.shape}") return out_tensor.transpose(2, 3) def predict_difficulty(mp3_path, model_name, rep): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if "only_cqt" in rep: only_cqt, only_pr = True, False rep_clean = "multimodal5" elif "only_pr" in rep: only_cqt, only_pr = False, True rep_clean = "multimodal5" else: only_cqt = only_pr = False rep_clean = rep model = AudioModel(num_classes=11, rep=rep_clean, modality_dropout=False, only_cqt=only_cqt, only_pr=only_pr).to(device) checkpoint = [torch.load(f"models/{model_name}/checkpoint_{i}.pth", map_location=device, weights_only=False) for i in range(5)] if rep == "cqt5": inp_data = get_cqt_from_mp3(mp3_path).to(device) elif rep == "pianoroll5": inp_data = get_pianoroll_from_mp3(mp3_path).to(device) elif rep_clean == "multimodal5": x1 = get_pianoroll_from_mp3(mp3_path).to(device) x2 = get_cqt_from_mp3(mp3_path).to(device) inp_data = [x1, x2] else: raise ValueError(f"Representation {rep} not supported") preds = [] for cheks in checkpoint: model.load_state_dict(cheks["model_state_dict"]) model.eval() with torch.inference_mode(): logits = model(inp_data, None) pred = prediction2label(logits).item() preds.append(pred) return mean(preds) if __name__ == "__main__": mp3_path = "yt_audio.mp3" model_name = "audio_midi_multi_ps_v5" pred_multi = predict_difficulty(mp3_path, model_name=model_name, rep="multimodal5") print(f"Multimodal: {pred_multi}") model_name = "audio_midi_pianoroll_ps_5_v4" pred_multi = predict_difficulty(mp3_path, model_name=model_name, rep="pianoroll5") print(f"Pianoroll: {pred_multi}") model_name = "audio_midi_multi_ps_v5" pred_multi = predict_difficulty(mp3_path, model_name=model_name, rep="pianoroll5") print(f"CQT: {pred_multi}")