Spaces:
Running
on
T4
Running
on
T4
Update ROBERTAmodel.py
Browse files- ROBERTAmodel.py +14 -1
ROBERTAmodel.py
CHANGED
@@ -200,6 +200,7 @@ class RoBERTaVisualizer(TransformerVisualizer):
|
|
200 |
|
201 |
def get_grad_attn_matrix(self,inputs_embeds, attentions, mean_attns, target_layer):
|
202 |
|
|
|
203 |
attn_matrix = mean_attns[target_layer]
|
204 |
seq_len = attn_matrix.shape[0]
|
205 |
|
@@ -208,9 +209,12 @@ class RoBERTaVisualizer(TransformerVisualizer):
|
|
208 |
|
209 |
attn_layer = attentions[target_layer].squeeze(0).mean(dim=0) # [seq, seq]
|
210 |
|
|
|
211 |
print('Computing grad norms')
|
212 |
grad_norms_list = []
|
213 |
for k in range(seq_len):
|
|
|
|
|
214 |
scalar = attn_layer[:, k].sum()
|
215 |
grad = torch.autograd.grad(scalar, inputs_embeds, retain_graph=True)[0].squeeze(0)
|
216 |
grad_norms = grad.norm(dim=1)
|
@@ -218,18 +222,27 @@ class RoBERTaVisualizer(TransformerVisualizer):
|
|
218 |
|
219 |
grad_norms = torch.round(grad_norms.unsqueeze(1).float() * 100) / 100
|
220 |
grad_norms = grad_norms.to(torch.float16)
|
|
|
|
|
|
|
221 |
|
222 |
grad_norms_list.append(grad_norms)
|
223 |
-
|
|
|
|
|
224 |
grad_matrix = torch.cat(grad_norms_list, dim=1)
|
225 |
grad_matrix = grad_matrix[:seq_len, :seq_len]
|
226 |
attn_matrix = attn_matrix[:seq_len, :seq_len]
|
|
|
|
|
227 |
|
228 |
attn_matrix = torch.round(attn_matrix.float() * 100) / 100
|
229 |
attn_matrix = attn_matrix.to(torch.float16)
|
230 |
|
231 |
grad_matrix = torch.round(grad_matrix.float() * 100) / 100
|
232 |
grad_matrix = grad_matrix.to(torch.float16)
|
|
|
|
|
233 |
|
234 |
|
235 |
|
|
|
200 |
|
201 |
def get_grad_attn_matrix(self,inputs_embeds, attentions, mean_attns, target_layer):
|
202 |
|
203 |
+
start = time.time()
|
204 |
attn_matrix = mean_attns[target_layer]
|
205 |
seq_len = attn_matrix.shape[0]
|
206 |
|
|
|
209 |
|
210 |
attn_layer = attentions[target_layer].squeeze(0).mean(dim=0) # [seq, seq]
|
211 |
|
212 |
+
print(6,time.time()-start)
|
213 |
print('Computing grad norms')
|
214 |
grad_norms_list = []
|
215 |
for k in range(seq_len):
|
216 |
+
start = time.time()
|
217 |
+
print(7,k,time.time()-start)
|
218 |
scalar = attn_layer[:, k].sum()
|
219 |
grad = torch.autograd.grad(scalar, inputs_embeds, retain_graph=True)[0].squeeze(0)
|
220 |
grad_norms = grad.norm(dim=1)
|
|
|
222 |
|
223 |
grad_norms = torch.round(grad_norms.unsqueeze(1).float() * 100) / 100
|
224 |
grad_norms = grad_norms.to(torch.float16)
|
225 |
+
|
226 |
+
start = time.time()
|
227 |
+
print(8,k,time.time()-start)
|
228 |
|
229 |
grad_norms_list.append(grad_norms)
|
230 |
+
|
231 |
+
start = time.time()
|
232 |
+
print(9,time.time()-start)
|
233 |
grad_matrix = torch.cat(grad_norms_list, dim=1)
|
234 |
grad_matrix = grad_matrix[:seq_len, :seq_len]
|
235 |
attn_matrix = attn_matrix[:seq_len, :seq_len]
|
236 |
+
start = time.time()
|
237 |
+
print(10,time.time()-start)
|
238 |
|
239 |
attn_matrix = torch.round(attn_matrix.float() * 100) / 100
|
240 |
attn_matrix = attn_matrix.to(torch.float16)
|
241 |
|
242 |
grad_matrix = torch.round(grad_matrix.float() * 100) / 100
|
243 |
grad_matrix = grad_matrix.to(torch.float16)
|
244 |
+
start = time.time()
|
245 |
+
print(11,time.time()-start)
|
246 |
|
247 |
|
248 |
|