asigalov61 commited on
Commit
4d766d5
·
verified ·
1 Parent(s): 7fd51d1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -33
app.py CHANGED
@@ -29,8 +29,8 @@ def classify_GPU(input_data):
29
 
30
  print('Loading model...')
31
 
32
- SEQ_LEN = 1024
33
- PAD_IDX = 14627
34
  DEVICE = 'cuda' # 'cuda'
35
 
36
  # instantiate the model
@@ -38,18 +38,21 @@ def classify_GPU(input_data):
38
  model = TransformerWrapper(
39
  num_tokens = PAD_IDX+1,
40
  max_seq_len = SEQ_LEN,
41
- attn_layers = Decoder(dim = 1024, depth = 12, heads = 16, attn_flash = True)
42
- )
 
 
43
 
44
- model = AutoregressiveWrapper(model, ignore_index = PAD_IDX)
45
 
46
  model.to(DEVICE)
 
47
  print('=' * 70)
48
 
49
  print('Loading model checkpoint...')
50
 
51
  model.load_state_dict(
52
- torch.load('Annotated_MIDI_Dataset_Classifier_Trained_Model_21269_steps_0.4335_loss_0.8716_acc.pth',
53
  map_location=DEVICE))
54
  print('=' * 70)
55
 
@@ -67,12 +70,8 @@ def classify_GPU(input_data):
67
 
68
  #==================================================================
69
 
70
- number_of_batches = 1 # @param {type:"slider", min:1, max:100, step:1}
71
-
72
- # @markdown NOTE: You can increase the number of batches on high-ram GPUs for better classification
73
-
74
  print('=' * 70)
75
- print('Annotated MIDI Dataset Classifier')
76
  print('=' * 70)
77
  print('Classifying...')
78
 
@@ -80,29 +79,20 @@ def classify_GPU(input_data):
80
 
81
  model.eval()
82
 
83
- results = []
84
-
85
- for input in input_data:
86
-
87
- x = torch.tensor([input[:1022]] * number_of_batches, dtype=torch.long, device='cuda')
88
-
89
- with ctx:
90
- out = model.generate(x,
91
- 1,
92
- temperature=0.3,
93
- filter_logits_fn=top_k,
94
- filter_kwargs={'k': 1},
95
- return_prime=False,
96
- verbose=False)
97
-
98
- y = out.tolist()
99
-
100
- output = [l[0] for l in y]
101
- result = mode(output)
102
-
103
- results.append(result)
104
 
105
- return output, results
 
 
106
 
107
  # =================================================================================================
108
 
 
29
 
30
  print('Loading model...')
31
 
32
+ SEQ_LEN = 1026
33
+ PAD_IDX = 940
34
  DEVICE = 'cuda' # 'cuda'
35
 
36
  # instantiate the model
 
38
  model = TransformerWrapper(
39
  num_tokens = PAD_IDX+1,
40
  max_seq_len = SEQ_LEN,
41
+ attn_layers = Decoder(dim = 1024, depth = 24, heads = 32, attn_flash = True)
42
+ )
43
+
44
+ model = AutoregressiveWrapper(model, ignore_index=PAD_IDX, pad_value=PAD_IDX)
45
 
46
+ model = torch.nn.DataParallel(model)
47
 
48
  model.to(DEVICE)
49
+
50
  print('=' * 70)
51
 
52
  print('Loading model checkpoint...')
53
 
54
  model.load_state_dict(
55
+ torch.load('Ultimate_MIDI_Classifier_Trained_Model_29886_steps_0.556_loss_0.8339_acc.pth',
56
  map_location=DEVICE))
57
  print('=' * 70)
58
 
 
70
 
71
  #==================================================================
72
 
 
 
 
 
73
  print('=' * 70)
74
+ print('Ultimate MIDI Classifier')
75
  print('=' * 70)
76
  print('Classifying...')
77
 
 
79
 
80
  model.eval()
81
 
82
+ x = torch.tensor(input_data[:1022], dtype=torch.long, device=DEVICE)
83
+
84
+ with ctx:
85
+ out = model.module.generate(x,
86
+ 2,
87
+ filter_logits_fn=top_k,
88
+ filter_kwargs={'k': 1},
89
+ temperature=0.9,
90
+ return_prime=False,
91
+ verbose=False)
 
 
 
 
 
 
 
 
 
 
 
92
 
93
+ result = tuple(out[0].tolist())
94
+
95
+ return result
96
 
97
  # =================================================================================================
98