PRamoneda commited on
Commit
df703c7
·
1 Parent(s): 2e9908b
Files changed (2) hide show
  1. get_difficulty.py +6 -6
  2. model.py +5 -5
get_difficulty.py CHANGED
@@ -31,14 +31,14 @@ def get_cqt_from_mp3(mp3_path):
31
  log_cqt = librosa.amplitude_to_db(np.abs(cqt))
32
  log_cqt = log_cqt.T # shape (T, 88)
33
  log_cqt = downsample_log_cqt(log_cqt, target_fs=5)
34
- cqt_tensor = torch.tensor(log_cqt, dtype=torch.float32).unsqueeze(0).unsqueeze(0).cuda()
35
  # pdb.set_trace()
36
  print(f"cqt shape: {log_cqt.shape}")
37
  return cqt_tensor
38
 
39
  def get_pianoroll_from_mp3(mp3_path):
40
  audio, _ = load_audio(mp3_path, sr=sample_rate, mono=True)
41
- transcriptor = PianoTranscription(device='cuda')
42
  midi_path = "temp.mid"
43
  transcriptor.transcribe(audio, midi_path)
44
  midi_data = pretty_midi.PrettyMIDI(midi_path)
@@ -57,8 +57,8 @@ def get_pianoroll_from_mp3(mp3_path):
57
  if 0 <= pitch < 88 and onset_frame < time_steps:
58
  onsets[onset_frame, pitch] = 1.0
59
 
60
- pr_tensor = torch.tensor(piano_roll.T).unsqueeze(0).unsqueeze(1).cuda().float()
61
- on_tensor = torch.tensor(onsets.T).unsqueeze(0).unsqueeze(1).cuda().float()
62
  out_tensor = torch.cat([pr_tensor, on_tensor], dim=1)
63
  print(f"piano_roll shape: {out_tensor.shape}")
64
  return out_tensor.transpose(2, 3)
@@ -75,7 +75,7 @@ def predict_difficulty(mp3_path, model_name, rep):
75
  rep_clean = rep
76
 
77
  model = AudioModel(num_classes=11, rep=rep_clean, modality_dropout=False, only_cqt=only_cqt, only_pr=only_pr)
78
- checkpoint = [torch.load(f"models/{model_name}/checkpoint_{i}.pth", map_location="cuda", weights_only=False)
79
  for i in range(5)]
80
 
81
 
@@ -93,7 +93,7 @@ def predict_difficulty(mp3_path, model_name, rep):
93
  preds = []
94
  for cheks in checkpoint:
95
  model.load_state_dict(cheks["model_state_dict"])
96
- model = model.cuda().eval()
97
  with torch.inference_mode():
98
  logits = model(inp_data, None)
99
  pred = prediction2label(logits).item()
 
31
  log_cqt = librosa.amplitude_to_db(np.abs(cqt))
32
  log_cqt = log_cqt.T # shape (T, 88)
33
  log_cqt = downsample_log_cqt(log_cqt, target_fs=5)
34
+ cqt_tensor = torch.tensor(log_cqt, dtype=torch.float32).unsqueeze(0).unsqueeze(0).cpu()
35
  # pdb.set_trace()
36
  print(f"cqt shape: {log_cqt.shape}")
37
  return cqt_tensor
38
 
39
  def get_pianoroll_from_mp3(mp3_path):
40
  audio, _ = load_audio(mp3_path, sr=sample_rate, mono=True)
41
+ transcriptor = PianoTranscription(device='cpu')
42
  midi_path = "temp.mid"
43
  transcriptor.transcribe(audio, midi_path)
44
  midi_data = pretty_midi.PrettyMIDI(midi_path)
 
57
  if 0 <= pitch < 88 and onset_frame < time_steps:
58
  onsets[onset_frame, pitch] = 1.0
59
 
60
+ pr_tensor = torch.tensor(piano_roll.T).unsqueeze(0).unsqueeze(1).cpu().float()
61
+ on_tensor = torch.tensor(onsets.T).unsqueeze(0).unsqueeze(1).cpu().float()
62
  out_tensor = torch.cat([pr_tensor, on_tensor], dim=1)
