yifan0sun commited on
Commit
19c6bd3
·
verified ·
1 Parent(s): e60ad79

Update BERTmodel.py

Browse files
Files changed (1) hide show
  1. BERTmodel.py +292 -294
BERTmodel.py CHANGED
@@ -1,295 +1,293 @@
1
- import torch
2
- import torch.nn as nn
3
- from transformers import BertTokenizer
4
-
5
- from models import TransformerVisualizer
6
-
7
- from transformers import (
8
- BertTokenizer,
9
- BertForMaskedLM,
10
- BertForSequenceClassification,
11
- )
12
- import torch.nn.functional as F
13
- import os
14
-
15
- CACHE_DIR = "/data/hf_cache"
16
-
17
-
18
- class BERTVisualizer(TransformerVisualizer):
19
- def __init__(self,task):
20
- super().__init__()
21
- self.task = task
22
- print(task,'BERT VIS START')
23
-
24
- TOKENIZER = 'bert-base-uncased'
25
- LOCAL_PATH = os.path.join(CACHE_DIR, "tokenizers",TOKENIZER)
26
-
27
-
28
- self.tokenizer = BertTokenizer.from_pretrained(LOCAL_PATH, local_files_only=True)
29
- """
30
- try:
31
- self.tokenizer = BertTokenizer.from_pretrained(LOCAL_PATH, local_files_only=True)
32
- except Exception as e:
33
- self.tokenizer = BertTokenizer.from_pretrained(TOKENIZER)
34
- self.tokenizer.save_pretrained(LOCAL_PATH)
35
- """
36
-
37
-
38
- print('finding model', self.task)
39
- if self.task == 'mlm':
40
-
41
- MODEL = 'bert-base-uncased'
42
- LOCAL_PATH = os.path.join(CACHE_DIR, "models",MODEL)
43
-
44
- self.model = BertForMaskedLM.from_pretrained( LOCAL_PATH, local_files_only=True, attn_implementation="eager" ).to(self.device)
45
- """
46
- try:
47
- self.model = BertForMaskedLM.from_pretrained( LOCAL_PATH, local_files_only=True, attn_implementation="eager" ).to(self.device)
48
- except Exception as e:
49
- self.model = BertForMaskedLM.from_pretrained( MODEL, attn_implementation="eager" ).to(self.device)
50
- self.model.save_pretrained(LOCAL_PATH)
51
- """
52
- elif self.task == 'sst':
53
- MODEL = "textattack_bert-base-uncased-SST-2"
54
- LOCAL_PATH = os.path.join(CACHE_DIR, "models",MODEL)
55
-
56
- self.model = BertForSequenceClassification.from_pretrained( LOCAL_PATH, local_files_only=True, device_map=None )
57
- """
58
- try:
59
- self.model = BertForSequenceClassification.from_pretrained( LOCAL_PATH, local_files_only=True, device_map=None )
60
- except Exception as e:
61
- self.model = BertForSequenceClassification.from_pretrained( MODEL, device_map=None )
62
- self.model.save_pretrained(LOCAL_PATH)
63
- """
64
-
65
- elif self.task == 'mnli':
66
- MODEL = 'textattack_bert-base-uncased-MNLI'
67
-
68
-
69
- LOCAL_PATH = os.path.join(CACHE_DIR, "models",MODEL)
70
-
71
- self.model = BertForSequenceClassification.from_pretrained( LOCAL_PATH, local_files_only=True, device_map=None )
72
- """
73
- try:
74
- self.model = BertForSequenceClassification.from_pretrained( LOCAL_PATH, local_files_only=True, device_map=None )
75
- except Exception as e:
76
- self.model = BertForSequenceClassification.from_pretrained( MODEL, device_map=None)
77
- self.model.save_pretrained(LOCAL_PATH)
78
- """
79
-
80
-
81
-
82
- else:
83
- raise ValueError(f"Unsupported task: {self.task}")
84
-
85
-
86
-
87
- print('model found')
88
- #self.model.to(self.device)
89
- print('self device junk')
90
- self.model.eval()
91
- print('self model eval')
92
- self.num_attention_layers = len(self.model.bert.encoder.layer)
93
- print('init finished')
94
-
95
- def tokenize(self, text, hypothesis = ''):
96
- print('TTTokenize',text,'H:', hypothesis)
97
- if len(hypothesis) == 0:
98
- encoded = self.tokenizer(text, return_tensors='pt', return_attention_mask=True)
99
- else:
100
- encoded = self.tokenizer(text, hypothesis, return_tensors='pt', return_attention_mask=True)
101
- input_ids = encoded['input_ids'].to(self.device)
102
- attention_mask = encoded['attention_mask'].to(self.device)
103
- tokens = self.tokenizer.convert_ids_to_tokens(input_ids[0])
104
- return {
105
- 'input_ids': input_ids,
106
- 'attention_mask': attention_mask,
107
- 'tokens': tokens
108
- }
109
-
110
-
111
- def predict(self, task, text, hypothesis='', maskID = None):
112
-
113
- print(task,text,hypothesis)
114
-
115
-
116
-
117
- if task == 'mlm':
118
-
119
- # Tokenize and find [MASK] position
120
- print('Tokenize and find [MASK] position')
121
- inputs = self.tokenizer(text, return_tensors='pt', padding=False, truncation=True)
122
- if maskID is not None and 0 <= maskID < inputs['input_ids'].size(1):
123
- inputs['input_ids'][0][maskID] = self.tokenizer.mask_token_id
124
- mask_index = maskID
125
- else:
126
- raise ValueError(f"Invalid maskID {maskID} for input length {inputs['input_ids'].size(1)}")
127
-
128
-
129
-
130
- # Move to device
131
- inputs = {k: v.to(self.device) for k, v in inputs.items()}
132
-
133
- # Get embeddings
134
- embedding_layer = self.model.bert.embeddings.word_embeddings
135
- inputs_embeds = embedding_layer(inputs['input_ids'])
136
-
137
- # Forward through BERT encoder
138
-
139
- hidden_states = self.model.bert(inputs_embeds=inputs_embeds,
140
- attention_mask=inputs['attention_mask']).last_hidden_state
141
-
142
- # Predict logits via MLM head
143
- logits = self.model.cls(hidden_states)
144
- mask_logits = logits[0, mask_index]
145
-
146
- top_probs, top_indices = torch.topk(mask_logits, k=10, dim=-1)
147
- top_probs = F.softmax(top_probs, dim=-1)
148
- decoded = self.tokenizer.convert_ids_to_tokens(top_indices.tolist())
149
-
150
- return decoded, top_probs
151
-
152
- elif task == 'sst':
153
- print('input')
154
- inputs = self.tokenizer(text, return_tensors='pt', padding=False, truncation=True).to(self.device)
155
- print('output')
156
- with torch.no_grad():
157
- outputs = self.model(**inputs)
158
- logits = outputs.logits # shape: [1, 2]
159
- probs = F.softmax(logits, dim=1).squeeze()
160
-
161
- labels = ["negative", "positive"]
162
- print('ready to return')
163
- return labels, probs
164
-
165
- elif task == 'mnli':
166
- inputs = self.tokenizer(text, hypothesis, return_tensors='pt', padding=True, truncation=True).to(self.device)
167
-
168
- with torch.no_grad():
169
- outputs = self.model(**inputs)
170
- logits = outputs.logits
171
- probs = F.softmax(logits, dim=1).squeeze()
172
-
173
- labels = ["entailment", "neutral", "contradiction"]
174
- return labels, probs
175
-
176
-
177
- def get_all_grad_attn_matrix(self, task, sentence, hypothesis='', maskID = 0):
178
-
179
- print('GET GRAD:', task,'sentence',sentence, 'hypothesis', hypothesis)
180
-
181
-
182
-
183
- print('Tokenize')
184
- if task == 'mnli':
185
- inputs = self.tokenizer(sentence, hypothesis, return_tensors='pt', padding=False, truncation=True)
186
- elif task == 'mlm':
187
- inputs = self.tokenizer(sentence, return_tensors='pt', padding=False, truncation=True)
188
- if maskID is not None and 0 <= maskID < inputs['input_ids'].size(1):
189
- inputs['input_ids'][0][maskID] = self.tokenizer.mask_token_id
190
- else:
191
- raise ValueError(f"Invalid maskID {maskID} for input length {inputs['input_ids'].size(1)}")
192
- else:
193
- inputs = self.tokenizer(sentence, return_tensors='pt', padding=False, truncation=True)
194
- tokens = self.tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
195
-
196
- inputs = {k: v.to(self.device) for k, v in inputs.items()}
197
- print(inputs['input_ids'].shape)
198
- print(tokens,len(tokens))
199
- print('Input embeddings with grad')
200
- embedding_layer = self.model.bert.embeddings.word_embeddings
201
- inputs_embeds = embedding_layer(inputs["input_ids"])
202
- inputs_embeds.requires_grad_()
203
-
204
- print('Forward pass')
205
- outputs = self.model.bert(
206
- inputs_embeds=inputs_embeds,
207
- attention_mask=inputs["attention_mask"],
208
- output_attentions=True
209
- )
210
- attentions = outputs.attentions # list of [1, heads, seq, seq]
211
-
212
- print('Optional: store average attentions per layer')
213
- mean_attns = [a.squeeze(0).mean(dim=0).detach().cpu() for a in attentions]
214
-
215
- attn_matrices_all = []
216
- grad_matrices_all = []
217
- for target_layer in range(len(attentions)):
218
- grad_matrix, attn_matrix = self.get_grad_attn_matrix(inputs_embeds, attentions, mean_attns, target_layer)
219
- grad_matrices_all.append(grad_matrix.tolist())
220
- attn_matrices_all.append(attn_matrix.tolist())
221
- return grad_matrices_all, attn_matrices_all
222
-
223
- def get_grad_attn_matrix(self,inputs_embeds, attentions, mean_attns, target_layer):
224
-
225
-
226
- attn_matrix = mean_attns[target_layer]
227
- seq_len = attn_matrix.shape[0]
228
- attn_layer = attentions[target_layer].squeeze(0).mean(dim=0) # [seq, seq]
229
-
230
-
231
- print('computing gradnorms now')
232
-
233
-
234
- grad_norms_list = []
235
-
236
- for k in range(seq_len):
237
- scalar = attn_layer[:, k].sum() # ✅ total attention received by token k
238
-
239
- # Compute gradient: d scalar / d inputs_embeds
240
-
241
- grad = torch.autograd.grad(scalar, inputs_embeds, retain_graph=True)[0].squeeze(0) # shape: [seq, hidden]
242
-
243
- grad_norms = grad.norm(dim=1) # shape: [seq]
244
-
245
- grad_norms_list.append(grad_norms.unsqueeze(1)) # shape: [seq, 1]
246
-
247
-
248
- grad_matrix = torch.cat(grad_norms_list, dim=1) # shape: [seq, seq]
249
- print('ready to send!')
250
-
251
- grad_matrix = grad_matrix[:seq_len, :seq_len]
252
- attn_matrix = attn_matrix[:seq_len, :seq_len]
253
-
254
- #tokens = self.tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
255
-
256
- return grad_matrix, attn_matrix
257
-
258
-
259
- if __name__ == "__main__":
260
- import sys
261
-
262
- MODEL_CLASSES = {
263
- "bert": BERTVisualizer,
264
- "roberta": RoBERTaVisualizer,
265
- "distilbert": DistilBERTVisualizer,
266
- "bart": BARTVisualizer,
267
- }
268
-
269
- # Parse command-line args or fallback to default
270
- model_name = sys.argv[1] if len(sys.argv) > 1 else "bert"
271
- text = " ".join(sys.argv[2:]) if len(sys.argv) > 2 else "The quick brown fox jumps over the lazy dog."
272
-
273
- if model_name.lower() not in MODEL_CLASSES:
274
- print(f"Supported models: {list(MODEL_CLASSES.keys())}")
275
- sys.exit(1)
276
-
277
- # Instantiate the visualizer
278
- visualizer_class = MODEL_CLASSES[model_name.lower()]
279
- visualizer = visualizer_class()
280
-
281
- # Tokenize
282
- token_info = visualizer.tokenize(text)
283
-
284
- # Report
285
- print(f"\nModel: {model_name}")
286
- print(f"Num attention layers: {visualizer.num_attention_layers}")
287
- print(f"Tokens: {token_info['tokens']}")
288
- print(f"Input IDs: {token_info['input_ids'].tolist()}")
289
- print(f"Attention mask: {token_info['attention_mask'].tolist()}")
290
-
291
-
292
- """
293
- usage for debug:
294
- python your_file.py bert "The rain in Spain falls mainly on the plain."
295
  """
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import BertTokenizer
4
+
5
+ from models import TransformerVisualizer
6
+
7
+ from transformers import (
8
+ BertTokenizer,
9
+ BertForMaskedLM,
10
+ BertForSequenceClassification,
11
+ )
12
+ import torch.nn.functional as F
13
+ import os, time
14
+
15
+ CACHE_DIR = "/data/hf_cache"
16
+
17
+
18
+ class BERTVisualizer(TransformerVisualizer):
19
+ def __init__(self,task):
20
+ super().__init__()
21
+ self.task = task
22
+ print(task,'BERT VIS START')
23
+
24
+ TOKENIZER = 'bert-base-uncased'
25
+ LOCAL_PATH = os.path.join(CACHE_DIR, "tokenizers",TOKENIZER)
26
+
27
+
28
+ self.tokenizer = BertTokenizer.from_pretrained(LOCAL_PATH, local_files_only=True)
29
+ """
30
+ try:
31
+ self.tokenizer = BertTokenizer.from_pretrained(LOCAL_PATH, local_files_only=True)
32
+ except Exception as e:
33
+ self.tokenizer = BertTokenizer.from_pretrained(TOKENIZER)
34
+ self.tokenizer.save_pretrained(LOCAL_PATH)
35
+ """
36
+
37
+
38
+ print('finding model', self.task)
39
+ if self.task == 'mlm':
40
+
41
+ MODEL = 'bert-base-uncased'
42
+ LOCAL_PATH = os.path.join(CACHE_DIR, "models",MODEL)
43
+
44
+ self.model = BertForMaskedLM.from_pretrained( LOCAL_PATH, local_files_only=True, attn_implementation="eager" ).to(self.device)
45
+ """
46
+ try:
47
+ self.model = BertForMaskedLM.from_pretrained( LOCAL_PATH, local_files_only=True, attn_implementation="eager" ).to(self.device)
48
+ except Exception as e:
49
+ self.model = BertForMaskedLM.from_pretrained( MODEL, attn_implementation="eager" ).to(self.device)
50
+ self.model.save_pretrained(LOCAL_PATH)
51
+ """
52
+ elif self.task == 'sst':
53
+ MODEL = "textattack_bert-base-uncased-SST-2"
54
+ LOCAL_PATH = os.path.join(CACHE_DIR, "models",MODEL)
55
+
56
+ self.model = BertForSequenceClassification.from_pretrained( LOCAL_PATH, local_files_only=True, device_map=None )
57
+ """
58
+ try:
59
+ self.model = BertForSequenceClassification.from_pretrained( LOCAL_PATH, local_files_only=True, device_map=None )
60
+ except Exception as e:
61
+ self.model = BertForSequenceClassification.from_pretrained( MODEL, device_map=None )
62
+ self.model.save_pretrained(LOCAL_PATH)
63
+ """
64
+
65
+ elif self.task == 'mnli':
66
+ MODEL = 'textattack_bert-base-uncased-MNLI'
67
+
68
+
69
+ LOCAL_PATH = os.path.join(CACHE_DIR, "models",MODEL)
70
+
71
+ self.model = BertForSequenceClassification.from_pretrained( LOCAL_PATH, local_files_only=True, device_map=None )
72
+ """
73
+ try:
74
+ self.model = BertForSequenceClassification.from_pretrained( LOCAL_PATH, local_files_only=True, device_map=None )
75
+ except Exception as e:
76
+ self.model = BertForSequenceClassification.from_pretrained( MODEL, device_map=None)
77
+ self.model.save_pretrained(LOCAL_PATH)
78
+ """
79
+
80
+
81
+
82
+ else:
83
+ raise ValueError(f"Unsupported task: {self.task}")
84
+
85
+
86
+
87
+ print('model found')
88
+ self.model.to(self.device)
89
+ # Force materialization of all layers (avoids meta device errors)
90
+ with torch.no_grad():
91
+ dummy_ids = torch.tensor([[0, 1]], device=self.device)
92
+ dummy_mask = torch.tensor([[1, 1]], device=self.device)
93
+ _ = self.model(input_ids=dummy_ids, attention_mask=dummy_mask)
94
+ self.model.eval()
95
+ print('self model eval')
96
+ self.num_attention_layers = len(self.model.bert.encoder.layer)
97
+ print('init finished')
98
+
99
+ def tokenize(self, text, hypothesis = ''):
100
+ print('TTTokenize',text,'H:', hypothesis)
101
+ if len(hypothesis) == 0:
102
+ encoded = self.tokenizer(text, return_tensors='pt', return_attention_mask=True)
103
+ else:
104
+ encoded = self.tokenizer(text, hypothesis, return_tensors='pt', return_attention_mask=True)
105
+ input_ids = encoded['input_ids'].to(self.device)
106
+ attention_mask = encoded['attention_mask'].to(self.device)
107
+ tokens = self.tokenizer.convert_ids_to_tokens(input_ids[0])
108
+ return {
109
+ 'input_ids': input_ids,
110
+ 'attention_mask': attention_mask,
111
+ 'tokens': tokens
112
+ }
113
+
114
+
115
+ def predict(self, task, text, hypothesis='', maskID = None):
116
+
117
+ print(task,text,hypothesis)
118
+
119
+
120
+
121
+ if task == 'mlm':
122
+
123
+ # Tokenize and find [MASK] position
124
+ print('Tokenize and find [MASK] position')
125
+ inputs = self.tokenizer(text, return_tensors='pt', padding=False, truncation=True)
126
+ if maskID is not None and 0 <= maskID < inputs['input_ids'].size(1):
127
+ inputs['input_ids'][0][maskID] = self.tokenizer.mask_token_id
128
+ mask_index = maskID
129
+ else:
130
+ raise ValueError(f"Invalid maskID {maskID} for input length {inputs['input_ids'].size(1)}")
131
+
132
+
133
+
134
+ # Move to device
135
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
136
+
137
+ # Get embeddings
138
+ embedding_layer = self.model.bert.embeddings.word_embeddings
139
+ inputs_embeds = embedding_layer(inputs['input_ids'])
140
+
141
+ # Forward through BERT encoder
142
+
143
+ hidden_states = self.model.bert(inputs_embeds=inputs_embeds,
144
+ attention_mask=inputs['attention_mask']).last_hidden_state
145
+
146
+ # Predict logits via MLM head
147
+ logits = self.model.cls(hidden_states)
148
+ mask_logits = logits[0, mask_index]
149
+
150
+ top_probs, top_indices = torch.topk(mask_logits, k=10, dim=-1)
151
+ top_probs = F.softmax(top_probs, dim=-1)
152
+ decoded = self.tokenizer.convert_ids_to_tokens(top_indices.tolist())
153
+
154
+ return decoded, top_probs
155
+
156
+ elif task == 'sst':
157
+ print('input')
158
+ inputs = self.tokenizer(text, return_tensors='pt', padding=False, truncation=True).to(self.device)
159
+ print('output')
160
+ with torch.no_grad():
161
+ outputs = self.model(**inputs)
162
+ logits = outputs.logits # shape: [1, 2]
163
+ probs = F.softmax(logits, dim=1).squeeze()
164
+
165
+ labels = ["negative", "positive"]
166
+ print('ready to return')
167
+ return labels, probs
168
+
169
+ elif task == 'mnli':
170
+ inputs = self.tokenizer(text, hypothesis, return_tensors='pt', padding=True, truncation=True).to(self.device)
171
+
172
+ with torch.no_grad():
173
+ outputs = self.model(**inputs)
174
+ logits = outputs.logits
175
+ probs = F.softmax(logits, dim=1).squeeze()
176
+
177
+ labels = ["entailment", "neutral", "contradiction"]
178
+ return labels, probs
179
+
180
+
181
+ def get_all_grad_attn_matrix(self, task, sentence, hypothesis='', maskID = 0):
182
+
183
+ print('GET GRAD:', task,'sentence',sentence, 'hypothesis', hypothesis)
184
+
185
+
186
+
187
+ print('Tokenize')
188
+ if task == 'mnli':
189
+ inputs = self.tokenizer(sentence, hypothesis, return_tensors='pt', padding=False, truncation=True)
190
+ elif task == 'mlm':
191
+ inputs = self.tokenizer(sentence, return_tensors='pt', padding=False, truncation=True)
192
+ if maskID is not None and 0 <= maskID < inputs['input_ids'].size(1):
193
+ inputs['input_ids'][0][maskID] = self.tokenizer.mask_token_id
194
+ else:
195
+ raise ValueError(f"Invalid maskID {maskID} for input length {inputs['input_ids'].size(1)}")
196
+ else:
197
+ inputs = self.tokenizer(sentence, return_tensors='pt', padding=False, truncation=True)
198
+ tokens = self.tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
199
+
200
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
201
+ print(inputs['input_ids'].shape)
202
+ print(tokens,len(tokens))
203
+ print('Input embeddings with grad')
204
+ embedding_layer = self.model.bert.embeddings.word_embeddings
205
+ inputs_embeds = embedding_layer(inputs["input_ids"])
206
+ inputs_embeds.requires_grad_()
207
+
208
+ print('Forward pass')
209
+ outputs = self.model.bert(
210
+ inputs_embeds=inputs_embeds,
211
+ attention_mask=inputs["attention_mask"],
212
+ output_attentions=True
213
+ )
214
+
215
+ attentions = outputs.attentions # list of [1, heads, seq, seq]
216
+
217
+ print('Average attentions per layer')
218
+ mean_attns = [a.squeeze(0).mean(dim=0).detach().cpu() for a in attentions]
219
+
220
+
221
+ def scalar_outputs(inputs_embeds):
222
+
223
+ outputs = self.model.bert(
224
+ inputs_embeds=inputs_embeds,
225
+ attention_mask=inputs["attention_mask"],
226
+ output_attentions=True
227
+ )
228
+ attentions = outputs.attentions
229
+ attentions_condensed = [a.mean(dim=0).mean(dim=0).sum(dim=0) for a in attentions]
230
+ attentions_condensed= torch.vstack(attentions_condensed)
231
+ return attentions_condensed
232
+
233
+ start = time.time()
234
+ jac = torch.autograd.functional.jacobian(scalar_outputs, inputs_embeds).to(torch.float16)
235
+ print('time to get jacobian: ', time.time()-start)
236
+ jac = jac.norm(dim=-1).squeeze(dim=2)
237
+ seq_len = jac.shape[0]
238
+ grad_matrices_all = [jac[ii,:,:].tolist() for ii in range(seq_len)]
239
+
240
+
241
+ attn_matrices_all = []
242
+ for target_layer in range(len(attentions)):
243
+ #grad_matrix, attn_matrix = self.get_grad_attn_matrix(inputs_embeds, attentions, mean_attns, target_layer)
244
+
245
+ attn_matrix = mean_attns[target_layer]
246
+ seq_len = attn_matrix.shape[0]
247
+ attn_matrix = attn_matrix[:seq_len, :seq_len]
248
+ attn_matrices_all.append(attn_matrix.tolist())
249
+
250
+
251
+
252
+ return grad_matrices_all, attn_matrices_all
253
+
254
+
255
+
256
+
257
+ if __name__ == "__main__":
258
+ import sys
259
+
260
+ MODEL_CLASSES = {
261
+ "bert": BERTVisualizer,
262
+ "roberta": RoBERTaVisualizer,
263
+ "distilbert": DistilBERTVisualizer,
264
+ "bart": BARTVisualizer,
265
+ }
266
+
267
+ # Parse command-line args or fallback to default
268
+ model_name = sys.argv[1] if len(sys.argv) > 1 else "bert"
269
+ text = " ".join(sys.argv[2:]) if len(sys.argv) > 2 else "The quick brown fox jumps over the lazy dog."
270
+
271
+ if model_name.lower() not in MODEL_CLASSES:
272
+ print(f"Supported models: {list(MODEL_CLASSES.keys())}")
273
+ sys.exit(1)
274
+
275
+ # Instantiate the visualizer
276
+ visualizer_class = MODEL_CLASSES[model_name.lower()]
277
+ visualizer = visualizer_class()
278
+
279
+ # Tokenize
280
+ token_info = visualizer.tokenize(text)
281
+
282
+ # Report
283
+ print(f"\nModel: {model_name}")
284
+ print(f"Num attention layers: {visualizer.num_attention_layers}")
285
+ print(f"Tokens: {token_info['tokens']}")
286
+ print(f"Input IDs: {token_info['input_ids'].tolist()}")
287
+ print(f"Attention mask: {token_info['attention_mask'].tolist()}")
288
+
289
+
290
+ """
291
+ usage for debug:
292
+ python your_file.py bert "The rain in Spain falls mainly on the plain."
 
 
293
  """