PRamoneda commited on
Commit
a5af45b
·
1 Parent(s): 45e5657

gpu to cpu

Browse files
__pycache__/get_difficulty.cpython-310.pyc CHANGED
Binary files a/__pycache__/get_difficulty.cpython-310.pyc and b/__pycache__/get_difficulty.cpython-310.pyc differ
 
__pycache__/model.cpython-310.pyc CHANGED
Binary files a/__pycache__/model.cpython-310.pyc and b/__pycache__/model.cpython-310.pyc differ
 
get_difficulty.py CHANGED
@@ -32,18 +32,16 @@ def get_cqt_from_mp3(mp3_path):
32
  log_cqt = log_cqt.T # shape (T, 88)
33
  log_cqt = downsample_log_cqt(log_cqt, target_fs=5)
34
  cqt_tensor = torch.tensor(log_cqt, dtype=torch.float32).unsqueeze(0).unsqueeze(0).cpu()
35
- # pdb.set_trace()
36
  print(f"cqt shape: {log_cqt.shape}")
37
  return cqt_tensor
38
 
39
  def get_pianoroll_from_mp3(mp3_path):
40
  audio, _ = load_audio(mp3_path, sr=sample_rate, mono=True)
41
- transcriptor = PianoTranscription(device='cpu')
42
  midi_path = "temp.mid"
43
  transcriptor.transcribe(audio, midi_path)
44
  midi_data = pretty_midi.PrettyMIDI(midi_path)
45
 
46
- # Create pianoroll and onset matrix
47
  fs = 5 # original frames per second
48
  piano_roll = midi_data.get_piano_roll(fs=fs)[21:109].T # shape: (T, 88)
49
  piano_roll = piano_roll / 127
@@ -64,6 +62,8 @@ def get_pianoroll_from_mp3(mp3_path):
64
  return out_tensor.transpose(2, 3)
65
 
66
  def predict_difficulty(mp3_path, model_name, rep):
 
 
67
  if "only_cqt" in rep:
68
  only_cqt, only_pr = True, False
69
  rep_clean = "multimodal5"
@@ -74,18 +74,17 @@ def predict_difficulty(mp3_path, model_name, rep):
74
  only_cqt = only_pr = False
75
  rep_clean = rep
76
 
77
- model = AudioModel(num_classes=11, rep=rep_clean, modality_dropout=False, only_cqt=only_cqt, only_pr=only_pr)
78
- checkpoint = [torch.load(f"models/{model_name}/checkpoint_{i}.pth", map_location="cpu", weights_only=False)
79
  for i in range(5)]
80
 
81
-
82
  if rep == "cqt5":
83
- inp_data = get_cqt_from_mp3(mp3_path)
84
  elif rep == "pianoroll5":
85
- inp_data = get_pianoroll_from_mp3(mp3_path)
86
  elif rep_clean == "multimodal5":
87
- x1 = get_pianoroll_from_mp3(mp3_path)
88
- x2 = get_cqt_from_mp3(mp3_path)
89
  inp_data = [x1, x2]
90
  else:
91
  raise ValueError(f"Representation {rep} not supported")
@@ -93,23 +92,16 @@ def predict_difficulty(mp3_path, model_name, rep):
93
  preds = []
94
  for cheks in checkpoint:
95
  model.load_state_dict(cheks["model_state_dict"])
96
- model = model.cpu().eval()
97
  with torch.inference_mode():
98
  logits = model(inp_data, None)
99
  pred = prediction2label(logits).item()
100
  preds.append(pred)
101
 
102
  return mean(preds)
103
- # return preds
104
 
105
  if __name__ == "__main__":
106
  mp3_path = "yt_audio.mp3"
107
- model_name = ""
108
- # pred_cqt = predict_difficulty(mp3_path, model_name="audio_midi_cqt5_ps_v5", rep="cqt5")
109
- # print(f"Predicción dificultad CQT: {pred_cqt}")
110
-
111
- # pred_pr = predict_difficulty(mp3_path, model_name="audio_midi_pianoroll_ps_5_v4", rep="pianoroll5")
112
- # print(f"Predicción dificultad PR: {pred_pr}")
113
-
114
- pred_multi = predict_difficulty(mp3_path, model_name="audio_midi_multi_ps_v5", rep="multimodal5")
115
- print(f"Predicción dificultad multimodal: {pred_multi}")
 
32
  log_cqt = log_cqt.T # shape (T, 88)
33
  log_cqt = downsample_log_cqt(log_cqt, target_fs=5)
34
  cqt_tensor = torch.tensor(log_cqt, dtype=torch.float32).unsqueeze(0).unsqueeze(0).cpu()
 