63
  print(f"piano_roll shape: {out_tensor.shape}")
64
  return out_tensor.transpose(2, 3)
 
75
  rep_clean = rep
76
 
77
  model = AudioModel(num_classes=11, rep=rep_clean, modality_dropout=False, only_cqt=only_cqt, only_pr=only_pr)
78
+ checkpoint = [torch.load(f"models/{model_name}/checkpoint_{i}.pth", map_location="cpu", weights_only=False)
79
  for i in range(5)]
80
 
81
 
 
93
  preds = []
94
  for cheks in checkpoint:
95
  model.load_state_dict(cheks["model_state_dict"])
96
+ model = model.cpu().eval()
97
  with torch.inference_mode():
98
  logits = model(inp_data, None)
99
  pred = prediction2label(logits).item()
model.py CHANGED
@@ -222,7 +222,7 @@ def get_mse_macro(y_true, y_pred):
222
 
223
  def get_cqt(rep, k):
224
  inp_data = utils.load_binary(f"../videos_download/{rep}/{k}.bin")
225
- inp_data = torch.tensor(inp_data, dtype=torch.float32).cuda()
226
  inp_data = inp_data.unsqueeze(0).unsqueeze(0).transpose(2, 3)
227
  return inp_data
228
 
@@ -230,8 +230,8 @@ def get_cqt(rep, k):
230
  def get_pianoroll(rep, k):
231
  inp_pr = utils.load_binary(f"../videos_download/{rep}/{k}.bin")
232
  inp_on = utils.load_binary(f"../videos_download/{rep}/{k}_onset.bin")
233
- inp_pr = torch.from_numpy(inp_pr).float().cuda()
234
- inp_on = torch.from_numpy(inp_on).float().cuda()
235
  inp_data = torch.stack([inp_pr, inp_on], dim=1)
236
  inp_data = inp_data.unsqueeze(0).permute(0, 1, 2, 3)
237
  return inp_data
@@ -255,12 +255,12 @@ def compute_model_basic(model_name, rep, modality_dropout, only_cqt=False, only_
255
  for split in range(5):
256
  #load_model
257
  model = AudioModel(11, rep, modality_dropout, only_cqt, only_pr)
258
- checkpoint = torch.load(f"models/{model_name}/checkpoint_{split}.pth", map_location='cuda:0')
259
  # print(checkpoint["epoch"])
260
  # print(checkpoint.keys())
261
 
262
  model.load_state_dict(checkpoint['model_state_dict'])
263
- model = model.cuda()
264
  pred_labels, true_labels = [], []
265
  predictions_split = {}
266
  model.eval()
 
222
 
223
  def get_cqt(rep, k):
224
  inp_data = utils.load_binary(f"../videos_download/{rep}/{k}.bin")
225
+ inp_data = torch.tensor(inp_data, dtype=torch.float32).cpu()
226
  inp_data = inp_data.unsqueeze(0).unsqueeze(0).transpose(2, 3)
227
  return inp_data
228
 
 
230
  def get_pianoroll(rep, k):
231
  inp_pr = utils.load_binary(f"../videos_download/{rep}/{k}.bin")
232
  inp_on = utils.load_binary(f"../videos_download/{rep}/{k}_onset.bin")
233
+ inp_pr = torch.from_numpy(inp_pr).float().cpu()
234
+ inp_on = torch.from_numpy(inp_on).float().cpu()
235
  inp_data = torch.stack([inp_pr, inp_on], dim=1)
236
  inp_data = inp_data.unsqueeze(0).permute(0, 1, 2, 3)
237
  return inp_data
 
255
  for split in range(5):
256
  #load_model
257
  model = AudioModel(11, rep, modality_dropout, only_cqt, only_pr)
258
+ checkpoint = torch.load(f"models/{model_name}/checkpoint_{split}.pth", map_location='cpu')
259
  # print(checkpoint["epoch"])
260
  # print(checkpoint.keys())
261
 
262
  model.load_state_dict(checkpoint['model_state_dict'])
263
+ model = model.cpu()
264
  pred_labels, true_labels = [], []
265
  predictions_split = {}
266
  model.eval()