Update app.py
Browse files
app.py
CHANGED
@@ -29,8 +29,8 @@ def classify_GPU(input_data):
|
|
29 |
|
30 |
print('Loading model...')
|
31 |
|
32 |
-
SEQ_LEN =
|
33 |
-
PAD_IDX =
|
34 |
DEVICE = 'cuda' # 'cuda'
|
35 |
|
36 |
# instantiate the model
|
@@ -38,18 +38,21 @@ def classify_GPU(input_data):
|
|
38 |
model = TransformerWrapper(
|
39 |
num_tokens = PAD_IDX+1,
|
40 |
max_seq_len = SEQ_LEN,
|
41 |
-
attn_layers = Decoder(dim = 1024, depth =
|
42 |
-
|
|
|
|
|
43 |
|
44 |
-
model =
|
45 |
|
46 |
model.to(DEVICE)
|
|
|
47 |
print('=' * 70)
|
48 |
|
49 |
print('Loading model checkpoint...')
|
50 |
|
51 |
model.load_state_dict(
|
52 |
-
torch.load('
|
53 |
map_location=DEVICE))
|
54 |
print('=' * 70)
|
55 |
|
@@ -67,12 +70,8 @@ def classify_GPU(input_data):
|
|
67 |
|
68 |
#==================================================================
|
69 |
|
70 |
-
number_of_batches = 1 # @param {type:"slider", min:1, max:100, step:1}
|
71 |
-
|
72 |
-
# @markdown NOTE: You can increase the number of batches on high-ram GPUs for better classification
|
73 |
-
|
74 |
print('=' * 70)
|
75 |
-
print('
|
76 |
print('=' * 70)
|
77 |
print('Classifying...')
|
78 |
|
@@ -80,29 +79,20 @@ def classify_GPU(input_data):
|
|
80 |
|
81 |
model.eval()
|
82 |
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
filter_logits_fn=top_k,
|
94 |
-
filter_kwargs={'k': 1},
|
95 |
-
return_prime=False,
|
96 |
-
verbose=False)
|
97 |
-
|
98 |
-
y = out.tolist()
|
99 |
-
|
100 |
-
output = [l[0] for l in y]
|
101 |
-
result = mode(output)
|
102 |
-
|
103 |
-
results.append(result)
|
104 |
|
105 |
-
|
|
|
|
|
106 |
|
107 |
# =================================================================================================
|
108 |
|
|
|
29 |
|
30 |
print('Loading model...')
|
31 |
|
32 |
+
SEQ_LEN = 1026
|
33 |
+
PAD_IDX = 940
|
34 |
DEVICE = 'cuda' # 'cuda'
|
35 |
|
36 |
# instantiate the model
|
|
|
38 |
model = TransformerWrapper(
|
39 |
num_tokens = PAD_IDX+1,
|
40 |
max_seq_len = SEQ_LEN,
|
41 |
+
attn_layers = Decoder(dim = 1024, depth = 24, heads = 32, attn_flash = True)
|
42 |
+
)
|
43 |
+
|
44 |
+
model = AutoregressiveWrapper(model, ignore_index=PAD_IDX, pad_value=PAD_IDX)
|
45 |
|
46 |
+
model = torch.nn.DataParallel(model)
|
47 |
|
48 |
model.to(DEVICE)
|
49 |
+
|
50 |
print('=' * 70)
|
51 |
|
52 |
print('Loading model checkpoint...')
|
53 |
|
54 |
model.load_state_dict(
|
55 |
+
torch.load('Ultimate_MIDI_Classifier_Trained_Model_29886_steps_0.556_loss_0.8339_acc.pth',
|
56 |
map_location=DEVICE))
|
57 |
print('=' * 70)
|
58 |
|
|
|
70 |
|
71 |
#==================================================================
|
72 |
|
|
|
|
|
|
|
|
|
73 |
print('=' * 70)
|
74 |
+
print('Ultimate MIDI Classifier')
|
75 |
print('=' * 70)
|
76 |
print('Classifying...')
|
77 |
|
|
|
79 |
|
80 |
model.eval()
|
81 |
|
82 |
+
x = torch.tensor(input_data[:1022], dtype=torch.long, device=DEVICE)
|
83 |
+
|
84 |
+
with ctx:
|
85 |
+
out = model.module.generate(x,
|
86 |
+
2,
|
87 |
+
filter_logits_fn=top_k,
|
88 |
+
filter_kwargs={'k': 1},
|
89 |
+
temperature=0.9,
|
90 |
+
return_prime=False,
|
91 |
+
verbose=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
|
93 |
+
result = tuple(out[0].tolist())
|
94 |
+
|
95 |
+
return result
|
96 |
|
97 |
# =================================================================================================
|
98 |
|