File size: 4,318 Bytes
c66e52a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
df703c7
c66e52a
 
 
 
 
a5af45b
c66e52a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
df703c7
 
c66e52a
 
 
 
 
a5af45b
 
c66e52a
 
 
 
 
 
 
 
 
 
a5af45b
 
c66e52a
 
 
a5af45b
c66e52a
a5af45b
c66e52a
a5af45b
 
c66e52a
 
 
 
 
 
 
a5af45b
c66e52a
 
 
 
 
 
 
 
 
a5af45b
 
1f1fb78
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import os
import pdb
from statistics import mean

import torch
from torch import nn
import numpy as np
import librosa
from piano_transcription_inference import PianoTranscription, sample_rate, load_audio
import pretty_midi
from utils import prediction2label
from model import AudioModel
from scipy.signal import resample


def downsample_log_cqt(cqt_matrix, target_fs=5):
    original_fs = 44100 / 160
    ratio = original_fs / target_fs
    downsampled = resample(cqt_matrix, int(cqt_matrix.shape[0] / ratio), axis=0)
    return downsampled

def downsample_matrix(mat, original_fs, target_fs):
    ratio = original_fs / target_fs
    return resample(mat, int(mat.shape[0] / ratio), axis=0)

def get_cqt_from_mp3(mp3_path):
    sample_rate = 44100
    hop_length = 160
    y, sr = librosa.load(mp3_path, sr=sample_rate, mono=True)
    cqt = librosa.cqt(y, sr=sr, hop_length=hop_length, n_bins=88, bins_per_octave=12)
    log_cqt = librosa.amplitude_to_db(np.abs(cqt))
    log_cqt = log_cqt.T  # shape (T, 88)
    log_cqt = downsample_log_cqt(log_cqt, target_fs=5)
    cqt_tensor = torch.tensor(log_cqt, dtype=torch.float32).unsqueeze(0).unsqueeze(0).cpu()
    print(f"cqt shape: {log_cqt.shape}")
    return cqt_tensor

def get_pianoroll_from_mp3(mp3_path):
    audio, _ = load_audio(mp3_path, sr=sample_rate, mono=True)
    transcriptor = PianoTranscription(device="cuda" if torch.cuda.is_available() else "cpu")
    midi_path = "temp.mid"
    transcriptor.transcribe(audio, midi_path)
    midi_data = pretty_midi.PrettyMIDI(midi_path)

    fs = 5  # original frames per second
    piano_roll = midi_data.get_piano_roll(fs=fs)[21:109].T  # shape: (T, 88)
    piano_roll = piano_roll / 127
    time_steps = piano_roll.shape[0]

    onsets = np.zeros_like(piano_roll)
    for instrument in midi_data.instruments:
        for note in instrument.notes:
            pitch = note.pitch - 21
            onset_frame = int(note.start * fs)
            if 0 <= pitch < 88 and onset_frame < time_steps:
                onsets[onset_frame, pitch] = 1.0

    pr_tensor = torch.tensor(piano_roll.T).unsqueeze(0).unsqueeze(1).cpu().float()
    on_tensor = torch.tensor(onsets.T).unsqueeze(0).unsqueeze(1).cpu().float()
    out_tensor = torch.cat([pr_tensor, on_tensor], dim=1)
    print(f"piano_roll shape: {out_tensor.shape}")
    return out_tensor.transpose(2, 3)

def predict_difficulty(mp3_path, model_name, rep):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    if "only_cqt" in rep:
        only_cqt, only_pr = True, False
        rep_clean = "multimodal5"
    elif "only_pr" in rep:
        only_cqt, only_pr = False, True
        rep_clean = "multimodal5"
    else:
        only_cqt = only_pr = False
        rep_clean = rep

    model = AudioModel(num_classes=11, rep=rep_clean, modality_dropout=False, only_cqt=only_cqt, only_pr=only_pr).to(device)
    checkpoint = [torch.load(f"models/{model_name}/checkpoint_{i}.pth", map_location=device, weights_only=False)
                  for i in range(5)]

    if rep == "cqt5":
        inp_data = get_cqt_from_mp3(mp3_path).to(device)
    elif rep == "pianoroll5":
        inp_data = get_pianoroll_from_mp3(mp3_path).to(device)
    elif rep_clean == "multimodal5":
        x1 = get_pianoroll_from_mp3(mp3_path).to(device)
        x2 = get_cqt_from_mp3(mp3_path).to(device)
        inp_data = [x1, x2]
    else:
        raise ValueError(f"Representation {rep} not supported")

    preds = []
    for cheks in checkpoint:
        model.load_state_dict(cheks["model_state_dict"])
        model.eval()
        with torch.inference_mode():
            logits = model(inp_data, None)
            pred = prediction2label(logits).item()
            preds.append(pred)

    return mean(preds)

if __name__ == "__main__":
    mp3_path = "yt_audio.mp3"
    model_name = "audio_midi_multi_ps_v5"
    pred_multi = predict_difficulty(mp3_path, model_name=model_name, rep="multimodal5")
    print(f"Multimodal: {pred_multi}")

    model_name = "audio_midi_pianoroll_ps_5_v4"
    pred_multi = predict_difficulty(mp3_path, model_name=model_name, rep="pianoroll5")
    print(f"Pianoroll: {pred_multi}")

    model_name = "audio_midi_multi_ps_v5"
    pred_multi = predict_difficulty(mp3_path, model_name=model_name, rep="pianoroll5")
    print(f"CQT: {pred_multi}")