Spaces:
Running
on
T4
Running
on
T4
Update ROBERTAmodel.py
Browse files- ROBERTAmodel.py +5 -35
ROBERTAmodel.py
CHANGED
@@ -194,7 +194,7 @@ class RoBERTaVisualizer(TransformerVisualizer):
|
|
194 |
attn_matrices_all.append(attn_matrix.tolist())
|
195 |
|
196 |
|
197 |
-
|
198 |
start = time.time()
|
199 |
def scalar_outputs(inputs_embeds):
|
200 |
|
@@ -210,14 +210,12 @@ class RoBERTaVisualizer(TransformerVisualizer):
|
|
210 |
|
211 |
grad_matrices_all.append(jac.tolist())
|
212 |
print(1,time.time()-start)
|
213 |
-
|
214 |
start = time.time()
|
215 |
grad_norms_list = []
|
216 |
-
|
217 |
for k in range(seq_len):
|
218 |
-
scalar =
|
219 |
-
scalar = scalar[:, k].sum()
|
220 |
-
|
221 |
grad = torch.autograd.grad(scalar, inputs_embeds, retain_graph=True)[0].squeeze(0)
|
222 |
|
223 |
grad_norms = grad.norm(dim=1)
|
@@ -225,32 +223,4 @@ class RoBERTaVisualizer(TransformerVisualizer):
|
|
225 |
print(2,time.time()-start)
|
226 |
|
227 |
return grad_matrices_all, attn_matrices_all
|
228 |
-
|
229 |
-
def get_grad_attn_matrix(self,inputs_embeds, attentions, mean_attns, target_layer):
|
230 |
-
|
231 |
-
attn_matrix = mean_attns[target_layer]
|
232 |
-
seq_len = attn_matrix.shape[0]
|
233 |
-
attn_layer = attentions[target_layer].squeeze(0).mean(dim=0) # [seq, seq]
|
234 |
-
"""
|
235 |
-
print('Computing grad norms')
|
236 |
-
grad_norms_list = []
|
237 |
-
|
238 |
-
for k in range(seq_len):
|
239 |
-
scalar = attn_layer[:, k].sum()
|
240 |
-
|
241 |
-
grad = torch.autograd.grad(scalar, inputs_embeds, retain_graph=True)[0].squeeze(0)
|
242 |
-
|
243 |
-
grad_norms = grad.norm(dim=1)
|
244 |
-
grad_norms_list.append(grad_norms.unsqueeze(1))
|
245 |
-
|
246 |
-
grad_matrix = torch.cat(grad_norms_list, dim=1)
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
grad_matrix = grad_matrix[:seq_len, :seq_len]
|
252 |
-
"""
|
253 |
-
attn_matrix = attn_matrix[:seq_len, :seq_len]
|
254 |
-
grad_matrix = attn_matrix
|
255 |
-
|
256 |
-
return grad_matrix, attn_matrix
|
|
|
194 |
attn_matrices_all.append(attn_matrix.tolist())
|
195 |
|
196 |
|
197 |
+
"""
|
198 |
start = time.time()
|
199 |
def scalar_outputs(inputs_embeds):
|
200 |
|
|
|
210 |
|
211 |
grad_matrices_all.append(jac.tolist())
|
212 |
print(1,time.time()-start)
|
213 |
+
"""
|
214 |
start = time.time()
|
215 |
grad_norms_list = []
|
216 |
+
scalar_layer = attentions[target_layer].mean(dim=0).mean(dim=0)
|
217 |
for k in range(seq_len):
|
218 |
+
scalar = scalar_layer[:, k].sum()
|
|
|
|
|
219 |
grad = torch.autograd.grad(scalar, inputs_embeds, retain_graph=True)[0].squeeze(0)
|
220 |
|
221 |
grad_norms = grad.norm(dim=1)
|
|
|
223 |
print(2,time.time()-start)
|
224 |
|
225 |
return grad_matrices_all, attn_matrices_all
|
226 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|