asigalov61 commited on
Commit
b5806f4
·
verified ·
1 Parent(s): 64f1965

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +94 -146
app.py CHANGED
@@ -61,24 +61,35 @@ print('=' * 70)
61
 
62
  #==================================================================================
63
 
64
- MODEL_CHECKPOINT = 'MIDI_Genre_Classifier_Trained_Model_36457_steps_0.5384_loss_0.8417_acc.pth'
65
 
66
  SOUDFONT_PATH = 'SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2'
67
 
68
- MAX_MELODY_NOTES = 64
 
 
 
 
 
 
 
 
 
69
 
70
- MAX_GEN_TOKS = 3072
 
 
71
 
72
  #==================================================================================
73
 
74
  print('=' * 70)
75
- print('Loading popular hook melodies dataset...')
76
 
77
- popular_hook_melodies_pickle = hf_hub_download(repo_id='asigalov61/MIDI-Genre-Classifier',
78
- filename='popular_hook_melodies_24_64_CC_BY_NC_SA.pickle'
79
  )
80
 
81
- popular_hook_melodies = TMIDIX.Tegridy_Any_Pickle_File_Reader(popular_hook_melodies_pickle)
82
 
83
  print('=' * 70)
84
  print('Done!')
@@ -89,20 +100,20 @@ print('=' * 70)
89
  print('=' * 70)
90
  print('Instantiating model...')
91
 
92
- device_type = 'cuda'
93
  dtype = 'bfloat16'
94
 
95
  ptdtype = {'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
96
  ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype)
97
 
98
- SEQ_LEN = 4096
99
- PAD_IDX = 1794
100
 
