yifan0sun commited on
Commit
b9eb0a3
·
verified ·
1 Parent(s): 7c7c06a

Update models.py

Browse files
Files changed (1) hide show
  1. models.py +16 -15
models.py CHANGED
@@ -1,16 +1,17 @@
1
- import torch
2
-
3
-
4
-
5
-
6
- class TransformerVisualizer():
7
- def __init__(self):
8
- self.device = torch.device('cpu')
9
-
10
- def predict(self, task, text):
11
- return task, text,1
12
-
13
-
14
- def get_attention_gradient_matrix(self, task, text, target_layer):
15
- return task, text,target_layer,1
 
16
 
 
1
+ import torch
2
+
3
+
4
+
5
+
6
+ class TransformerVisualizer():
7
+ def __init__(self):
8
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
+
10
+
11
+ def predict(self, task, text):
12
+ return task, text,1
13
+
14
+
15
+ def get_attention_gradient_matrix(self, task, text, target_layer):
16
+ return task, text,target_layer,1
17