yifan0sun commited on
Commit
d594b53
·
verified ·
1 Parent(s): 67aa9c5

Update ROBERTAmodel.py

Browse files
Files changed (1) hide show
  1. ROBERTAmodel.py +14 -1
ROBERTAmodel.py CHANGED
@@ -200,6 +200,7 @@ class RoBERTaVisualizer(TransformerVisualizer):
200
 
201
  def get_grad_attn_matrix(self,inputs_embeds, attentions, mean_attns, target_layer):
202
 
 
203
  attn_matrix = mean_attns[target_layer]
204
  seq_len = attn_matrix.shape[0]
205
 
@@ -208,9 +209,12 @@ class RoBERTaVisualizer(TransformerVisualizer):
208
 
209
  attn_layer = attentions[target_layer].squeeze(0).mean(dim=0) # [seq, seq]
210
 
 
211
  print('Computing grad norms')
212
  grad_norms_list = []
213
  for k in range(seq_len):
 
 
214
  scalar = attn_layer[:, k].sum()
215
  grad = torch.autograd.grad(scalar, inputs_embeds, retain_graph=True)[0].squeeze(0)
216
  grad_norms = grad.norm(dim=1)
@@ -218,18 +222,27 @@ class RoBERTaVisualizer(TransformerVisualizer):
218
 
219
  grad_norms = torch.round(grad_norms.unsqueeze(1).float() * 100) / 100
220
  grad_norms = grad_norms.to(torch.float16)
 
 
 
221
 
222
  grad_norms_list.append(grad_norms)
223
-
 
 
224
  grad_matrix = torch.cat(grad_norms_list, dim=1)
225
  grad_matrix = grad_matrix[:seq_len, :seq_len]
226
  attn_matrix = attn_matrix[:seq_len, :seq_len]
 
 
227
 
228
  attn_matrix = torch.round(attn_matrix.float() * 100) / 100
229
  attn_matrix = attn_matrix.to(torch.float16)
230
 
231
  grad_matrix = torch.round(grad_matrix.float() * 100) / 100
232
  grad_matrix = grad_matrix.to(torch.float16)
 
 
233
 
234
 
235
 
 
200
 
201
  def get_grad_attn_matrix(self,inputs_embeds, attentions, mean_attns, target_layer):
202
 
203
+ start = time.time()
204
  attn_matrix = mean_attns[target_layer]
205
  seq_len = attn_matrix.shape[0]
206
 
 
209
 
210
  attn_layer = attentions[target_layer].squeeze(0).mean(dim=0) # [seq, seq]
211
 
212
+ print(6,time.time()-start)
213
  print('Computing grad norms')
214
  grad_norms_list = []
215
  for k in range(seq_len):
216
+ start = time.time()
217
+ print(7,k,time.time()-start)
218
  scalar = attn_layer[:, k].sum()
219
  grad = torch.autograd.grad(scalar, inputs_embeds, retain_graph=True)[0].squeeze(0)
220
  grad_norms = grad.norm(dim=1)
 
222
 
223
  grad_norms = torch.round(grad_norms.unsqueeze(1).float() * 100) / 100
224
  grad_norms = grad_norms.to(torch.float16)
225
+
226
+ start = time.time()
227
+ print(8,k,time.time()-start)
228
 
229
  grad_norms_list.append(grad_norms)
230
+
231
+ start = time.time()
232
+ print(9,time.time()-start)
233
  grad_matrix = torch.cat(grad_norms_list, dim=1)
234
  grad_matrix = grad_matrix[:seq_len, :seq_len]
235
  attn_matrix = attn_matrix[:seq_len, :seq_len]
236
+ start = time.time()
237
+ print(10,time.time()-start)
238
 
239
  attn_matrix = torch.round(attn_matrix.float() * 100) / 100
240
  attn_matrix = attn_matrix.to(torch.float16)
241
 
242
  grad_matrix = torch.round(grad_matrix.float() * 100) / 100
243
  grad_matrix = grad_matrix.to(torch.float16)
244
+ start = time.time()
245
+ print(11,time.time()-start)
246
 
247
 
248