Spaces:
Sleeping
Sleeping
PRamoneda
commited on
Commit
·
a5af45b
1
Parent(s):
45e5657
gpu to cpu
Browse files- __pycache__/get_difficulty.cpython-310.pyc +0 -0
- __pycache__/model.cpython-310.pyc +0 -0
- get_difficulty.py +13 -21
- model.py +0 -121
- temp.mid +0 -0
__pycache__/get_difficulty.cpython-310.pyc
CHANGED
Binary files a/__pycache__/get_difficulty.cpython-310.pyc and b/__pycache__/get_difficulty.cpython-310.pyc differ
|
|
__pycache__/model.cpython-310.pyc
CHANGED
Binary files a/__pycache__/model.cpython-310.pyc and b/__pycache__/model.cpython-310.pyc differ
|
|
get_difficulty.py
CHANGED
@@ -32,18 +32,16 @@ def get_cqt_from_mp3(mp3_path):
|
|
32 |
log_cqt = log_cqt.T # shape (T, 88)
|
33 |
log_cqt = downsample_log_cqt(log_cqt, target_fs=5)
|
34 |
cqt_tensor = torch.tensor(log_cqt, dtype=torch.float32).unsqueeze(0).unsqueeze(0).cpu()
|
35 |
-
# pdb.set_trace()
|
36 |
print(f"cqt shape: {log_cqt.shape}")
|
37 |
return cqt_tensor
|
38 |
|
39 |
def get_pianoroll_from_mp3(mp3_path):
|
40 |
audio, _ = load_audio(mp3_path, sr=sample_rate, mono=True)
|
41 |
-
transcriptor = PianoTranscription(device=
|
42 |
midi_path = "temp.mid"
|
43 |
transcriptor.transcribe(audio, midi_path)
|
44 |
midi_data = pretty_midi.PrettyMIDI(midi_path)
|
45 |
|
46 |
-
# Create pianoroll and onset matrix
|
47 |
fs = 5 # original frames per second
|
48 |
piano_roll = midi_data.get_piano_roll(fs=fs)[21:109].T # shape: (T, 88)
|
49 |
piano_roll = piano_roll / 127
|
@@ -64,6 +62,8 @@ def get_pianoroll_from_mp3(mp3_path):
|
|
64 |
return out_tensor.transpose(2, 3)
|
65 |
|
66 |
def predict_difficulty(mp3_path, model_name, rep):
|
|
|
|
|
67 |
if "only_cqt" in rep:
|
68 |
only_cqt, only_pr = True, False
|
69 |
rep_clean = "multimodal5"
|
@@ -74,18 +74,17 @@ def predict_difficulty(mp3_path, model_name, rep):
|
|
74 |
only_cqt = only_pr = False
|
75 |
rep_clean = rep
|
76 |
|
77 |
-
model = AudioModel(num_classes=11, rep=rep_clean, modality_dropout=False, only_cqt=only_cqt, only_pr=only_pr)
|
78 |
-
checkpoint = [torch.load(f"models/{model_name}/checkpoint_{i}.pth", map_location=
|
79 |
for i in range(5)]
|
80 |
|
81 |
-
|
82 |
if rep == "cqt5":
|
83 |
-
inp_data = get_cqt_from_mp3(mp3_path)
|
84 |
elif rep == "pianoroll5":
|
85 |
-
inp_data = get_pianoroll_from_mp3(mp3_path)
|
86 |
elif rep_clean == "multimodal5":
|
87 |
-
x1 = get_pianoroll_from_mp3(mp3_path)
|
88 |
-
x2 = get_cqt_from_mp3(mp3_path)
|
89 |
inp_data = [x1, x2]
|
90 |
else:
|
91 |
raise ValueError(f"Representation {rep} not supported")
|
@@ -93,23 +92,16 @@ def predict_difficulty(mp3_path, model_name, rep):
|
|
93 |
preds = []
|
94 |
for cheks in checkpoint:
|
95 |
model.load_state_dict(cheks["model_state_dict"])
|
96 |
-
model
|
97 |
with torch.inference_mode():
|
98 |
logits = model(inp_data, None)
|
99 |
pred = prediction2label(logits).item()
|
100 |
preds.append(pred)
|
101 |
|
102 |
return mean(preds)
|
103 |
-
# return preds
|
104 |
|
105 |
if __name__ == "__main__":
|
106 |
mp3_path = "yt_audio.mp3"
|
107 |
-
model_name = ""
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
# pred_pr = predict_difficulty(mp3_path, model_name="audio_midi_pianoroll_ps_5_v4", rep="pianoroll5")
|
112 |
-
# print(f"Predicción dificultad PR: {pred_pr}")
|
113 |
-
|
114 |
-
pred_multi = predict_difficulty(mp3_path, model_name="audio_midi_multi_ps_v5", rep="multimodal5")
|
115 |
-
print(f"Predicción dificultad multimodal: {pred_multi}")
|
|
|
32 |
log_cqt = log_cqt.T # shape (T, 88)
|
33 |
log_cqt = downsample_log_cqt(log_cqt, target_fs=5)
|
34 |
cqt_tensor = torch.tensor(log_cqt, dtype=torch.float32).unsqueeze(0).unsqueeze(0).cpu()
|
|
|
35 |
print(f"cqt shape: {log_cqt.shape}")
|
36 |
return cqt_tensor
|
37 |
|
38 |
def get_pianoroll_from_mp3(mp3_path):
|
39 |
audio, _ = load_audio(mp3_path, sr=sample_rate, mono=True)
|
40 |
+
transcriptor = PianoTranscription(device="cuda" if torch.cuda.is_available() else "cpu")
|
41 |
midi_path = "temp.mid"
|
42 |
transcriptor.transcribe(audio, midi_path)
|
43 |
midi_data = pretty_midi.PrettyMIDI(midi_path)
|
44 |
|
|
|
45 |
fs = 5 # original frames per second
|
46 |
piano_roll = midi_data.get_piano_roll(fs=fs)[21:109].T # shape: (T, 88)
|
47 |
piano_roll = piano_roll / 127
|
|
|
62 |
return out_tensor.transpose(2, 3)
|
63 |
|
64 |
def predict_difficulty(mp3_path, model_name, rep):
|
65 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
66 |
+
|
67 |
if "only_cqt" in rep:
|
68 |
only_cqt, only_pr = True, False
|
69 |
rep_clean = "multimodal5"
|
|
|
74 |
only_cqt = only_pr = False
|
75 |
rep_clean = rep
|
76 |
|
77 |
+
model = AudioModel(num_classes=11, rep=rep_clean, modality_dropout=False, only_cqt=only_cqt, only_pr=only_pr).to(device)
|
78 |
+
checkpoint = [torch.load(f"models/{model_name}/checkpoint_{i}.pth", map_location=device, weights_only=False)
|
79 |
for i in range(5)]
|
80 |
|
|
|
81 |
if rep == "cqt5":
|
82 |
+
inp_data = get_cqt_from_mp3(mp3_path).to(device)
|
83 |
elif rep == "pianoroll5":
|
84 |
+
inp_data = get_pianoroll_from_mp3(mp3_path).to(device)
|
85 |
elif rep_clean == "multimodal5":
|
86 |
+
x1 = get_pianoroll_from_mp3(mp3_path).to(device)
|
87 |
+
x2 = get_cqt_from_mp3(mp3_path).to(device)
|
88 |
inp_data = [x1, x2]
|
89 |
else:
|
90 |
raise ValueError(f"Representation {rep} not supported")
|
|
|
92 |
preds = []
|
93 |
for cheks in checkpoint:
|
94 |
model.load_state_dict(cheks["model_state_dict"])
|
95 |
+
model.eval()
|
96 |
with torch.inference_mode():
|
97 |
logits = model(inp_data, None)
|
98 |
pred = prediction2label(logits).item()
|
99 |
preds.append(pred)
|
100 |
|
101 |
return mean(preds)
|
|
|
102 |
|
103 |
if __name__ == "__main__":
|
104 |
mp3_path = "yt_audio.mp3"
|
105 |
+
model_name = "audio_midi_multi_ps_v5"
|
106 |
+
pred_multi = predict_difficulty(mp3_path, model_name=model_name, rep="multimodal5")
|
107 |
+
print(f"Predicción dificultad multimodal: {pred_multi}")
|
|
|
|
|
|
|
|
|
|
|
|
model.py
CHANGED
@@ -212,127 +212,6 @@ class AudioModel(nn.Module):
|
|
212 |
return x
|
213 |
|
214 |
|
215 |
-
def get_mse_macro(y_true, y_pred):
|
216 |
-
mse_each_class = []
|
217 |
-
for true_class in set(y_true):
|
218 |
-
tt, pp = zip(*[[tt, pp] for tt, pp in zip(y_true, y_pred) if tt == true_class])
|
219 |
-
mse_each_class.append(mean_squared_error(y_true=tt, y_pred=pp))
|
220 |
-
return mean(mse_each_class)
|
221 |
-
|
222 |
-
|
223 |
-
def get_cqt(rep, k):
|
224 |
-
inp_data = utils.load_binary(f"../videos_download/{rep}/{k}.bin")
|
225 |
-
inp_data = torch.tensor(inp_data, dtype=torch.float32).cpu()
|
226 |
-
inp_data = inp_data.unsqueeze(0).unsqueeze(0).transpose(2, 3)
|
227 |
-
return inp_data
|
228 |
-
|
229 |
-
|
230 |
-
def get_pianoroll(rep, k):
|
231 |
-
inp_pr = utils.load_binary(f"../videos_download/{rep}/{k}.bin")
|
232 |
-
inp_on = utils.load_binary(f"../videos_download/{rep}/{k}_onset.bin")
|
233 |
-
inp_pr = torch.from_numpy(inp_pr).float().cpu()
|
234 |
-
inp_on = torch.from_numpy(inp_on).float().cpu()
|
235 |
-
inp_data = torch.stack([inp_pr, inp_on], dim=1)
|
236 |
-
inp_data = inp_data.unsqueeze(0).permute(0, 1, 2, 3)
|
237 |
-
return inp_data
|
238 |
-
|
239 |
-
def compute_model_basic(model_name, rep, modality_dropout, only_cqt=False, only_pr=False):
|
240 |
-
seed = 42
|
241 |
-
np.random.seed(seed)
|
242 |
-
torch.manual_seed(seed)
|
243 |
-
if torch.cuda.is_available():
|
244 |
-
torch.cuda.manual_seed(seed)
|
245 |
-
data = utils.load_json("../videos_download/split_audio.json")
|
246 |
-
mse, acc = [], []
|
247 |
-
predictions = []
|
248 |
-
if only_cqt:
|
249 |
-
cache_name = model_name + "_cqt"
|
250 |
-
elif only_pr:
|
251 |
-
cache_name = model_name + "_pr"
|
252 |
-
else:
|
253 |
-
cache_name = model_name
|
254 |
-
if not os.path.exists(f"cache/{cache_name}.json"):
|
255 |
-
for split in range(5):
|
256 |
-
#load_model
|
257 |
-
model = AudioModel(11, rep, modality_dropout, only_cqt, only_pr)
|
258 |
-
checkpoint = torch.load(f"models/{model_name}/checkpoint_{split}.pth", map_location='cpu')
|
259 |
-
# print(checkpoint["epoch"])
|
260 |
-
# print(checkpoint.keys())
|
261 |
-
|
262 |
-
model.load_state_dict(checkpoint['model_state_dict'])
|
263 |
-
model = model.cpu()
|
264 |
-
pred_labels, true_labels = [], []
|
265 |
-
predictions_split = {}
|
266 |
-
model.eval()
|
267 |
-
with torch.inference_mode():
|
268 |
-
for k, ps in data[str(split)]["test"].items():
|
269 |
-
# computar el modelo
|
270 |
-
if "cqt" in rep:
|
271 |
-
inp_data = get_cqt(rep, k)
|
272 |
-
elif "pianoroll" in rep:
|
273 |
-
inp_data = get_pianoroll(rep, k)
|
274 |
-
elif rep == "multimodal5":
|
275 |
-
x1 = get_pianoroll("pianoroll5", k)
|
276 |
-
x2 = get_cqt("cqt5", k)[:, :, :x1.shape[2]]
|
277 |
-
inp_data = [x1, x2]
|
278 |
-
log_prob = model(inp_data, None)
|
279 |
-
pred = prediction2label(log_prob).cpu().tolist()[0]
|
280 |
-
print(k, ps, pred)
|
281 |
-
predictions_split[k] = {
|
282 |
-
"true": ps,
|
283 |
-
"pred": pred
|
284 |
-
}
|
285 |
-
|
286 |
-
true_labels.append(ps)
|
287 |
-
pred_labels.append(pred)
|
288 |
-
|
289 |
-
predictions.append(predictions_split)
|
290 |
-
mse.append(get_mse_macro(true_labels, pred_labels))
|
291 |
-
acc.append(balanced_accuracy_score(true_labels, pred_labels))
|
292 |
-
# with one decimal
|
293 |
-
print(f"mse: {mean(mse):.1f}({stdev(mse):.1f})", end=" ")
|
294 |
-
print(f"acc: {mean(acc)*100:.1f}({stdev(acc)*100:.1f})")
|
295 |
-
utils.save_json({
|
296 |
-
"mse": mse,
|
297 |
-
"acc": acc,
|
298 |
-
"predictions": predictions
|
299 |
-
}, f"cache/{cache_name}.json")
|
300 |
-
else:
|
301 |
-
data = utils.load_json(f"cache/{cache_name}.json")
|
302 |
-
tau_c, mse, acc = [], [], []
|
303 |
-
for i in range(5):
|
304 |
-
pred, true = [], []
|
305 |
-
for k, dd in data["predictions"][i].items():
|
306 |
-
pred.append(dd["pred"])
|
307 |
-
true.append(dd["true"])
|
308 |
-
tau_c.append(kendalltau(x=true, y=pred).statistic)
|
309 |
-
mse.append(get_mse_macro(true, pred))
|
310 |
-
acc.append(balanced_accuracy_score(true, pred))
|
311 |
-
print(model_name, end="// ")
|
312 |
-
print(f"& {mean(mse):.2f}({stdev(mse):.2f})", end=" ")
|
313 |
-
print(f"& {mean(acc) * 100:.1f}({stdev(acc) * 100:.2f})", end=" ")
|
314 |
-
print(f"& {mean(tau_c):.3f}({stdev(tau_c):.3f})")
|
315 |
-
|
316 |
-
|
317 |
-
def compute_ensemble(truncate=False):
|
318 |
-
round_func = lambda x: math.ceil(x) if truncate else math.floor(x)
|
319 |
-
data_pr = utils.load_json(f"cache/audio_midi_cqt5_ps_v5.json")
|
320 |
-
data_cqt = utils.load_json(f"cache/audio_midi_pianoroll_ps_5_v4.json")
|
321 |
-
tau_c, mse, acc = [], [], []
|
322 |
-
for i in range(5):
|
323 |
-
pred, true = [], []
|
324 |
-
for k, dd in data_pr["predictions"][i].items():
|
325 |
-
cqt_pred = data_cqt["predictions"][i][k]
|
326 |
-
pred.append(round_func((dd["pred"] + cqt_pred["pred"])/2))
|
327 |
-
true.append(dd["true"])
|
328 |
-
tau_c.append(kendalltau(x=true, y=pred).statistic)
|
329 |
-
mse.append(get_mse_macro(true, pred))
|
330 |
-
acc.append(balanced_accuracy_score(true, pred))
|
331 |
-
print("ensemble", end="// ")
|
332 |
-
print(f"& {mean(mse):.2f}({stdev(mse):.2f})", end=" ")
|
333 |
-
print(f"& {mean(acc) * 100:.1f}({stdev(acc) * 100:.2f})", end=" ")
|
334 |
-
print(f"& {mean(tau_c):.3f}({stdev(tau_c):.3f})")
|
335 |
-
|
336 |
|
337 |
def load_json(name_file):
|
338 |
with open(name_file, 'r') as fp:
|
|
|
212 |
return x
|
213 |
|
214 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
215 |
|
216 |
def load_json(name_file):
|
217 |
with open(name_file, 'r') as fp:
|
temp.mid
CHANGED
Binary files a/temp.mid and b/temp.mid differ
|
|