35
  print(f"cqt shape: {log_cqt.shape}")
36
  return cqt_tensor
37
 
38
  def get_pianoroll_from_mp3(mp3_path):
39
  audio, _ = load_audio(mp3_path, sr=sample_rate, mono=True)
40
+ transcriptor = PianoTranscription(device="cuda" if torch.cuda.is_available() else "cpu")
41
  midi_path = "temp.mid"
42
  transcriptor.transcribe(audio, midi_path)
43
  midi_data = pretty_midi.PrettyMIDI(midi_path)
44
 
 
45
  fs = 5 # original frames per second
46
  piano_roll = midi_data.get_piano_roll(fs=fs)[21:109].T # shape: (T, 88)
47
  piano_roll = piano_roll / 127
 
62
  return out_tensor.transpose(2, 3)
63
 
64
  def predict_difficulty(mp3_path, model_name, rep):
65
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
66
+
67
  if "only_cqt" in rep:
68
  only_cqt, only_pr = True, False
69
  rep_clean = "multimodal5"
 
74
  only_cqt = only_pr = False
75
  rep_clean = rep
76
 
77
+ model = AudioModel(num_classes=11, rep=rep_clean, modality_dropout=False, only_cqt=only_cqt, only_pr=only_pr).to(device)
78
+ checkpoint = [torch.load(f"models/{model_name}/checkpoint_{i}.pth", map_location=device, weights_only=False)
79
  for i in range(5)]
80
 
 
81
  if rep == "cqt5":
82
+ inp_data = get_cqt_from_mp3(mp3_path).to(device)
83
  elif rep == "pianoroll5":
84
+ inp_data = get_pianoroll_from_mp3(mp3_path).to(device)
85
  elif rep_clean == "multimodal5":
86
+ x1 = get_pianoroll_from_mp3(mp3_path).to(device)
87
+ x2 = get_cqt_from_mp3(mp3_path).to(device)
88
  inp_data = [x1, x2]
89
  else:
90
  raise ValueError(f"Representation {rep} not supported")
 
92
  preds = []
93
  for cheks in checkpoint:
94
  model.load_state_dict(cheks["model_state_dict"])
95
+ model.eval()
96
  with torch.inference_mode():
97
  logits = model(inp_data, None)
98
  pred = prediction2label(logits).item()
99
  preds.append(pred)
100
 
101
  return mean(preds)
 
102
 
103
  if __name__ == "__main__":
104
  mp3_path = "yt_audio.mp3"
105
+ model_name = "audio_midi_multi_ps_v5"
106
+ pred_multi = predict_difficulty(mp3_path, model_name=model_name, rep="multimodal5")
107
+ print(f"Predicción dificultad multimodal: {pred_multi}")
 
 
 
 
 
 
model.py CHANGED
@@ -212,127 +212,6 @@ class AudioModel(nn.Module):
212
  return x
213
 
214
 
