yifan0sun commited on
Commit
2015ce0
·
verified ·
1 Parent(s): 1f7f45a

Update ROBERTAmodel.py

Browse files
Files changed (1) hide show
  1. ROBERTAmodel.py +40 -2
ROBERTAmodel.py CHANGED
@@ -183,6 +183,42 @@ class RoBERTaVisualizer(TransformerVisualizer):
183
  print('Average attentions per layer')
184
  mean_attns = [a.squeeze(0).mean(dim=0).detach().cpu() for a in attentions]
185
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
  attn_matrices_all = []
187
  grad_matrices_all = []
188
  for target_layer in range(len(attentions)):
@@ -194,7 +230,7 @@ class RoBERTaVisualizer(TransformerVisualizer):
194
  attn_matrices_all.append(attn_matrix.tolist())
195
 
196
 
197
- """
198
  start = time.time()
199
  def scalar_outputs(inputs_embeds):
200
 
@@ -210,7 +246,8 @@ class RoBERTaVisualizer(TransformerVisualizer):
210
 
211
  grad_matrices_all.append(jac.tolist())
212
  print(1,time.time()-start)
213
- """
 
214
  start = time.time()
215
  grad_norms_list = []
216
  scalar_layer = attentions[target_layer].mean(dim=0).mean(dim=0)
@@ -221,6 +258,7 @@ class RoBERTaVisualizer(TransformerVisualizer):
221
  grad_norms = grad.norm(dim=1)
222
  grad_norms_list.append(grad_norms.unsqueeze(1))
223
  print(2,time.time()-start)
 
224
 
225
  return grad_matrices_all, attn_matrices_all
226
 
 
183
  print('Average attentions per layer')
184
  mean_attns = [a.squeeze(0).mean(dim=0).detach().cpu() for a in attentions]
185
 
186
+
187
+
188
+
189
+ def scalar_outputs(inputs_embeds):
190
+
191
+ outputs = self.model.roberta(
192
+ inputs_embeds=inputs_embeds,
193
+ attention_mask=inputs["attention_mask"],
194
+ output_attentions=True
195
+ )
196
+ attentions_condensed = [a.mean(dim=0).mean(dim=0).sum(dim=0) for a in attentions]
197
+ print([a.shape for a in attentions_condensed])
198
+ attentions_condensed= torch.vstack(attentions_condensed)
199
+ print(attentions_condensed.shape)
200
+ return attentions_condensed
201
+
202
+ start = time.time()
203
+ jac = torch.autograd.functional.jacobian(scalar_outputs, inputs_embeds).norm(dim=-1).squeeze(dim=2)
204
+ print(jac.shape)
205
+ grad_matrices_all = [jac[i] for i in range(jac.size(0))]
206
+
207
+ print(31,time.time()-start)
208
+ attn_matrices_all = []
209
+ for target_layer in range(len(attentions)):
210
+ #grad_matrix, attn_matrix = self.get_grad_attn_matrix(inputs_embeds, attentions, mean_attns, target_layer)
211
+
212
+ attn_matrix = mean_attns[target_layer]
213
+ seq_len = attn_matrix.shape[0]
214
+ attn_matrix = attn_matrix[:seq_len, :seq_len]
215
+ attn_matrices_all.append(attn_matrix.tolist())
216
+ print(3,time.time()-start)
217
+
218
+
219
+
220
+ """
221
+
222
  attn_matrices_all = []
223
  grad_matrices_all = []
224
  for target_layer in range(len(attentions)):
 
230
  attn_matrices_all.append(attn_matrix.tolist())
231
 
232
 
233
+
234
  start = time.time()
235
  def scalar_outputs(inputs_embeds):
236
 
 
246
 
247
  grad_matrices_all.append(jac.tolist())
248
  print(1,time.time()-start)
249
+
250
+
251
  start = time.time()
252
  grad_norms_list = []
253
  scalar_layer = attentions[target_layer].mean(dim=0).mean(dim=0)
 
258
  grad_norms = grad.norm(dim=1)
259
  grad_norms_list.append(grad_norms.unsqueeze(1))
260
  print(2,time.time()-start)
261
+ """
262
 
263
  return grad_matrices_all, attn_matrices_all
264