Spaces:
Running
on
T4
Running
on
T4
Update ROBERTAmodel.py
Browse files- ROBERTAmodel.py +8 -6
ROBERTAmodel.py
CHANGED
@@ -186,16 +186,23 @@ class RoBERTaVisualizer(TransformerVisualizer):
|
|
186 |
print('Average attentions per layer')
|
187 |
mean_attns = [a.squeeze(0).mean(dim=0).detach().cpu() for a in attentions]
|
188 |
print(4,time.time()-start)
|
|
|
189 |
start = time.time()
|
190 |
|
191 |
attn_matrices_all = []
|
192 |
grad_matrices_all = []
|
193 |
for target_layer in range(len(attentions)):
|
|
|
|
|
194 |
grad_matrix, attn_matrix = self.get_grad_attn_matrix(inputs_embeds, attentions, mean_attns, target_layer)
|
|
|
|
|
195 |
grad_matrices_all.append(grad_matrix.tolist())
|
196 |
attn_matrices_all.append(attn_matrix.tolist())
|
|
|
|
|
197 |
|
198 |
-
print(
|
199 |
return grad_matrices_all, attn_matrices_all
|
200 |
|
201 |
def get_grad_attn_matrix(self,inputs_embeds, attentions, mean_attns, target_layer):
|
@@ -209,12 +216,10 @@ class RoBERTaVisualizer(TransformerVisualizer):
|
|
209 |
|
210 |
attn_layer = attentions[target_layer].squeeze(0).mean(dim=0) # [seq, seq]
|
211 |
|
212 |
-
print(6,time.time()-start)
|
213 |
print('Computing grad norms')
|
214 |
grad_norms_list = []
|
215 |
for k in range(seq_len):
|
216 |
start = time.time()
|
217 |
-
print(7,k,time.time()-start)
|
218 |
scalar = attn_layer[:, k].sum()
|
219 |
grad = torch.autograd.grad(scalar, inputs_embeds, retain_graph=True)[0].squeeze(0)
|
220 |
grad_norms = grad.norm(dim=1)
|
@@ -224,7 +229,6 @@ class RoBERTaVisualizer(TransformerVisualizer):
|
|
224 |
grad_norms = grad_norms.to(torch.float16)
|
225 |
|
226 |
start = time.time()
|
227 |
-
print(8,k,time.time()-start)
|
228 |
|
229 |
grad_norms_list.append(grad_norms)
|
230 |
|
@@ -234,7 +238,6 @@ class RoBERTaVisualizer(TransformerVisualizer):
|
|
234 |
grad_matrix = grad_matrix[:seq_len, :seq_len]
|
235 |
attn_matrix = attn_matrix[:seq_len, :seq_len]
|
236 |
start = time.time()
|
237 |
-
print(10,time.time()-start)
|
238 |
|
239 |
attn_matrix = torch.round(attn_matrix.float() * 100) / 100
|
240 |
attn_matrix = attn_matrix.to(torch.float16)
|
@@ -242,7 +245,6 @@ class RoBERTaVisualizer(TransformerVisualizer):
|
|
242 |
grad_matrix = torch.round(grad_matrix.float() * 100) / 100
|
243 |
grad_matrix = grad_matrix.to(torch.float16)
|
244 |
start = time.time()
|
245 |
-
print(11,time.time()-start)
|
246 |
|
247 |
|
248 |
|
|
|
186 |
print('Average attentions per layer')
|
187 |
mean_attns = [a.squeeze(0).mean(dim=0).detach().cpu() for a in attentions]
|
188 |
print(4,time.time()-start)
|
189 |
+
startloop = time.time()
|
190 |
start = time.time()
|
191 |
|
192 |
attn_matrices_all = []
|
193 |
grad_matrices_all = []
|
194 |
for target_layer in range(len(attentions)):
|
195 |
+
print(5,target_layer, len(attentions), time.time()-start)
|
196 |
+
start = time.time()
|
197 |
grad_matrix, attn_matrix = self.get_grad_attn_matrix(inputs_embeds, attentions, mean_attns, target_layer)
|
198 |
+
print(6,target_layer, len(attentions), time.time()-start)
|
199 |
+
start = time.time()
|
200 |
grad_matrices_all.append(grad_matrix.tolist())
|
201 |
attn_matrices_all.append(attn_matrix.tolist())
|
202 |
+
print(7,target_layer, len(attentions), time.time()-start)
|
203 |
+
start = time.time()
|
204 |
|
205 |
+
print(8,time.time()-startloop)
|
206 |
return grad_matrices_all, attn_matrices_all
|
207 |
|
208 |
def get_grad_attn_matrix(self,inputs_embeds, attentions, mean_attns, target_layer):
|
|
|
216 |
|
217 |
attn_layer = attentions[target_layer].squeeze(0).mean(dim=0) # [seq, seq]
|
218 |
|
|
|
219 |
print('Computing grad norms')
|
220 |
grad_norms_list = []
|
221 |
for k in range(seq_len):
|
222 |
start = time.time()
|
|
|
223 |
scalar = attn_layer[:, k].sum()
|
224 |
grad = torch.autograd.grad(scalar, inputs_embeds, retain_graph=True)[0].squeeze(0)
|
225 |
grad_norms = grad.norm(dim=1)
|
|
|
229 |
grad_norms = grad_norms.to(torch.float16)
|
230 |
|
231 |
start = time.time()
|
|
|
232 |
|
233 |
grad_norms_list.append(grad_norms)
|
234 |
|
|
|
238 |
grad_matrix = grad_matrix[:seq_len, :seq_len]
|
239 |
attn_matrix = attn_matrix[:seq_len, :seq_len]
|
240 |
start = time.time()
|
|
|
241 |
|
242 |
attn_matrix = torch.round(attn_matrix.float() * 100) / 100
|
243 |
attn_matrix = attn_matrix.to(torch.float16)
|
|
|
245 |
grad_matrix = torch.round(grad_matrix.float() * 100) / 100
|
246 |
grad_matrix = grad_matrix.to(torch.float16)
|
247 |
start = time.time()
|
|
|
248 |
|
249 |
|
250 |
|