Spaces:
Sleeping
Sleeping
PRamoneda
commited on
Commit
·
df703c7
1
Parent(s):
2e9908b
cpu
Browse files- get_difficulty.py +6 -6
- model.py +5 -5
get_difficulty.py
CHANGED
@@ -31,14 +31,14 @@ def get_cqt_from_mp3(mp3_path):
|
|
31 |
log_cqt = librosa.amplitude_to_db(np.abs(cqt))
|
32 |
log_cqt = log_cqt.T # shape (T, 88)
|
33 |
log_cqt = downsample_log_cqt(log_cqt, target_fs=5)
|
34 |
-
cqt_tensor = torch.tensor(log_cqt, dtype=torch.float32).unsqueeze(0).unsqueeze(0).
|
35 |
# pdb.set_trace()
|
36 |
print(f"cqt shape: {log_cqt.shape}")
|
37 |
return cqt_tensor
|
38 |
|
39 |
def get_pianoroll_from_mp3(mp3_path):
|
40 |
audio, _ = load_audio(mp3_path, sr=sample_rate, mono=True)
|
41 |
-
transcriptor = PianoTranscription(device='
|
42 |
midi_path = "temp.mid"
|
43 |
transcriptor.transcribe(audio, midi_path)
|
44 |
midi_data = pretty_midi.PrettyMIDI(midi_path)
|
@@ -57,8 +57,8 @@ def get_pianoroll_from_mp3(mp3_path):
|
|
57 |
if 0 <= pitch < 88 and onset_frame < time_steps:
|
58 |
onsets[onset_frame, pitch] = 1.0
|
59 |
|
60 |
-
pr_tensor = torch.tensor(piano_roll.T).unsqueeze(0).unsqueeze(1).
|
61 |
-
on_tensor = torch.tensor(onsets.T).unsqueeze(0).unsqueeze(1).
|
62 |
out_tensor = torch.cat([pr_tensor, on_tensor], dim=1)
|
63 |
print(f"piano_roll shape: {out_tensor.shape}")
|
64 |
return out_tensor.transpose(2, 3)
|
@@ -75,7 +75,7 @@ def predict_difficulty(mp3_path, model_name, rep):
|
|
75 |
rep_clean = rep
|
76 |
|
77 |
model = AudioModel(num_classes=11, rep=rep_clean, modality_dropout=False, only_cqt=only_cqt, only_pr=only_pr)
|
78 |
-
checkpoint = [torch.load(f"models/{model_name}/checkpoint_{i}.pth", map_location="
|
79 |
for i in range(5)]
|
80 |
|
81 |
|
@@ -93,7 +93,7 @@ def predict_difficulty(mp3_path, model_name, rep):
|
|
93 |
preds = []
|
94 |
for cheks in checkpoint:
|
95 |
model.load_state_dict(cheks["model_state_dict"])
|
96 |
-
model = model.
|
97 |
with torch.inference_mode():
|
98 |
logits = model(inp_data, None)
|
99 |
pred = prediction2label(logits).item()
|
|
|
31 |
log_cqt = librosa.amplitude_to_db(np.abs(cqt))
|
32 |
log_cqt = log_cqt.T # shape (T, 88)
|
33 |
log_cqt = downsample_log_cqt(log_cqt, target_fs=5)
|
34 |
+
cqt_tensor = torch.tensor(log_cqt, dtype=torch.float32).unsqueeze(0).unsqueeze(0).cpu()
|
35 |
# pdb.set_trace()
|
36 |
print(f"cqt shape: {log_cqt.shape}")
|
37 |
return cqt_tensor
|
38 |
|
39 |
def get_pianoroll_from_mp3(mp3_path):
|
40 |
audio, _ = load_audio(mp3_path, sr=sample_rate, mono=True)
|
41 |
+
transcriptor = PianoTranscription(device='cpu')
|
42 |
midi_path = "temp.mid"
|
43 |
transcriptor.transcribe(audio, midi_path)
|
44 |
midi_data = pretty_midi.PrettyMIDI(midi_path)
|
|
|
57 |
if 0 <= pitch < 88 and onset_frame < time_steps:
|
58 |
onsets[onset_frame, pitch] = 1.0
|
59 |
|
60 |
+
pr_tensor = torch.tensor(piano_roll.T).unsqueeze(0).unsqueeze(1).cpu().float()
|
61 |
+
on_tensor = torch.tensor(onsets.T).unsqueeze(0).unsqueeze(1).cpu().float()
|
62 |
out_tensor = torch.cat([pr_tensor, on_tensor], dim=1)
|
63 |
print(f"piano_roll shape: {out_tensor.shape}")
|
64 |
return out_tensor.transpose(2, 3)
|
|
|
75 |
rep_clean = rep
|
76 |
|
77 |
model = AudioModel(num_classes=11, rep=rep_clean, modality_dropout=False, only_cqt=only_cqt, only_pr=only_pr)
|
78 |
+
checkpoint = [torch.load(f"models/{model_name}/checkpoint_{i}.pth", map_location="cpu", weights_only=False)
|
79 |
for i in range(5)]
|
80 |
|
81 |
|
|
|
93 |
preds = []
|
94 |
for cheks in checkpoint:
|
95 |
model.load_state_dict(cheks["model_state_dict"])
|
96 |
+
model = model.cpu().eval()
|
97 |
with torch.inference_mode():
|
98 |
logits = model(inp_data, None)
|
99 |
pred = prediction2label(logits).item()
|
model.py
CHANGED
@@ -222,7 +222,7 @@ def get_mse_macro(y_true, y_pred):
|
|
222 |
|
223 |
def get_cqt(rep, k):
|
224 |
inp_data = utils.load_binary(f"../videos_download/{rep}/{k}.bin")
|
225 |
-
inp_data = torch.tensor(inp_data, dtype=torch.float32).
|
226 |
inp_data = inp_data.unsqueeze(0).unsqueeze(0).transpose(2, 3)
|
227 |
return inp_data
|
228 |
|
@@ -230,8 +230,8 @@ def get_cqt(rep, k):
|
|
230 |
def get_pianoroll(rep, k):
|
231 |
inp_pr = utils.load_binary(f"../videos_download/{rep}/{k}.bin")
|
232 |
inp_on = utils.load_binary(f"../videos_download/{rep}/{k}_onset.bin")
|
233 |
-
inp_pr = torch.from_numpy(inp_pr).float().
|
234 |
-
inp_on = torch.from_numpy(inp_on).float().
|
235 |
inp_data = torch.stack([inp_pr, inp_on], dim=1)
|
236 |
inp_data = inp_data.unsqueeze(0).permute(0, 1, 2, 3)
|
237 |
return inp_data
|
@@ -255,12 +255,12 @@ def compute_model_basic(model_name, rep, modality_dropout, only_cqt=False, only_
|
|
255 |
for split in range(5):
|
256 |
#load_model
|
257 |
model = AudioModel(11, rep, modality_dropout, only_cqt, only_pr)
|
258 |
-
checkpoint = torch.load(f"models/{model_name}/checkpoint_{split}.pth", map_location='
|
259 |
# print(checkpoint["epoch"])
|
260 |
# print(checkpoint.keys())
|
261 |
|
262 |
model.load_state_dict(checkpoint['model_state_dict'])
|
263 |
-
model = model.
|
264 |
pred_labels, true_labels = [], []
|
265 |
predictions_split = {}
|
266 |
model.eval()
|
|
|
222 |
|
223 |
def get_cqt(rep, k):
|
224 |
inp_data = utils.load_binary(f"../videos_download/{rep}/{k}.bin")
|
225 |
+
inp_data = torch.tensor(inp_data, dtype=torch.float32).cpu()
|
226 |
inp_data = inp_data.unsqueeze(0).unsqueeze(0).transpose(2, 3)
|
227 |
return inp_data
|
228 |
|
|
|
230 |
def get_pianoroll(rep, k):
|
231 |
inp_pr = utils.load_binary(f"../videos_download/{rep}/{k}.bin")
|
232 |
inp_on = utils.load_binary(f"../videos_download/{rep}/{k}_onset.bin")
|
233 |
+
inp_pr = torch.from_numpy(inp_pr).float().cpu()
|
234 |
+
inp_on = torch.from_numpy(inp_on).float().cpu()
|
235 |
inp_data = torch.stack([inp_pr, inp_on], dim=1)
|
236 |
inp_data = inp_data.unsqueeze(0).permute(0, 1, 2, 3)
|
237 |
return inp_data
|
|
|
255 |
for split in range(5):
|
256 |
#load_model
|
257 |
model = AudioModel(11, rep, modality_dropout, only_cqt, only_pr)
|
258 |
+
checkpoint = torch.load(f"models/{model_name}/checkpoint_{split}.pth", map_location='cpu')
|
259 |
# print(checkpoint["epoch"])
|
260 |
# print(checkpoint.keys())
|
261 |
|
262 |
model.load_state_dict(checkpoint['model_state_dict'])
|
263 |
+
model = model.cpu()
|
264 |
pred_labels, true_labels = [], []
|
265 |
predictions_split = {}
|
266 |
model.eval()
|