Spaces:
Running
on
T4
Running
on
T4
Update ROBERTAmodel.py
Browse files- ROBERTAmodel.py +11 -6
ROBERTAmodel.py
CHANGED
@@ -193,17 +193,20 @@ class RoBERTaVisualizer(TransformerVisualizer):
|
|
193 |
attention_mask=inputs["attention_mask"],
|
194 |
output_attentions=True
|
195 |
)
|
|
|
196 |
attentions_condensed = [a.mean(dim=0).mean(dim=0).sum(dim=0) for a in attentions]
|
197 |
-
print([a.shape for a in attentions_condensed])
|
198 |
attentions_condensed= torch.vstack(attentions_condensed)
|
199 |
-
print(attentions_condensed.shape)
|
200 |
return attentions_condensed
|
201 |
|
202 |
start = time.time()
|
203 |
-
jac = torch.autograd.functional.jacobian(scalar_outputs, inputs_embeds)
|
204 |
print(jac.shape)
|
205 |
-
|
206 |
-
|
|
|
|
|
|
|
|
|
207 |
print(31,time.time()-start)
|
208 |
attn_matrices_all = []
|
209 |
for target_layer in range(len(attentions)):
|
@@ -212,7 +215,9 @@ class RoBERTaVisualizer(TransformerVisualizer):
|
|
212 |
attn_matrix = mean_attns[target_layer]
|
213 |
seq_len = attn_matrix.shape[0]
|
214 |
attn_matrix = attn_matrix[:seq_len, :seq_len]
|
|
|
215 |
attn_matrices_all.append(attn_matrix.tolist())
|
|
|
216 |
print(3,time.time()-start)
|
217 |
|
218 |
|
@@ -259,6 +264,6 @@ class RoBERTaVisualizer(TransformerVisualizer):
|
|
259 |
grad_norms_list.append(grad_norms.unsqueeze(1))
|
260 |
print(2,time.time()-start)
|
261 |
"""
|
262 |
-
|
263 |
return grad_matrices_all, attn_matrices_all
|
264 |
|
|
|
193 |
attention_mask=inputs["attention_mask"],
|
194 |
output_attentions=True
|
195 |
)
|
196 |
+
attentions = outputs.attentions
|
197 |
attentions_condensed = [a.mean(dim=0).mean(dim=0).sum(dim=0) for a in attentions]
|
|
|
198 |
attentions_condensed= torch.vstack(attentions_condensed)
|
|
|
199 |
return attentions_condensed
|
200 |
|
201 |
start = time.time()
|
202 |
+
jac = torch.autograd.functional.jacobian(scalar_outputs, inputs_embeds)
|
203 |
print(jac.shape)
|
204 |
+
jac = jac.norm(dim=-1).squeeze(dim=2)
|
205 |
+
print(jac.shape)
|
206 |
+
seq_len = jac.shape[0]
|
207 |
+
print(seq_len)
|
208 |
+
grad_matrices_all = [jac[ii,:,:].tolist() for ii in range(seq_len)]
|
209 |
+
|
210 |
print(31,time.time()-start)
|
211 |
attn_matrices_all = []
|
212 |
for target_layer in range(len(attentions)):
|
|
|
215 |
attn_matrix = mean_attns[target_layer]
|
216 |
seq_len = attn_matrix.shape[0]
|
217 |
attn_matrix = attn_matrix[:seq_len, :seq_len]
|
218 |
+
print(4,attn_matrix.shape)
|
219 |
attn_matrices_all.append(attn_matrix.tolist())
|
220 |
+
|
221 |
print(3,time.time()-start)
|
222 |
|
223 |
|
|
|
264 |
grad_norms_list.append(grad_norms.unsqueeze(1))
|
265 |
print(2,time.time()-start)
|
266 |
"""
|
267 |
+
#print(grad_matrices_all)
|
268 |
return grad_matrices_all, attn_matrices_all
|
269 |
|