215
- def get_mse_macro(y_true, y_pred):
216
- mse_each_class = []
217
- for true_class in set(y_true):
218
- tt, pp = zip(*[[tt, pp] for tt, pp in zip(y_true, y_pred) if tt == true_class])
219
- mse_each_class.append(mean_squared_error(y_true=tt, y_pred=pp))
220
- return mean(mse_each_class)
221
-
222
-
223
- def get_cqt(rep, k):
224
- inp_data = utils.load_binary(f"../videos_download/{rep}/{k}.bin")
225
- inp_data = torch.tensor(inp_data, dtype=torch.float32).cpu()
226
- inp_data = inp_data.unsqueeze(0).unsqueeze(0).transpose(2, 3)
227
- return inp_data
228
-
229
-
230
- def get_pianoroll(rep, k):
231
- inp_pr = utils.load_binary(f"../videos_download/{rep}/{k}.bin")
232
- inp_on = utils.load_binary(f"../videos_download/{rep}/{k}_onset.bin")
233
- inp_pr = torch.from_numpy(inp_pr).float().cpu()
234
- inp_on = torch.from_numpy(inp_on).float().cpu()
235
- inp_data = torch.stack([inp_pr, inp_on], dim=1)
236
- inp_data = inp_data.unsqueeze(0).permute(0, 1, 2, 3)
237
- return inp_data
238
-
239
- def compute_model_basic(model_name, rep, modality_dropout, only_cqt=False, only_pr=False):
240
- seed = 42
241
- np.random.seed(seed)
242
- torch.manual_seed(seed)
243
- if torch.cuda.is_available():
244
- torch.cuda.manual_seed(seed)
245
- data = utils.load_json("../videos_download/split_audio.json")
246
- mse, acc = [], []
247
- predictions = []
248
- if only_cqt:
249
- cache_name = model_name + "_cqt"
250
- elif only_pr:
251
- cache_name = model_name + "_pr"
252
- else:
253
- cache_name = model_name
254
- if not os.path.exists(f"cache/{cache_name}.json"):
255
- for split in range(5):
256
- #load_model
257
- model = AudioModel(11, rep, modality_dropout, only_cqt, only_pr)
258
- checkpoint = torch.load(f"models/{model_name}/checkpoint_{split}.pth", map_location='cpu')
259
- # print(checkpoint["epoch"])
260
- # print(checkpoint.keys())
261
-
262
- model.load_state_dict(checkpoint['model_state_dict'])
263
- model = model.cpu()
264
- pred_labels, true_labels = [], []
265
- predictions_split = {}
266
- model.eval()
267
- with torch.inference_mode():
268
- for k, ps in data[str(split)]["test"].items():
269
- # computar el modelo
270
- if "cqt" in rep:
271
- inp_data = get_cqt(rep, k)
272
- elif "pianoroll" in rep:
273
- inp_data = get_pianoroll(rep, k)
274
- elif rep == "multimodal5":
275
- x1 = get_pianoroll("pianoroll5", k)
276
- x2 = get_cqt("cqt5", k)[:, :, :x1.shape[2]]
277
- inp_data = [x1, x2]
278
- log_prob = model(inp_data, None)
279
- pred = prediction2label(log_prob).cpu().tolist()[0]
280
- print(k, ps, pred)
281
- predictions_split[k] = {
282
- "true": ps,
283
- "pred": pred
284
- }
285
-
286
- true_labels.append(ps)
287
- pred_labels.append(pred)
288
-
289
- predictions.append(predictions_split)
290
- mse.append(get_mse_macro(true_labels, pred_labels))
291
- acc.append(balanced_accuracy_score(true_labels, pred_labels))
292
- # with one decimal
293
- print(f"mse: {mean(mse):.1f}({stdev(mse):.1f})", end=" ")
294
- print(f"acc: {mean(acc)*100:.1f}({stdev(acc)*100:.1f})")
295
- utils.save_json({
296
- "mse": mse,
297
- "acc": acc,
298
- "predictions": predictions
299
- }, f"cache/{cache_name}.json")
300
- else:
301
- data = utils.load_json(f"cache/{cache_name}.json")
302
- tau_c, mse, acc = [], [], []
303
- for i in range(5):
304
- pred, true = [], []
305
- for k, dd in data["predictions"][i].items():
306
- pred.append(dd["pred"])
307
- true.append(dd["true"])
308
- tau_c.append(kendalltau(x=true, y=pred).statistic)
309
- mse.append(get_mse_macro(true, pred))
310
- acc.append(balanced_accuracy_score(true, pred))
311
- print(model_name, end="// ")
312
- print(f"& {mean(mse):.2f}({stdev(mse):.2f})", end=" ")
313
- print(f"& {mean(acc) * 100:.1f}({stdev(acc) * 100:.2f})", end=" ")
314
- print(f"& {mean(tau_c):.3f}({stdev(tau_c):.3f})")
315
-
316
-
317
- def compute_ensemble(truncate=False):
318
- round_func = lambda x: math.ceil(x) if truncate else math.floor(x)
319
- data_pr = utils.load_json(f"cache/audio_midi_cqt5_ps_v5.json")
320
- data_cqt = utils.load_json(f"cache/audio_midi_pianoroll_ps_5_v4.json")
321
- tau_c, mse, acc = [], [], []
322
- for i in range(5):
323
- pred, true = [], []
324
- for k, dd in data_pr["predictions"][i].items():
325
- cqt_pred = data_cqt["predictions"][i][k]
326
- pred.append(round_func((dd["pred"] + cqt_pred["pred"])/2))
327
- true.append(dd["true"])
328
- tau_c.append(kendalltau(x=true, y=pred).statistic)
329
- mse.append(get_mse_macro(true, pred))
330
- acc.append(balanced_accuracy_score(true, pred))
331
- print("ensemble", end="// ")
332
- print(f"& {mean(mse):.2f}({stdev(mse):.2f})", end=" ")
333
- print(f"& {mean(acc) * 100:.1f}({stdev(acc) * 100:.2f})", end=" ")
334
- print(f"& {mean(tau_c):.3f}({stdev(tau_c):.3f})")
335
-
336
 
337
  def load_json(name_file):
338
  with open(name_file, 'r') as fp:
 
212
  return x
213
 
214
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
 
216
  def load_json(name_file):
217
  with open(name_file, 'r') as fp:
temp.mid CHANGED
Binary files a/temp.mid and b/temp.mid differ