101
  model = TransformerWrapper(
102
  num_tokens = PAD_IDX+1,
103
  max_seq_len = SEQ_LEN,
104
  attn_layers = Decoder(dim = 2048,
105
- depth = 4,
106
  heads = 32,
107
  rotary_pos_emb = True,
108
  attn_flash = True
@@ -114,12 +125,10 @@ model = AutoregressiveWrapper(model, ignore_index=PAD_IDX, pad_value=PAD_IDX)
114
  print('=' * 70)
115
  print('Loading model checkpoint...')
116
 
117
- model_checkpoint = hf_hub_download(repo_id='asigalov61/MIDI-Genre-Classifier', filename=MODEL_CHECKPOINT)
118
 
119
  model.load_state_dict(torch.load(model_checkpoint, map_location='cpu', weights_only=True))
120
 
121
- model = torch.compile(model, mode='max-autotune')
122
-
123
  print('=' * 70)
124
  print('Done!')
125
  print('=' * 70)
@@ -128,111 +137,46 @@ print('=' * 70)
128
 
129
  #==================================================================================
130
 
131
- def load_midi(input_midi, melody_patch=-1, use_nth_note=1):
132
-
133
- raw_score = TMIDIX.midi2single_track_ms_score(input_midi)
134
-
135
- escore_notes = TMIDIX.advanced_score_processor(raw_score, return_enhanced_score_notes=True)[0]
136
- escore_notes = TMIDIX.augment_enhanced_score_notes(escore_notes, timings_divider=32)
137
-
138
- sp_escore_notes = TMIDIX.solo_piano_escore_notes(escore_notes, keep_drums=False)
139
-
140
- if melody_patch == -1:
141
- zscore = TMIDIX.recalculate_score_timings(sp_escore_notes)
142
-
143
- else:
144
- mel_score = [e for e in sp_escore_notes if e[6] == melody_patch]
145
-
146
- if mel_score:
147
- zscore = TMIDIX.recalculate_score_timings(mel_score)
148
 
149
- else:
150
- zscore = TMIDIX.recalculate_score_timings(sp_escore_notes)
151
-
152
- cscore = TMIDIX.chordify_score([1000, zscore])[:MAX_MELODY_NOTES:use_nth_note]
153
 
154
- score = []
155
 
156
- score_list = []
157
 
158
- pc = cscore[0]
159
 
160
- for c in cscore:
161
- score.append(max(0, min(127, c[0][1]-pc[0][1])))
162
 
163
- scl = [[max(0, min(127, c[0][1]-pc[0][1]))]]
 
164
 
165
- n = c[0]
166
-
167
- score.extend([max(1, min(127, n[2]))+128, max(1, min(127, n[4]))+256])
168
- scl.append([max(1, min(127, n[2]))+128, max(1, min(127, n[4]))+256])
169
 
170
- score_list.append(scl)
171
 
172
- pc = c
173
 
174
- score_list.append(scl)
175
 
176
- return score, score_list
177
 
178
  #==================================================================================
179
 
180
- @spaces.GPU
181
- def Generate_Accompaniment(input_midi,
182
- input_melody,
183
- melody_patch,
184
- use_nth_note,
185
- model_temperature,
186
- model_sampling_top_k
187
- ):
188
 
189
  #===============================================================================
190
 
191
- def generate_full_seq(input_seq,
192
- max_toks=3072,
193
- temperature=0.9,
194
- top_k_value=15,
195
- verbose=True
196
- ):
197
-
198
- seq_abs_run_time = sum([t for t in input_seq if t < 128])
199
-
200
- cur_time = 0
201
-
202
- full_seq = copy.deepcopy(input_seq)
203
-
204
- toks_counter = 0
205
-
206
- while cur_time <= seq_abs_run_time+32:
207
-
208
- if verbose:
209
- if toks_counter % 128 == 0:
210
- print('Generated', toks_counter, 'tokens')
211
-
212
- x = torch.LongTensor(full_seq).cuda()
213
-
214
- with ctx:
215
- out = model.generate(x,
216
- 1,
217
- filter_logits_fn=top_k,
218
- filter_kwargs={'k': top_k_value},
219
- temperature=temperature,
220
- return_prime=False,
221
- verbose=False)
222
-
223
- y = out.tolist()[0][0]
224
-
225
- if y < 128:
226
- cur_time += y
227
-
228
- full_seq.append(y)
229
-
230
- toks_counter += 1
231
 
232
- if toks_counter == max_toks:
233
- return full_seq
234
-
235
- return full_seq
236
 
237
  #===============================================================================
238
 
@@ -326,29 +270,51 @@ def Generate_Accompaniment(input_midi,
326
  vel = 90
327
  pitch = 0
328
  channel = 0
329
- patch = 0
330
-
331
- channels_map = [0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 11, 9, 12, 13, 14, 15]
332
- patches_map = [40, 0, 10, 19, 24, 35, 40, 52, 56, 9, 65, 73, 0, 0, 0, 0]
333
- velocities_map = [125, 80, 100, 80, 90, 100, 100, 80, 110, 110, 110, 110, 80, 80, 80, 80]
334
-
335
- for m in final_song:
336
-
337
- if 0 <= m < 128:
338
- time += m * 32
339
-
340
- elif 128 < m < 256:
341
- dur = (m-128) * 32
342
-
343
- elif 256 < m < 1792:
344
- cha = (m-256) // 128
345
- pitch = (m-256) % 128
346
-
347
- channel = channels_map[cha]
348
- patch = patches_map[channel]
349
- vel = velocities_map[channel]
350
-
351
- song_f.append(['note', time, dur, channel, pitch, vel, patch])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
352
 
353
  fn1 = "MIDI-Genre-Classifier-Composition"
354
 
@@ -356,7 +322,7 @@ def Generate_Accompaniment(input_midi,
356
  output_signature = 'MIDI Genre Classifier',
357
  output_file_name = fn1,
358
  track_name='Project Los Angeles',
359
- list_of_MIDI_patches=patches_map
360
  )
361
 
362
  new_fn = fn1+'.mid'
@@ -408,7 +374,8 @@ with gr.Blocks() as demo:
408
  #==================================================================================
409
 
410
  gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>MIDI Genre Classifier</h1>")
411
- gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>Guided melody accompaniment generation with transformers</h1>")
 
412
  gr.HTML("""
413
  <p>
414
  <a href="https://huggingface.co/spaces/asigalov61/MIDI-Genre-Classifier?duplicate=true">
@@ -416,7 +383,7 @@ with gr.Blocks() as demo:
416
  </a>
417
  </p>
418
 
419
- for faster execution and endless generation!
420
  """)
421
 
422
  #==================================================================================
@@ -462,25 +429,6 @@ with gr.Blocks() as demo:
462
  output_midi
463
  ]
464
  )
465
-
466
- gr.Examples(
467
- [["USSR-National-Anthem-Seed-Melody.mid", "Custom MIDI", -1, 1, 0.9, 15],
468
- ["Sparks-Fly-Seed-Melody.mid", "Custom MIDI", -1, 1, 0.9, 15]
469
- ],
470
- [input_midi,
471
- input_melody,
472
- melody_patch,
473
- use_nth_note,
474
- model_temperature,
475
- model_sampling_top_k
476
- ],
477
- [output_title,
478
- output_audio,
479
- output_plot,
480
- output_midi
481
- ],
482
- Generate_Accompaniment
483
- )
484
 
485
  #==================================================================================
486
 
 
61
 
62
  #==================================================================================
63
 
64
+ MODEL_CHECKPOINT = 'Giant_Music_Transformer_Medium_Trained_Model_42174_steps_0.5211_loss_0.8542_acc.pth'
65
 
66
  SOUDFONT_PATH = 'SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2'
67
 
68
+ #==================================================================================
69
+
70
+ print('=' * 70)
71
+ print('Loading MIDI GAS processed scores dataset...')
72
+
73
+ midi_gas_ps_pickle = hf_hub_download(repo_id='asigalov61/MIDI-GAS',
74
+ filename='MIDI_GAS_Processed_Scores_CC_BY_NC_SA.pickle'
75
+ )
76
+
77
+ midi_gas_ps = TMIDIX.Tegridy_Any_Pickle_File_Reader(midi_gas_ps_pickle)
78
 
79
+ print('=' * 70)
80
+ print('Done!')
81
+ print('=' * 70)
82
 
83
  #==================================================================================
84
 
85
  print('=' * 70)
86
+ print('Loading MIDI GAS processed scores embeddings dataset...')
87
 
88
+ midi_gas_pse_pickle = hf_hub_download(repo_id='asigalov61/MIDI-GAS',
89
+ filename='MIDI_GAS_Processed_Scores_Embeddings_CC_BY_NC_SA.pickle'
90
  )
91
 
92
+ midi_gas_pse = TMIDIX.Tegridy_Any_Pickle_File_Reader(midi_gas_pse_pickle)
93
 
94
  print('=' * 70)
95
  print('Done!')
 
100
  print('=' * 70)
101
  print('Instantiating model...')
102
 
103
+ device_type = 'cpu'
104
  dtype = 'bfloat16'
105
 
106
  ptdtype = {'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
107
  ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype)
108
 
109
+ SEQ_LEN = 8192
110
+ PAD_IDX = 19463
111
 
112
  model = TransformerWrapper(
113
  num_tokens = PAD_IDX+1,
114
  max_seq_len = SEQ_LEN,
115
  attn_layers = Decoder(dim = 2048,
116
+ depth = 8,
117
  heads = 32,
118
  rotary_pos_emb = True,
119
  attn_flash = True
 
125
  print('=' * 70)
126
  print('Loading model checkpoint...')
127
 
128
+ model_checkpoint = hf_hub_download(repo_id='asigalov61/Giant-Music-Transformer', filename=MODEL_CHECKPOINT)
129
 
130
  model.load_state_dict(torch.load(model_checkpoint, map_location='cpu', weights_only=True))
131
 
 
 
132
  print('=' * 70)
133
  print('Done!')
134
  print('=' * 70)
 
137
 
138
  #==================================================================================
139
 
140
+ def load_midi(input_midi):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
 
142
+ raw_score = TMIDIX.midi2single_track_ms_score(midi_file)
 
 
 
143
 
144
+ escore = TMIDIX.advanced_score_processor(raw_score, return_enhanced_score_notes=True)[0]
145
 
146
+ escore_notes = TMIDIX.augment_enhanced_score_notes(escore)
147
 
148
+ instruments_list = list(set([y[6] for y in escore_notes]))
149
 
150
+ tok_score = []
 
151
 
152
+ if 128 in instruments_list:
153
+ drums_present = 19331
154
 
155
+ else:
156
+ drums_present = 19330
 
 
157
 
158
+ pat = escore_notes[0][6]
159
 
160
+ tok_score.extend([19461, drums_present, 19332+pat])
161
 
162
+ tok_score.extend(TMIDIX.multi_instrumental_escore_notes_tokenized(escore_notes)[:8190])
163
 
164
+ return tok_score
165
 
166
  #==================================================================================
167
 
168
+ # @spaces.GPU
169
+ def Classify_MIDI_Genre(input_midi,
170
+ input_melody,
171
+ melody_patch,
172
+ use_nth_note,
173
+ model_temperature,
174
+ model_sampling_top_k
175
+ ):
176
 
177
  #===============================================================================
178
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
 
 
 
 
 
180
 
181
  #===============================================================================
182
 
 
270
  vel = 90
271
  pitch = 0
272
  channel = 0
273
+
274
+ patches = [-1] * 16
275
+ patches[9] = 9
276
+
277
+ channels = [0] * 16
278
+ channels[9] = 1
279
+
280
+ for ss in song:
281
+
282
+ if 0 <= ss < 256:
283
+
284
+ time += ss * 16
285
+
286
+ if 256 <= ss < 2304:
287
+
288
+ dur = ((ss-256) // 8) * 16
289
+ vel = (((ss-256) % 8)+1) * 15
290
+
291
+ if 2304 <= ss < 18945:
292
+
293
+ patch = (ss-2304) // 129
294
+
295
+ if patch < 128:
296
+
297
+ if patch not in patches:
298
+ if 0 in channels:
299
+ cha = channels.index(0)
300
+ channels[cha] = 1
301
+ else:
302
+ cha = 15
303
+
304
+ patches[cha] = patch
305
+ channel = patches.index(patch)
306
+ else:
307
+ channel = patches.index(patch)
308
+
309
+ if patch == 128:
310
+ channel = 9
311
+
312
+ pitch = (ss-2304) % 129
313
+
314
+ song_f.append(['note', time, dur, channel, pitch, vel, patch ])
315
+
316
+ patches = [0 if x==-1 else x for x in patches]
317
+
318
 
319
  fn1 = "MIDI-Genre-Classifier-Composition"
320
 
 
322
  output_signature = 'MIDI Genre Classifier',
323
  output_file_name = fn1,
324
  track_name='Project Los Angeles',
325
+ list_of_MIDI_patches=patches
326
  )
327
 
328
  new_fn = fn1+'.mid'
 
374
  #==================================================================================
375
 
376
  gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>MIDI Genre Classifier</h1>")
377
+ gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>Accurately classify any MIDI by top music genre/h1>")
378
+
379
  gr.HTML("""
380
  <p>
381
  <a href="https://huggingface.co/spaces/asigalov61/MIDI-Genre-Classifier?duplicate=true">
 
383
  </a>
384
  </p>
385
 
386
+ for faster execution and endless classification!
387
  """)
388
 
389
  #==================================================================================
 
429
  output_midi
430
  ]
431
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
432
 
433
  #==================================================================================
434