yifan0sun commited on
Commit
04ccab0
·
verified ·
1 Parent(s): 27ab4ec

Upload 5 files

Browse files
Files changed (5) hide show
  1. BERTmodel.py +294 -289
  2. DISTILLBERTmodel.py +257 -253
  3. ROBERTAmodel.py +207 -199
  4. models.py +15 -15
  5. server.py +349 -370
BERTmodel.py CHANGED
@@ -1,290 +1,295 @@
1
- import torch
2
- import torch.nn as nn
3
- from transformers import BertTokenizer
4
-
5
- from models import TransformerVisualizer
6
-
7
- from transformers import (
8
- BertTokenizer,
9
- BertForMaskedLM,
10
- BertForSequenceClassification,
11
- )
12
- import torch.nn.functional as F
13
- import os
14
-
15
- CACHE_DIR = "/data/hf_cache"
16
-
17
-
18
- class BERTVisualizer(TransformerVisualizer):
19
- def __init__(self,task):
20
- super().__init__()
21
- self.task = task
22
- print(task,'BERTVIS START')
23
-
24
- TOKENIZER = 'bert-base-uncased'
25
- LOCAL_PATH = os.path.join(CACHE_DIR, "tokenizers",TOKENIZER.replace("/", "_"))
26
-
27
-
28
-
29
- try:
30
- self.tokenizer = BertTokenizer.from_pretrained(LOCAL_PATH, local_files_only=True)
31
- except Exception as e:
32
- self.tokenizer = BertTokenizer.from_pretrained(TOKENIZER)
33
- self.tokenizer.save_pretrained(LOCAL_PATH)
34
-
35
-
36
-
37
- print('finding model', self.task)
38
- if self.task == 'mlm':
39
-
40
- MODEL = 'bert-base-uncased'
41
- LOCAL_PATH = os.path.join(CACHE_DIR, "models",MODEL.replace("/", "_"))
42
-
43
- try:
44
- self.model = BertForMaskedLM.from_pretrained( LOCAL_PATH, local_files_only=True, attn_implementation="eager" ).to(self.device)
45
- except Exception as e:
46
- self.model = BertForMaskedLM.from_pretrained( MODEL, attn_implementation="eager" ).to(self.device)
47
- self.model.save_pretrained(LOCAL_PATH)
48
-
49
- elif self.task == 'sst':
50
- MODEL = "textattack/bert-base-uncased-SST-2"
51
- LOCAL_PATH = os.path.join(CACHE_DIR, "models",MODEL.replace("/", "_"))
52
-
53
-
54
- try:
55
- self.model = BertForSequenceClassification.from_pretrained( LOCAL_PATH, local_files_only=True, device_map=None )
56
- except Exception as e:
57
- self.model = BertForSequenceClassification.from_pretrained( MODEL, device_map=None )
58
- self.model.save_pretrained(LOCAL_PATH)
59
-
60
-
61
- elif self.task == 'mnli':
62
- MODEL = 'textattack/bert-base-uncased-MNLI'
63
-
64
-
65
- LOCAL_PATH = os.path.join(CACHE_DIR, "models",MODEL.replace("/", "_"))
66
-
67
-
68
- try:
69
- self.model = BertForSequenceClassification.from_pretrained( LOCAL_PATH, local_files_only=True, device_map=None )
70
- except Exception as e:
71
- self.model = BertForSequenceClassification.from_pretrained( MODEL, device_map=None)
72
- self.model.save_pretrained(LOCAL_PATH)
73
-
74
-
75
-
76
-
77
- else:
78
- raise ValueError(f"Unsupported task: {self.task}")
79
-
80
-
81
-
82
- print('model found')
83
- #self.model.to(self.device)
84
- print('self device junk')
85
- self.model.eval()
86
- print('self model eval')
87
- self.num_attention_layers = len(self.model.bert.encoder.layer)
88
- print('init finished')
89
-
90
- def tokenize(self, text, hypothesis = ''):
91
- print('TTTokenize',text,'H:', hypothesis)
92
- if len(hypothesis) == 0:
93
- encoded = self.tokenizer(text, return_tensors='pt', return_attention_mask=True)
94
- else:
95
- encoded = self.tokenizer(text, hypothesis, return_tensors='pt', return_attention_mask=True)
96
- input_ids = encoded['input_ids'].to(self.device)
97
- attention_mask = encoded['attention_mask'].to(self.device)
98
- tokens = self.tokenizer.convert_ids_to_tokens(input_ids[0])
99
- return {
100
- 'input_ids': input_ids,
101
- 'attention_mask': attention_mask,
102
- 'tokens': tokens
103
- }
104
-
105
-
106
- def predict(self, task, text, hypothesis='', maskID = None):
107
-
108
- print(task,text,hypothesis)
109
-
110
-
111
-
112
- if task == 'mlm':
113
-
114
- # Tokenize and find [MASK] position
115
- print('Tokenize and find [MASK] position')
116
- inputs = self.tokenizer(text, return_tensors='pt', padding=False, truncation=True)
117
- if maskID is not None and 0 <= maskID < inputs['input_ids'].size(1):
118
- inputs['input_ids'][0][maskID] = self.tokenizer.mask_token_id
119
- mask_index = maskID
120
- else:
121
- raise ValueError(f"Invalid maskID {maskID} for input length {inputs['input_ids'].size(1)}")
122
-
123
-
124
-
125
- # Move to device
126
- inputs = {k: v.to(self.device) for k, v in inputs.items()}
127
-
128
- # Get embeddings
129
- embedding_layer = self.model.bert.embeddings.word_embeddings
130
- inputs_embeds = embedding_layer(inputs['input_ids'])
131
-
132
- # Forward through BERT encoder
133
-
134
- hidden_states = self.model.bert(inputs_embeds=inputs_embeds,
135
- attention_mask=inputs['attention_mask']).last_hidden_state
136
-
137
- # Predict logits via MLM head
138
- logits = self.model.cls(hidden_states)
139
- mask_logits = logits[0, mask_index]
140
-
141
- top_probs, top_indices = torch.topk(mask_logits, k=10, dim=-1)
142
- top_probs = F.softmax(top_probs, dim=-1)
143
- decoded = self.tokenizer.convert_ids_to_tokens(top_indices.tolist())
144
-
145
- return decoded, top_probs
146
-
147
- elif task == 'sst':
148
- print('input')
149
- inputs = self.tokenizer(text, return_tensors='pt', padding=False, truncation=True).to(self.device)
150
- print('output')
151
- with torch.no_grad():
152
- outputs = self.model(**inputs)
153
- logits = outputs.logits # shape: [1, 2]
154
- probs = F.softmax(logits, dim=1).squeeze()
155
-
156
- labels = ["negative", "positive"]
157
- print('ready to return')
158
- return labels, probs
159
-
160
- elif task == 'mnli':
161
- inputs = self.tokenizer(text, hypothesis, return_tensors='pt', padding=True, truncation=True).to(self.device)
162
-
163
- with torch.no_grad():
164
- outputs = self.model(**inputs)
165
- logits = outputs.logits
166
- probs = F.softmax(logits, dim=1).squeeze()
167
-
168
- labels = ["entailment", "neutral", "contradiction"]
169
- return labels, probs
170
-
171
-
172
- def get_all_grad_attn_matrix(self, task, sentence, hypothesis='', maskID = 0):
173
-
174
- print('GET GRAD:', task,'sentence',sentence, 'hypothesis', hypothesis)
175
-
176
-
177
-
178
- print('Tokenize')
179
- if task == 'mnli':
180
- inputs = self.tokenizer(sentence, hypothesis, return_tensors='pt', padding=False, truncation=True)
181
- elif task == 'mlm':
182
- inputs = self.tokenizer(sentence, return_tensors='pt', padding=False, truncation=True)
183
- if maskID is not None and 0 <= maskID < inputs['input_ids'].size(1):
184
- inputs['input_ids'][0][maskID] = self.tokenizer.mask_token_id
185
- else:
186
- raise ValueError(f"Invalid maskID {maskID} for input length {inputs['input_ids'].size(1)}")
187
- else:
188
- inputs = self.tokenizer(sentence, return_tensors='pt', padding=False, truncation=True)
189
- tokens = self.tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
190
-
191
- inputs = {k: v.to(self.device) for k, v in inputs.items()}
192
- print(inputs['input_ids'].shape)
193
- print(tokens,len(tokens))
194
- print('Input embeddings with grad')
195
- embedding_layer = self.model.bert.embeddings.word_embeddings
196
- inputs_embeds = embedding_layer(inputs["input_ids"])
197
- inputs_embeds.requires_grad_()
198
-
199
- print('Forward pass')
200
- outputs = self.model.bert(
201
- inputs_embeds=inputs_embeds,
202
- attention_mask=inputs["attention_mask"],
203
- output_attentions=True
204
- )
205
- attentions = outputs.attentions # list of [1, heads, seq, seq]
206
-
207
- print('Optional: store average attentions per layer')
208
- mean_attns = [a.squeeze(0).mean(dim=0).detach().cpu() for a in attentions]
209
-
210
- attn_matrices_all = []
211
- grad_matrices_all = []
212
- for target_layer in range(len(attentions)):
213
- grad_matrix, attn_matrix = self.get_grad_attn_matrix(inputs_embeds, attentions, mean_attns, target_layer)
214
- grad_matrices_all.append(grad_matrix.tolist())
215
- attn_matrices_all.append(attn_matrix.tolist())
216
- return grad_matrices_all, attn_matrices_all
217
-
218
- def get_grad_attn_matrix(self,inputs_embeds, attentions, mean_attns, target_layer):
219
-
220
-
221
- attn_matrix = mean_attns[target_layer]
222
- seq_len = attn_matrix.shape[0]
223
- attn_layer = attentions[target_layer].squeeze(0).mean(dim=0) # [seq, seq]
224
-
225
-
226
- print('computing gradnorms now')
227
-
228
-
229
- grad_norms_list = []
230
-
231
- for k in range(seq_len):
232
- scalar = attn_layer[:, k].sum() # ✅ total attention received by token k
233
-
234
- # Compute gradient: d scalar / d inputs_embeds
235
-
236
- grad = torch.autograd.grad(scalar, inputs_embeds, retain_graph=True)[0].squeeze(0) # shape: [seq, hidden]
237
-
238
- grad_norms = grad.norm(dim=1) # shape: [seq]
239
-
240
- grad_norms_list.append(grad_norms.unsqueeze(1)) # shape: [seq, 1]
241
-
242
-
243
- grad_matrix = torch.cat(grad_norms_list, dim=1) # shape: [seq, seq]
244
- print('ready to send!')
245
-
246
- grad_matrix = grad_matrix[:seq_len, :seq_len]
247
- attn_matrix = attn_matrix[:seq_len, :seq_len]
248
-
249
- #tokens = self.tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
250
-
251
- return grad_matrix, attn_matrix
252
-
253
-
254
- if __name__ == "__main__":
255
- import sys
256
-
257
- MODEL_CLASSES = {
258
- "bert": BERTVisualizer,
259
- "roberta": RoBERTaVisualizer,
260
- "distilbert": DistilBERTVisualizer,
261
- "bart": BARTVisualizer,
262
- }
263
-
264
- # Parse command-line args or fallback to default
265
- model_name = sys.argv[1] if len(sys.argv) > 1 else "bert"
266
- text = " ".join(sys.argv[2:]) if len(sys.argv) > 2 else "The quick brown fox jumps over the lazy dog."
267
-
268
- if model_name.lower() not in MODEL_CLASSES:
269
- print(f"Supported models: {list(MODEL_CLASSES.keys())}")
270
- sys.exit(1)
271
-
272
- # Instantiate the visualizer
273
- visualizer_class = MODEL_CLASSES[model_name.lower()]
274
- visualizer = visualizer_class()
275
-
276
- # Tokenize
277
- token_info = visualizer.tokenize(text)
278
-
279
- # Report
280
- print(f"\nModel: {model_name}")
281
- print(f"Num attention layers: {visualizer.num_attention_layers}")
282
- print(f"Tokens: {token_info['tokens']}")
283
- print(f"Input IDs: {token_info['input_ids'].tolist()}")
284
- print(f"Attention mask: {token_info['attention_mask'].tolist()}")
285
-
286
-
287
- """
288
- usage for debug:
289
- python your_file.py bert "The rain in Spain falls mainly on the plain."
 
 
 
 
 
290
  """
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import BertTokenizer
4
+
5
+ from models import TransformerVisualizer
6
+
7
+ from transformers import (
8
+ BertTokenizer,
9
+ BertForMaskedLM,
10
+ BertForSequenceClassification,
11
+ )
12
+ import torch.nn.functional as F
13
+ import os
14
+
15
+ CACHE_DIR = "/data/hf_cache"
16
+
17
+
18
+ class BERTVisualizer(TransformerVisualizer):
19
+ def __init__(self,task):
20
+ super().__init__()
21
+ self.task = task
22
+ print(task,'BERT VIS START')
23
+
24
+ TOKENIZER = 'bert-base-uncased'
25
+ LOCAL_PATH = os.path.join(CACHE_DIR, "tokenizers",TOKENIZER)
26
+
27
+
28
+ self.tokenizer = BertTokenizer.from_pretrained(LOCAL_PATH, local_files_only=True)
29
+ """
30
+ try:
31
+ self.tokenizer = BertTokenizer.from_pretrained(LOCAL_PATH, local_files_only=True)
32
+ except Exception as e:
33
+ self.tokenizer = BertTokenizer.from_pretrained(TOKENIZER)
34
+ self.tokenizer.save_pretrained(LOCAL_PATH)
35
+ """
36
+
37
+
38
+ print('finding model', self.task)
39
+ if self.task == 'mlm':
40
+
41
+ MODEL = 'bert-base-uncased'
42
+ LOCAL_PATH = os.path.join(CACHE_DIR, "models",MODEL)
43
+
44
+ self.model = BertForMaskedLM.from_pretrained( LOCAL_PATH, local_files_only=True, attn_implementation="eager" ).to(self.device)
45
+ """
46
+ try:
47
+ self.model = BertForMaskedLM.from_pretrained( LOCAL_PATH, local_files_only=True, attn_implementation="eager" ).to(self.device)
48
+ except Exception as e:
49
+ self.model = BertForMaskedLM.from_pretrained( MODEL, attn_implementation="eager" ).to(self.device)
50
+ self.model.save_pretrained(LOCAL_PATH)
51
+ """
52
+ elif self.task == 'sst':
53
+ MODEL = "textattack_bert-base-uncased-SST-2"
54
+ LOCAL_PATH = os.path.join(CACHE_DIR, "models",MODEL)
55
+
56
+ self.model = BertForSequenceClassification.from_pretrained( LOCAL_PATH, local_files_only=True, device_map=None )
57
+ """
58
+ try:
59
+ self.model = BertForSequenceClassification.from_pretrained( LOCAL_PATH, local_files_only=True, device_map=None )
60
+ except Exception as e:
61
+ self.model = BertForSequenceClassification.from_pretrained( MODEL, device_map=None )
62
+ self.model.save_pretrained(LOCAL_PATH)
63
+ """
64
+
65
+ elif self.task == 'mnli':
66
+ MODEL = 'textattack_bert-base-uncased-MNLI'
67
+
68
+
69
+ LOCAL_PATH = os.path.join(CACHE_DIR, "models",MODEL)
70
+
71
+ self.model = BertForSequenceClassification.from_pretrained( LOCAL_PATH, local_files_only=True, device_map=None )
72
+ """
73
+ try:
74
+ self.model = BertForSequenceClassification.from_pretrained( LOCAL_PATH, local_files_only=True, device_map=None )
75
+ except Exception as e:
76
+ self.model = BertForSequenceClassification.from_pretrained( MODEL, device_map=None)
77
+ self.model.save_pretrained(LOCAL_PATH)
78
+ """
79
+
80
+
81
+
82
+ else:
83
+ raise ValueError(f"Unsupported task: {self.task}")
84
+
85
+
86
+
87
+ print('model found')
88
+ #self.model.to(self.device)
89
+ print('self device junk')
90
+ self.model.eval()
91
+ print('self model eval')
92
+ self.num_attention_layers = len(self.model.bert.encoder.layer)
93
+ print('init finished')
94
+
95
+ def tokenize(self, text, hypothesis = ''):
96
+ print('TTTokenize',text,'H:', hypothesis)
97
+ if len(hypothesis) == 0:
98
+ encoded = self.tokenizer(text, return_tensors='pt', return_attention_mask=True)
99
+ else:
100
+ encoded = self.tokenizer(text, hypothesis, return_tensors='pt', return_attention_mask=True)
101
+ input_ids = encoded['input_ids'].to(self.device)
102
+ attention_mask = encoded['attention_mask'].to(self.device)
103
+ tokens = self.tokenizer.convert_ids_to_tokens(input_ids[0])
104
+ return {
105
+ 'input_ids': input_ids,
106
+ 'attention_mask': attention_mask,
107
+ 'tokens': tokens
108
+ }
109
+
110
+
111
+ def predict(self, task, text, hypothesis='', maskID = None):
112
+
113
+ print(task,text,hypothesis)
114
+
115
+
116
+
117
+ if task == 'mlm':
118
+
119
+ # Tokenize and find [MASK] position
120
+ print('Tokenize and find [MASK] position')
121
+ inputs = self.tokenizer(text, return_tensors='pt', padding=False, truncation=True)
122
+ if maskID is not None and 0 <= maskID < inputs['input_ids'].size(1):
123
+ inputs['input_ids'][0][maskID] = self.tokenizer.mask_token_id
124
+ mask_index = maskID
125
+ else:
126
+ raise ValueError(f"Invalid maskID {maskID} for input length {inputs['input_ids'].size(1)}")
127
+
128
+
129
+
130
+ # Move to device
131
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
132
+
133
+ # Get embeddings
134
+ embedding_layer = self.model.bert.embeddings.word_embeddings
135
+ inputs_embeds = embedding_layer(inputs['input_ids'])
136
+
137
+ # Forward through BERT encoder
138
+
139
+ hidden_states = self.model.bert(inputs_embeds=inputs_embeds,
140
+ attention_mask=inputs['attention_mask']).last_hidden_state
141
+
142
+ # Predict logits via MLM head
143
+ logits = self.model.cls(hidden_states)
144
+ mask_logits = logits[0, mask_index]
145
+
146
+ top_probs, top_indices = torch.topk(mask_logits, k=10, dim=-1)
147
+ top_probs = F.softmax(top_probs, dim=-1)
148
+ decoded = self.tokenizer.convert_ids_to_tokens(top_indices.tolist())
149
+
150
+ return decoded, top_probs
151
+
152
+ elif task == 'sst':
153
+ print('input')
154
+ inputs = self.tokenizer(text, return_tensors='pt', padding=False, truncation=True).to(self.device)
155
+ print('output')
156
+ with torch.no_grad():
157
+ outputs = self.model(**inputs)
158
+ logits = outputs.logits # shape: [1, 2]
159
+ probs = F.softmax(logits, dim=1).squeeze()
160
+
161
+ labels = ["negative", "positive"]
162
+ print('ready to return')
163
+ return labels, probs
164
+
165
+ elif task == 'mnli':
166
+ inputs = self.tokenizer(text, hypothesis, return_tensors='pt', padding=True, truncation=True).to(self.device)
167
+
168
+ with torch.no_grad():
169
+ outputs = self.model(**inputs)
170
+ logits = outputs.logits
171
+ probs = F.softmax(logits, dim=1).squeeze()
172
+
173
+ labels = ["entailment", "neutral", "contradiction"]
174
+ return labels, probs
175
+
176
+
177
+ def get_all_grad_attn_matrix(self, task, sentence, hypothesis='', maskID = 0):
178
+
179
+ print('GET GRAD:', task,'sentence',sentence, 'hypothesis', hypothesis)
180
+
181
+
182
+
183
+ print('Tokenize')
184
+ if task == 'mnli':
185
+ inputs = self.tokenizer(sentence, hypothesis, return_tensors='pt', padding=False, truncation=True)
186
+ elif task == 'mlm':
187
+ inputs = self.tokenizer(sentence, return_tensors='pt', padding=False, truncation=True)
188
+ if maskID is not None and 0 <= maskID < inputs['input_ids'].size(1):
189
+ inputs['input_ids'][0][maskID] = self.tokenizer.mask_token_id
190
+ else:
191
+ raise ValueError(f"Invalid maskID {maskID} for input length {inputs['input_ids'].size(1)}")
192
+ else:
193
+ inputs = self.tokenizer(sentence, return_tensors='pt', padding=False, truncation=True)
194
+ tokens = self.tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
195
+
196
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
197
+ print(inputs['input_ids'].shape)
198
+ print(tokens,len(tokens))
199
+ print('Input embeddings with grad')
200
+ embedding_layer = self.model.bert.embeddings.word_embeddings
201
+ inputs_embeds = embedding_layer(inputs["input_ids"])
202
+ inputs_embeds.requires_grad_()
203
+
204
+ print('Forward pass')
205
+ outputs = self.model.bert(
206
+ inputs_embeds=inputs_embeds,
207
+ attention_mask=inputs["attention_mask"],
208
+ output_attentions=True
209
+ )
210
+ attentions = outputs.attentions # list of [1, heads, seq, seq]
211
+
212
+ print('Optional: store average attentions per layer')
213
+ mean_attns = [a.squeeze(0).mean(dim=0).detach().cpu() for a in attentions]
214
+
215
+ attn_matrices_all = []
216
+ grad_matrices_all = []
217
+ for target_layer in range(len(attentions)):
218
+ grad_matrix, attn_matrix = self.get_grad_attn_matrix(inputs_embeds, attentions, mean_attns, target_layer)
219
+ grad_matrices_all.append(grad_matrix.tolist())
220
+ attn_matrices_all.append(attn_matrix.tolist())
221
+ return grad_matrices_all, attn_matrices_all
222
+
223
+ def get_grad_attn_matrix(self,inputs_embeds, attentions, mean_attns, target_layer):
224
+
225
+
226
+ attn_matrix = mean_attns[target_layer]
227
+ seq_len = attn_matrix.shape[0]
228
+ attn_layer = attentions[target_layer].squeeze(0).mean(dim=0) # [seq, seq]
229
+
230
+
231
+ print('computing gradnorms now')
232
+
233
+
234
+ grad_norms_list = []
235
+
236
+ for k in range(seq_len):
237
+ scalar = attn_layer[:, k].sum() # ✅ total attention received by token k
238
+
239
+ # Compute gradient: d scalar / d inputs_embeds
240
+
241
+ grad = torch.autograd.grad(scalar, inputs_embeds, retain_graph=True)[0].squeeze(0) # shape: [seq, hidden]
242
+
243
+ grad_norms = grad.norm(dim=1) # shape: [seq]
244
+
245
+ grad_norms_list.append(grad_norms.unsqueeze(1)) # shape: [seq, 1]
246
+
247
+
248
+ grad_matrix = torch.cat(grad_norms_list, dim=1) # shape: [seq, seq]
249
+ print('ready to send!')
250
+
251
+ grad_matrix = grad_matrix[:seq_len, :seq_len]
252
+ attn_matrix = attn_matrix[:seq_len, :seq_len]
253
+
254
+ #tokens = self.tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
255
+
256
+ return grad_matrix, attn_matrix
257
+
258
+
259
+ if __name__ == "__main__":
260
+ import sys
261
+
262
+ MODEL_CLASSES = {
263
+ "bert": BERTVisualizer,
264
+ "roberta": RoBERTaVisualizer,
265
+ "distilbert": DistilBERTVisualizer,
266
+ "bart": BARTVisualizer,
267
+ }
268
+
269
+ # Parse command-line args or fallback to default
270
+ model_name = sys.argv[1] if len(sys.argv) > 1 else "bert"
271
+ text = " ".join(sys.argv[2:]) if len(sys.argv) > 2 else "The quick brown fox jumps over the lazy dog."
272
+
273
+ if model_name.lower() not in MODEL_CLASSES:
274
+ print(f"Supported models: {list(MODEL_CLASSES.keys())}")
275
+ sys.exit(1)
276
+
277
+ # Instantiate the visualizer
278
+ visualizer_class = MODEL_CLASSES[model_name.lower()]
279
+ visualizer = visualizer_class()
280
+
281
+ # Tokenize
282
+ token_info = visualizer.tokenize(text)
283
+
284
+ # Report
285
+ print(f"\nModel: {model_name}")
286
+ print(f"Num attention layers: {visualizer.num_attention_layers}")
287
+ print(f"Tokens: {token_info['tokens']}")
288
+ print(f"Input IDs: {token_info['input_ids'].tolist()}")
289
+ print(f"Attention mask: {token_info['attention_mask'].tolist()}")
290
+
291
+
292
+ """
293
+ usage for debug:
294
+ python your_file.py bert "The rain in Spain falls mainly on the plain."
295
  """
DISTILLBERTmodel.py CHANGED
@@ -1,254 +1,258 @@
1
- import torch
2
- import torch.nn.functional as F
3
-
4
-
5
-
6
- import os
7
- from models import TransformerVisualizer
8
-
9
- from transformers import (
10
- DistilBertTokenizer,
11
- DistilBertForMaskedLM, DistilBertForSequenceClassification
12
- )
13
-
14
- CACHE_DIR = "/data/hf_cache"
15
- class DistilBERTVisualizer(TransformerVisualizer):
16
- def __init__(self, task):
17
- super().__init__()
18
- self.task = task
19
-
20
-
21
- TOKENIZER = 'distilbert-base-uncased'
22
- LOCAL_PATH = os.path.join(CACHE_DIR, "tokenizers",TOKENIZER.replace("/", "_"))
23
-
24
- try:
25
- self.tokenizer = DistilBertTokenizer.from_pretrained(LOCAL_PATH, local_files_only=True)
26
- except Exception as e:
27
- self.tokenizer = DistilBertTokenizer.from_pretrained(TOKENIZER)
28
- self.tokenizer.save_pretrained(LOCAL_PATH)
29
-
30
-
31
-
32
- print('finding model', self.task)
33
- if self.task == 'mlm':
34
-
35
- MODEL = 'distilbert-base-uncased'
36
- LOCAL_PATH = os.path.join(CACHE_DIR, "models",MODEL.replace("/", "_"))
37
-
38
- try:
39
- self.model = DistilBertForMaskedLM.from_pretrained( LOCAL_PATH, local_files_only=True )
40
- except Exception as e:
41
- self.model = DistilBertForMaskedLM.from_pretrained( MODEL )
42
- self.model.save_pretrained(LOCAL_PATH)
43
-
44
- elif self.task == 'sst':
45
- MODEL = 'distilbert-base-uncased-finetuned-sst-2-english'
46
- LOCAL_PATH = os.path.join(CACHE_DIR, "models",MODEL.replace("/", "_"))
47
-
48
-
49
- try:
50
- self.model = DistilBertForSequenceClassification.from_pretrained( LOCAL_PATH, local_files_only=True )
51
- except Exception as e:
52
- self.model = DistilBertForSequenceClassification.from_pretrained( MODEL )
53
- self.model.save_pretrained(LOCAL_PATH)
54
-
55
-
56
- elif self.task == 'mnli':
57
- MODEL = "textattack/distilbert-base-uncased-MNLI"
58
-
59
-
60
- LOCAL_PATH = os.path.join(CACHE_DIR, "models",MODEL.replace("/", "_"))
61
-
62
-
63
- try:
64
- self.model = DistilBertForSequenceClassification.from_pretrained( LOCAL_PATH, local_files_only=True)
65
- except Exception as e:
66
- self.model = DistilBertForSequenceClassification.from_pretrained( MODEL)
67
- self.model.save_pretrained(LOCAL_PATH)
68
-
69
-
70
-
71
-
72
- else:
73
- raise ValueError(f"Unsupported task: {self.task}")
74
-
75
-
76
-
77
-
78
-
79
-
80
- self.model.eval()
81
- self.num_attention_layers = len(self.model.distilbert.transformer.layer)
82
-
83
- self.model.to(self.device)
84
-
85
-
86
-
87
- def tokenize(self, text, hypothesis = ''):
88
-
89
-
90
-
91
- if len(hypothesis) == 0:
92
- encoded = self.tokenizer(text, return_tensors='pt', return_attention_mask=True,padding=False, truncation=True)
93
- else:
94
- encoded = self.tokenizer(text, hypothesis, return_tensors='pt', return_attention_mask=True,padding=False, truncation=True)
95
-
96
-
97
- input_ids = encoded['input_ids'].to(self.device)
98
- attention_mask = encoded['attention_mask'].to(self.device)
99
- tokens = self.tokenizer.convert_ids_to_tokens(input_ids[0])
100
- return {
101
- 'input_ids': input_ids,
102
- 'attention_mask': attention_mask,
103
- 'tokens': tokens
104
- }
105
-
106
- def predict(self, task, text, hypothesis='', maskID = 0):
107
-
108
- if task == 'mlm':
109
- inputs = self.tokenizer(text, return_tensors='pt', padding=False, truncation=True)
110
- if maskID is not None and 0 <= maskID < inputs['input_ids'].size(1):
111
- inputs['input_ids'][0][maskID] = self.tokenizer.mask_token_id
112
- mask_index = maskID
113
- else:
114
- raise ValueError(f"Invalid maskID {maskID} for input of length {inputs['input_ids'].size(1)}")
115
- inputs = {k: v.to(self.device) for k, v in inputs.items()}
116
-
117
- with torch.no_grad():
118
- outputs = self.model(**inputs)
119
- logits = outputs.logits
120
-
121
- mask_logits = logits[0, mask_index]
122
- top_probs, top_indices = torch.topk(F.softmax(mask_logits, dim=-1), 10)
123
- decoded = self.tokenizer.convert_ids_to_tokens(top_indices.tolist())
124
- return decoded, top_probs
125
-
126
- elif task == 'sst':
127
- inputs = self.tokenizer(text, return_tensors='pt', padding=False, truncation=True).to(self.device)
128
-
129
- with torch.no_grad():
130
- outputs = self.model(**inputs)
131
- logits = outputs.logits
132
- probs = F.softmax(logits, dim=1).squeeze()
133
-
134
- labels = ["negative", "positive"]
135
- return labels, probs
136
- elif task == 'mnli':
137
- inputs = self.tokenizer(text, hypothesis, return_tensors='pt', padding=True, truncation=True).to(self.device)
138
-
139
- with torch.no_grad():
140
- outputs = self.model(**inputs)
141
- logits = outputs.logits
142
- probs = F.softmax(logits, dim=1).squeeze()
143
-
144
- labels = ["entailment", "neutral", "contradiction"]
145
- return labels, probs
146
-
147
- else:
148
- raise NotImplementedError(f"Task '{task}' not supported for DistilBERT")
149
-
150
- def get_all_grad_attn_matrix(self, task, sentence, hypothesis='', maskID = 0):
151
- print(task, sentence,hypothesis)
152
-
153
- print('Tokenize')
154
- if task == 'mnli':
155
- inputs = self.tokenizer(sentence, hypothesis, return_tensors='pt', padding=False, truncation=True)
156
- elif task == 'mlm':
157
- inputs = self.tokenizer(sentence, return_tensors='pt', padding=False, truncation=True)
158
- if maskID is not None and 0 <= maskID < inputs['input_ids'].size(1):
159
- inputs['input_ids'][0][maskID] = self.tokenizer.mask_token_id
160
- else:
161
- print(f"Invalid maskID {maskID} for input of length {inputs['input_ids'].size(1)}")
162
- raise ValueError(f"Invalid maskID {maskID} for input of length {inputs['input_ids'].size(1)}")
163
- inputs = {k: v.to(self.device) for k, v in inputs.items()}
164
- else:
165
- inputs = self.tokenizer(sentence, return_tensors='pt', padding=False, truncation=True)
166
- tokens = self.tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
167
- print(tokens)
168
- inputs = {k: v.to(self.device) for k, v in inputs.items()}
169
-
170
- print('Input embeddings with grad')
171
- embedding_layer = self.model.distilbert.embeddings.word_embeddings
172
- inputs_embeds = embedding_layer(inputs["input_ids"])
173
- inputs_embeds.requires_grad_()
174
-
175
- print('Forward pass')
176
- outputs = self.model.distilbert(
177
- inputs_embeds=inputs_embeds,
178
- attention_mask=inputs["attention_mask"],
179
- output_attentions=True,
180
- )
181
- attentions = outputs.attentions # list of [1, heads, seq, seq]
182
-
183
- print('Mean attentions per layer')
184
- mean_attns = [a.squeeze(0).mean(dim=0).detach().cpu() for a in attentions]
185
-
186
-
187
-
188
- attn_matrices_all = []
189
- grad_matrices_all = []
190
- for target_layer in range(len(attentions)):
191
- grad_matrix, attn_matrix = self.get_grad_attn_matrix(inputs_embeds, attentions, mean_attns, target_layer)
192
- grad_matrices_all.append(grad_matrix.tolist())
193
- attn_matrices_all.append(attn_matrix.tolist())
194
- return grad_matrices_all, attn_matrices_all
195
-
196
-
197
- def get_grad_attn_matrix(self,inputs_embeds, attentions, mean_attns, target_layer):
198
- attn_matrix = mean_attns[target_layer]
199
- seq_len = attn_matrix.shape[0]
200
- attn_layer = attentions[target_layer].squeeze(0).mean(dim=0)
201
-
202
- print('Computing grad norms')
203
- grad_norms_list = []
204
- for k in range(seq_len):
205
- scalar = attn_layer[:, k].sum()
206
- grad = torch.autograd.grad(scalar, inputs_embeds, retain_graph=True)[0].squeeze(0)
207
- grad_norms = grad.norm(dim=1)
208
- grad_norms_list.append(grad_norms.unsqueeze(1))
209
-
210
- grad_matrix = torch.cat(grad_norms_list, dim=1)
211
- grad_matrix = grad_matrix[:seq_len, :seq_len]
212
- attn_matrix = attn_matrix[:seq_len, :seq_len]
213
-
214
- return grad_matrix, attn_matrix
215
-
216
-
217
-
218
- if __name__ == "__main__":
219
- import sys
220
-
221
- MODEL_CLASSES = {
222
- "bert": BERTVisualizer,
223
- "roberta": RoBERTaVisualizer,
224
- "distilbert": DistilBERTVisualizer,
225
- "bart": BARTVisualizer,
226
- }
227
-
228
- # Parse command-line args or fallback to default
229
- model_name = sys.argv[1] if len(sys.argv) > 1 else "bert"
230
- text = " ".join(sys.argv[2:]) if len(sys.argv) > 2 else "The quick brown fox jumps over the lazy dog."
231
-
232
- if model_name.lower() not in MODEL_CLASSES:
233
- print(f"Supported models: {list(MODEL_CLASSES.keys())}")
234
- sys.exit(1)
235
-
236
- # Instantiate the visualizer
237
- visualizer_class = MODEL_CLASSES[model_name.lower()]
238
- visualizer = visualizer_class()
239
-
240
- # Tokenize
241
- token_info = visualizer.tokenize(text)
242
-
243
- # Report
244
- print(f"\nModel: {model_name}")
245
- print(f"Num attention layers: {visualizer.num_attention_layers}")
246
- print(f"Tokens: {token_info['tokens']}")
247
- print(f"Input IDs: {token_info['input_ids'].tolist()}")
248
- print(f"Attention mask: {token_info['attention_mask'].tolist()}")
249
-
250
-
251
- """
252
- usage for debug:
253
- python your_file.py bert "The rain in Spain falls mainly on the plain."
 
 
 
 
254
  """
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+
5
+
6
+ import os
7
+ from models import TransformerVisualizer
8
+
9
+ from transformers import (
10
+ DistilBertTokenizer,
11
+ DistilBertForMaskedLM, DistilBertForSequenceClassification
12
+ )
13
+
14
+ CACHE_DIR = "/data/hf_cache"
15
+ class DistilBERTVisualizer(TransformerVisualizer):
16
+ def __init__(self, task):
17
+ super().__init__()
18
+ self.task = task
19
+
20
+
21
+ TOKENIZER = 'distilbert-base-uncased'
22
+ LOCAL_PATH = os.path.join(CACHE_DIR, "tokenizers",TOKENIZER.replace("/", "_"))
23
+
24
+ self.tokenizer = DistilBertTokenizer.from_pretrained(LOCAL_PATH, local_files_only=True)
25
+ """
26
+ try:
27
+ self.tokenizer = DistilBertTokenizer.from_pretrained(LOCAL_PATH, local_files_only=True)
28
+ except Exception as e:
29
+ self.tokenizer = DistilBertTokenizer.from_pretrained(TOKENIZER)
30
+ self.tokenizer.save_pretrained(LOCAL_PATH)
31
+ """
32
+
33
+
34
+ print('finding model', self.task)
35
+ if self.task == 'mlm':
36
+
37
+ MODEL = 'distilbert-base-uncased'
38
+ LOCAL_PATH = os.path.join(CACHE_DIR, "models",MODEL)
39
+
40
+ self.model = DistilBertForMaskedLM.from_pretrained( LOCAL_PATH, local_files_only=True )
41
+ """
42
+ try:
43
+ except Exception as e:
44
+ self.model = DistilBertForMaskedLM.from_pretrained( MODEL )
45
+ self.model.save_pretrained(LOCAL_PATH)
46
+ """
47
+ elif self.task == 'sst':
48
+ MODEL = 'distilbert-base-uncased-finetuned-sst-2-english'
49
+ LOCAL_PATH = os.path.join(CACHE_DIR, "models",MODEL)
50
+
51
+ self.model = DistilBertForSequenceClassification.from_pretrained( LOCAL_PATH, local_files_only=True )
52
+ """
53
+ try:
54
+ self.model = DistilBertForSequenceClassification.from_pretrained( LOCAL_PATH, local_files_only=True )
55
+ except Exception as e:
56
+ self.model = DistilBertForSequenceClassification.from_pretrained( MODEL )
57
+ self.model.save_pretrained(LOCAL_PATH)
58
+ """
59
+
60
+ elif self.task == 'mnli':
61
+ MODEL = "textattack_distilbert-base-uncased-MNLI"
62
+ LOCAL_PATH = os.path.join(CACHE_DIR, "models",MODEL)
63
+
64
+
65
+ self.model = DistilBertForSequenceClassification.from_pretrained( LOCAL_PATH, local_files_only=True)
66
+ """
67
+ try:
68
+ self.model = DistilBertForSequenceClassification.from_pretrained( LOCAL_PATH, local_files_only=True)
69
+ except Exception as e:
70
+ self.model = DistilBertForSequenceClassification.from_pretrained( MODEL)
71
+ self.model.save_pretrained(LOCAL_PATH)
72
+ """
73
+
74
+
75
+
76
+ else:
77
+ raise ValueError(f"Unsupported task: {self.task}")
78
+
79
+
80
+
81
+
82
+
83
+
84
+ self.model.eval()
85
+ self.num_attention_layers = len(self.model.distilbert.transformer.layer)
86
+
87
+ self.model.to(self.device)
88
+
89
+
90
+
91
+ def tokenize(self, text, hypothesis = ''):
92
+
93
+
94
+
95
+ if len(hypothesis) == 0:
96
+ encoded = self.tokenizer(text, return_tensors='pt', return_attention_mask=True,padding=False, truncation=True)
97
+ else:
98
+ encoded = self.tokenizer(text, hypothesis, return_tensors='pt', return_attention_mask=True,padding=False, truncation=True)
99
+
100
+
101
+ input_ids = encoded['input_ids'].to(self.device)
102
+ attention_mask = encoded['attention_mask'].to(self.device)
103
+ tokens = self.tokenizer.convert_ids_to_tokens(input_ids[0])
104
+ return {
105
+ 'input_ids': input_ids,
106
+ 'attention_mask': attention_mask,
107
+ 'tokens': tokens
108
+ }
109
+
110
+ def predict(self, task, text, hypothesis='', maskID = 0):
111
+
112
+ if task == 'mlm':
113
+ inputs = self.tokenizer(text, return_tensors='pt', padding=False, truncation=True)
114
+ if maskID is not None and 0 <= maskID < inputs['input_ids'].size(1):
115
+ inputs['input_ids'][0][maskID] = self.tokenizer.mask_token_id
116
+ mask_index = maskID
117
+ else:
118
+ raise ValueError(f"Invalid maskID {maskID} for input of length {inputs['input_ids'].size(1)}")
119
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
120
+
121
+ with torch.no_grad():
122
+ outputs = self.model(**inputs)
123
+ logits = outputs.logits
124
+
125
+ mask_logits = logits[0, mask_index]
126
+ top_probs, top_indices = torch.topk(F.softmax(mask_logits, dim=-1), 10)
127
+ decoded = self.tokenizer.convert_ids_to_tokens(top_indices.tolist())
128
+ return decoded, top_probs
129
+
130
+ elif task == 'sst':
131
+ inputs = self.tokenizer(text, return_tensors='pt', padding=False, truncation=True).to(self.device)
132
+
133
+ with torch.no_grad():
134
+ outputs = self.model(**inputs)
135
+ logits = outputs.logits
136
+ probs = F.softmax(logits, dim=1).squeeze()
137
+
138
+ labels = ["negative", "positive"]
139
+ return labels, probs
140
+ elif task == 'mnli':
141
+ inputs = self.tokenizer(text, hypothesis, return_tensors='pt', padding=True, truncation=True).to(self.device)
142
+
143
+ with torch.no_grad():
144
+ outputs = self.model(**inputs)
145
+ logits = outputs.logits
146
+ probs = F.softmax(logits, dim=1).squeeze()
147
+
148
+ labels = ["entailment", "neutral", "contradiction"]
149
+ return labels, probs
150
+
151
+ else:
152
+ raise NotImplementedError(f"Task '{task}' not supported for DistilBERT")
153
+
154
+ def get_all_grad_attn_matrix(self, task, sentence, hypothesis='', maskID = 0):
155
+ print(task, sentence,hypothesis)
156
+
157
+ print('Tokenize')
158
+ if task == 'mnli':
159
+ inputs = self.tokenizer(sentence, hypothesis, return_tensors='pt', padding=False, truncation=True)
160
+ elif task == 'mlm':
161
+ inputs = self.tokenizer(sentence, return_tensors='pt', padding=False, truncation=True)
162
+ if maskID is not None and 0 <= maskID < inputs['input_ids'].size(1):
163
+ inputs['input_ids'][0][maskID] = self.tokenizer.mask_token_id
164
+ else:
165
+ print(f"Invalid maskID {maskID} for input of length {inputs['input_ids'].size(1)}")
166
+ raise ValueError(f"Invalid maskID {maskID} for input of length {inputs['input_ids'].size(1)}")
167
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
168
+ else:
169
+ inputs = self.tokenizer(sentence, return_tensors='pt', padding=False, truncation=True)
170
+ tokens = self.tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
171
+ print(tokens)
172
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
173
+
174
+ print('Input embeddings with grad')
175
+ embedding_layer = self.model.distilbert.embeddings.word_embeddings
176
+ inputs_embeds = embedding_layer(inputs["input_ids"])
177
+ inputs_embeds.requires_grad_()
178
+
179
+ print('Forward pass')
180
+ outputs = self.model.distilbert(
181
+ inputs_embeds=inputs_embeds,
182
+ attention_mask=inputs["attention_mask"],
183
+ output_attentions=True,
184
+ )
185
+ attentions = outputs.attentions # list of [1, heads, seq, seq]
186
+
187
+ print('Mean attentions per layer')
188
+ mean_attns = [a.squeeze(0).mean(dim=0).detach().cpu() for a in attentions]
189
+
190
+
191
+
192
+ attn_matrices_all = []
193
+ grad_matrices_all = []
194
+ for target_layer in range(len(attentions)):
195
+ grad_matrix, attn_matrix = self.get_grad_attn_matrix(inputs_embeds, attentions, mean_attns, target_layer)
196
+ grad_matrices_all.append(grad_matrix.tolist())
197
+ attn_matrices_all.append(attn_matrix.tolist())
198
+ return grad_matrices_all, attn_matrices_all
199
+
200
+
201
+ def get_grad_attn_matrix(self,inputs_embeds, attentions, mean_attns, target_layer):
202
+ attn_matrix = mean_attns[target_layer]
203
+ seq_len = attn_matrix.shape[0]
204
+ attn_layer = attentions[target_layer].squeeze(0).mean(dim=0)
205
+
206
+ print('Computing grad norms')
207
+ grad_norms_list = []
208
+ for k in range(seq_len):
209
+ scalar = attn_layer[:, k].sum()
210
+ grad = torch.autograd.grad(scalar, inputs_embeds, retain_graph=True)[0].squeeze(0)
211
+ grad_norms = grad.norm(dim=1)
212
+ grad_norms_list.append(grad_norms.unsqueeze(1))
213
+
214
+ grad_matrix = torch.cat(grad_norms_list, dim=1)
215
+ grad_matrix = grad_matrix[:seq_len, :seq_len]
216
+ attn_matrix = attn_matrix[:seq_len, :seq_len]
217
+
218
+ return grad_matrix, attn_matrix
219
+
220
+
221
+
222
+ if __name__ == "__main__":
223
+ import sys
224
+
225
+ MODEL_CLASSES = {
226
+ "bert": BERTVisualizer,
227
+ "roberta": RoBERTaVisualizer,
228
+ "distilbert": DistilBERTVisualizer,
229
+ "bart": BARTVisualizer,
230
+ }
231
+
232
+ # Parse command-line args or fallback to default
233
+ model_name = sys.argv[1] if len(sys.argv) > 1 else "bert"
234
+ text = " ".join(sys.argv[2:]) if len(sys.argv) > 2 else "The quick brown fox jumps over the lazy dog."
235
+
236
+ if model_name.lower() not in MODEL_CLASSES:
237
+ print(f"Supported models: {list(MODEL_CLASSES.keys())}")
238
+ sys.exit(1)
239
+
240
+ # Instantiate the visualizer
241
+ visualizer_class = MODEL_CLASSES[model_name.lower()]
242
+ visualizer = visualizer_class()
243
+
244
+ # Tokenize
245
+ token_info = visualizer.tokenize(text)
246
+
247
+ # Report
248
+ print(f"\nModel: {model_name}")
249
+ print(f"Num attention layers: {visualizer.num_attention_layers}")
250
+ print(f"Tokens: {token_info['tokens']}")
251
+ print(f"Input IDs: {token_info['input_ids'].tolist()}")
252
+ print(f"Attention mask: {token_info['attention_mask'].tolist()}")
253
+
254
+
255
+ """
256
+ usage for debug:
257
+ python your_file.py bert "The rain in Spain falls mainly on the plain."
258
  """
ROBERTAmodel.py CHANGED
@@ -1,199 +1,207 @@
1
- from transformers import RobertaTokenizer, RobertaForMaskedLM
2
- import torch
3
- import torch.nn.functional as F
4
- from models import TransformerVisualizer
5
- from transformers import (
6
- RobertaForMaskedLM, RobertaForSequenceClassification
7
- )
8
- import os
9
-
10
- CACHE_DIR = "/data/hf_cache"
11
-
12
- class RoBERTaVisualizer(TransformerVisualizer):
13
- def __init__(self, task):
14
- super().__init__()
15
- self.task = task
16
-
17
-
18
-
19
- TOKENIZER = 'roberta-base'
20
- LOCAL_PATH = os.path.join(CACHE_DIR, "tokenizers",TOKENIZER.replace("/", "_"))
21
-
22
- try:
23
- self.tokenizer = RobertaTokenizer.from_pretrained(LOCAL_PATH, local_files_only=True)
24
- except Exception as e:
25
- self.tokenizer = RobertaTokenizer.from_pretrained(TOKENIZER)
26
- self.tokenizer.save_pretrained(LOCAL_PATH)
27
-
28
- if self.task == 'mlm':
29
-
30
- MODEL = "roberta-base"
31
- LOCAL_PATH = os.path.join(CACHE_DIR, "models",MODEL.replace("/", "_"))
32
-
33
- try:
34
- self.model = RobertaForMaskedLM.from_pretrained( LOCAL_PATH, local_files_only=True )
35
- except Exception as e:
36
- self.model = RobertaForMaskedLM.from_pretrained( MODEL )
37
- self.model.save_pretrained(LOCAL_PATH)
38
-
39
- elif self.task == 'sst':
40
-
41
-
42
- MODEL = 'textattack/roberta-base-SST-2'
43
- LOCAL_PATH = os.path.join(CACHE_DIR, "models",MODEL.replace("/", "_"))
44
-
45
- try:
46
- self.model = RobertaForSequenceClassification.from_pretrained( LOCAL_PATH, local_files_only=True )
47
- except Exception as e:
48
- self.model = RobertaForSequenceClassification.from_pretrained( MODEL )
49
- self.model.save_pretrained(LOCAL_PATH)
50
-
51
-
52
- elif self.task == 'mnli':
53
- MODEL = "roberta-large-mnli"
54
- LOCAL_PATH = os.path.join(CACHE_DIR, "models",MODEL.replace("/", "_"))
55
-
56
-
57
- try:
58
- self.model = RobertaForSequenceClassification.from_pretrained( LOCAL_PATH, local_files_only=True)
59
- except Exception as e:
60
- self.model = RobertaForSequenceClassification.from_pretrained( MODEL)
61
- self.model.save_pretrained(LOCAL_PATH)
62
-
63
-
64
-
65
-
66
- self.model.to(self.device)
67
- self.model.eval()
68
- self.num_attention_layers = self.model.config.num_hidden_layers
69
-
70
-
71
- def tokenize(self, text, hypothesis = ''):
72
-
73
-
74
-
75
- if len(hypothesis) == 0:
76
- encoded = self.tokenizer(text, return_tensors='pt', return_attention_mask=True,padding=False, truncation=True)
77
- else:
78
- encoded = self.tokenizer(text, hypothesis, return_tensors='pt', return_attention_mask=True,padding=False, truncation=True)
79
-
80
- input_ids = encoded['input_ids'].to(self.device)
81
- attention_mask = encoded['attention_mask'].to(self.device)
82
- tokens = self.tokenizer.convert_ids_to_tokens(input_ids[0])
83
- print('First time tokenizing:', tokens, len(tokens))
84
-
85
- response = {
86
- 'input_ids': input_ids,
87
- 'attention_mask': attention_mask,
88
- 'tokens': tokens
89
- }
90
- print(response)
91
- return response
92
-
93
- def predict(self, task, text, hypothesis='', maskID = None):
94
-
95
-
96
-
97
- if task == 'mlm':
98
- inputs = self.tokenizer(text, return_tensors='pt', padding=False, truncation=True)
99
- if maskID is not None and 0 <= maskID < inputs['input_ids'].size(1):
100
- inputs['input_ids'][0][maskID] = self.tokenizer.mask_token_id
101
- mask_index = maskID
102
- else:
103
- raise ValueError(f"Invalid maskID {maskID} for input of length {inputs['input_ids'].size(1)}")
104
- inputs = {k: v.to(self.device) for k, v in inputs.items()}
105
-
106
- with torch.no_grad():
107
- outputs = self.model(**inputs)
108
- logits = outputs.logits
109
-
110
- mask_logits = logits[0, mask_index]
111
- top_probs, top_indices = torch.topk(F.softmax(mask_logits, dim=-1), 10)
112
- decoded = self.tokenizer.convert_ids_to_tokens(top_indices.tolist())
113
- return decoded, top_probs
114
-
115
- elif task == 'sst':
116
- inputs = self.tokenizer(text, return_tensors='pt', padding=False, truncation=True).to(self.device)
117
-
118
- with torch.no_grad():
119
- outputs = self.model(**inputs)
120
- logits = outputs.logits
121
- probs = F.softmax(logits, dim=1).squeeze()
122
-
123
- labels = ["negative", "positive"]
124
- return labels, probs
125
-
126
- elif task == 'mnli':
127
- inputs = self.tokenizer(text, hypothesis, return_tensors='pt', padding=True, truncation=True).to(self.device)
128
-
129
- with torch.no_grad():
130
- outputs = self.model(**inputs)
131
- logits = outputs.logits
132
- probs = F.softmax(logits, dim=1).squeeze()
133
-
134
- labels = ["entailment", "neutral", "contradiction"]
135
- return labels, probs
136
-
137
- else:
138
- raise NotImplementedError(f"Task '{task}' not supported for RoBERTa")
139
-
140
-
141
- def get_all_grad_attn_matrix(self, task, sentence, hypothesis='', maskID = None):
142
- print(task, sentence, hypothesis)
143
- print('Tokenize')
144
- if task == 'mnli':
145
- inputs = self.tokenizer(sentence, hypothesis, return_tensors='pt', padding=False, truncation=True)
146
- elif task == 'mlm':
147
- inputs = self.tokenizer(sentence, return_tensors='pt', padding=False, truncation=True)
148
- if maskID is not None and 0 <= maskID < inputs['input_ids'].size(1):
149
- inputs['input_ids'][0][maskID] = self.tokenizer.mask_token_id
150
- else:
151
- inputs = self.tokenizer(sentence, return_tensors='pt', padding=False, truncation=True)
152
- tokens = self.tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
153
- print(tokens)
154
- inputs = {k: v.to(self.device) for k, v in inputs.items()}
155
-
156
- print('Input embeddings with grad')
157
- embedding_layer = self.model.roberta.embeddings.word_embeddings
158
- inputs_embeds = embedding_layer(inputs["input_ids"])
159
- inputs_embeds.requires_grad_()
160
-
161
- print('Forward pass')
162
- outputs = self.model.roberta(
163
- inputs_embeds=inputs_embeds,
164
- attention_mask=inputs["attention_mask"],
165
- output_attentions=True
166
- )
167
- attentions = outputs.attentions # list of [1, heads, seq, seq]
168
-
169
- print('Average attentions per layer')
170
- mean_attns = [a.squeeze(0).mean(dim=0).detach().cpu() for a in attentions]
171
-
172
- attn_matrices_all = []
173
- grad_matrices_all = []
174
- for target_layer in range(len(attentions)):
175
- grad_matrix, attn_matrix = self.get_grad_attn_matrix(inputs_embeds, attentions, mean_attns, target_layer)
176
- grad_matrices_all.append(grad_matrix.tolist())
177
- attn_matrices_all.append(attn_matrix.tolist())
178
- return grad_matrices_all, attn_matrices_all
179
-
180
- def get_grad_attn_matrix(self,inputs_embeds, attentions, mean_attns, target_layer):
181
-
182
- attn_matrix = mean_attns[target_layer]
183
- seq_len = attn_matrix.shape[0]
184
- attn_layer = attentions[target_layer].squeeze(0).mean(dim=0) # [seq, seq]
185
-
186
- print('Computing grad norms')
187
- grad_norms_list = []
188
- for k in range(seq_len):
189
- scalar = attn_layer[:, k].sum()
190
- grad = torch.autograd.grad(scalar, inputs_embeds, retain_graph=True)[0].squeeze(0)
191
- grad_norms = grad.norm(dim=1)
192
- grad_norms_list.append(grad_norms.unsqueeze(1))
193
-
194
- grad_matrix = torch.cat(grad_norms_list, dim=1)
195
- grad_matrix = grad_matrix[:seq_len, :seq_len]
196
- attn_matrix = attn_matrix[:seq_len, :seq_len]
197
-
198
-
199
- return grad_matrix, attn_matrix
 
 
 
 
 
 
 
 
 
1
+ from transformers import RobertaTokenizer, RobertaForMaskedLM
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from models import TransformerVisualizer
5
+ from transformers import (
6
+ RobertaForMaskedLM, RobertaForSequenceClassification
7
+ )
8
+ import os
9
+
10
+ CACHE_DIR = "/data/hf_cache"
11
+
12
+ class RoBERTaVisualizer(TransformerVisualizer):
13
+ def __init__(self, task):
14
+ super().__init__()
15
+ self.task = task
16
+
17
+
18
+
19
+ TOKENIZER = 'roberta-base'
20
+ LOCAL_PATH = os.path.join(CACHE_DIR, "tokenizers",TOKENIZER)
21
+
22
+ self.tokenizer = RobertaTokenizer.from_pretrained(LOCAL_PATH, local_files_only=True)
23
+ """
24
+ try:
25
+ self.tokenizer = RobertaTokenizer.from_pretrained(LOCAL_PATH, local_files_only=True)
26
+ except Exception as e:
27
+ self.tokenizer = RobertaTokenizer.from_pretrained(TOKENIZER)
28
+ self.tokenizer.save_pretrained(LOCAL_PATH)
29
+ """
30
+ if self.task == 'mlm':
31
+
32
+ MODEL = "roberta-base"
33
+ LOCAL_PATH = os.path.join(CACHE_DIR, "models",MODEL)
34
+
35
+ self.model = RobertaForMaskedLM.from_pretrained( LOCAL_PATH, local_files_only=True )
36
+ """
37
+ try:
38
+ self.model = RobertaForMaskedLM.from_pretrained( LOCAL_PATH, local_files_only=True )
39
+ except Exception as e:
40
+ self.model = RobertaForMaskedLM.from_pretrained( MODEL )
41
+ self.model.save_pretrained(LOCAL_PATH)
42
+ """
43
+ elif self.task == 'sst':
44
+
45
+
46
+ MODEL = 'textattack_roberta-base-SST-2'
47
+ LOCAL_PATH = os.path.join(CACHE_DIR, "models",MODEL)
48
+
49
+ self.model = RobertaForSequenceClassification.from_pretrained( LOCAL_PATH, local_files_only=True )
50
+ """
51
+ try:
52
+ self.model = RobertaForSequenceClassification.from_pretrained( LOCAL_PATH, local_files_only=True )
53
+ except Exception as e:
54
+ self.model = RobertaForSequenceClassification.from_pretrained( MODEL )
55
+ self.model.save_pretrained(LOCAL_PATH)
56
+ """
57
+
58
+ elif self.task == 'mnli':
59
+ MODEL = "roberta-large-mnli"
60
+ LOCAL_PATH = os.path.join(CACHE_DIR, "models",MODEL)
61
+
62
+
63
+ self.model = RobertaForSequenceClassification.from_pretrained( LOCAL_PATH, local_files_only=True)
64
+ """
65
+ try:
66
+ self.model = RobertaForSequenceClassification.from_pretrained( LOCAL_PATH, local_files_only=True)
67
+ except Exception as e:
68
+ self.model = RobertaForSequenceClassification.from_pretrained( MODEL)
69
+ self.model.save_pretrained(LOCAL_PATH)
70
+ """
71
+
72
+
73
+
74
+ self.model.to(self.device)
75
+ self.model.eval()
76
+ self.num_attention_layers = self.model.config.num_hidden_layers
77
+
78
+
79
+ def tokenize(self, text, hypothesis = ''):
80
+
81
+
82
+
83
+ if len(hypothesis) == 0:
84
+ encoded = self.tokenizer(text, return_tensors='pt', return_attention_mask=True,padding=False, truncation=True)
85
+ else:
86
+ encoded = self.tokenizer(text, hypothesis, return_tensors='pt', return_attention_mask=True,padding=False, truncation=True)
87
+
88
+ input_ids = encoded['input_ids'].to(self.device)
89
+ attention_mask = encoded['attention_mask'].to(self.device)
90
+ tokens = self.tokenizer.convert_ids_to_tokens(input_ids[0])
91
+ print('First time tokenizing:', tokens, len(tokens))
92
+
93
+ response = {
94
+ 'input_ids': input_ids,
95
+ 'attention_mask': attention_mask,
96
+ 'tokens': tokens
97
+ }
98
+ print(response)
99
+ return response
100
+
101
+ def predict(self, task, text, hypothesis='', maskID = None):
102
+
103
+
104
+
105
+ if task == 'mlm':
106
+ inputs = self.tokenizer(text, return_tensors='pt', padding=False, truncation=True)
107
+ if maskID is not None and 0 <= maskID < inputs['input_ids'].size(1):
108
+ inputs['input_ids'][0][maskID] = self.tokenizer.mask_token_id
109
+ mask_index = maskID
110
+ else:
111
+ raise ValueError(f"Invalid maskID {maskID} for input of length {inputs['input_ids'].size(1)}")
112
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
113
+
114
+ with torch.no_grad():
115
+ outputs = self.model(**inputs)
116
+ logits = outputs.logits
117
+
118
+ mask_logits = logits[0, mask_index]
119
+ top_probs, top_indices = torch.topk(F.softmax(mask_logits, dim=-1), 10)
120
+ decoded = self.tokenizer.convert_ids_to_tokens(top_indices.tolist())
121
+ return decoded, top_probs
122
+
123
+ elif task == 'sst':
124
+ inputs = self.tokenizer(text, return_tensors='pt', padding=False, truncation=True).to(self.device)
125
+
126
+ with torch.no_grad():
127
+ outputs = self.model(**inputs)
128
+ logits = outputs.logits
129
+ probs = F.softmax(logits, dim=1).squeeze()
130
+
131
+ labels = ["negative", "positive"]
132
+ return labels, probs
133
+
134
+ elif task == 'mnli':
135
+ inputs = self.tokenizer(text, hypothesis, return_tensors='pt', padding=True, truncation=True).to(self.device)
136
+
137
+ with torch.no_grad():
138
+ outputs = self.model(**inputs)
139
+ logits = outputs.logits
140
+ probs = F.softmax(logits, dim=1).squeeze()
141
+
142
+ labels = ["entailment", "neutral", "contradiction"]
143
+ return labels, probs
144
+
145
+ else:
146
+ raise NotImplementedError(f"Task '{task}' not supported for RoBERTa")
147
+
148
+
149
+ def get_all_grad_attn_matrix(self, task, sentence, hypothesis='', maskID = None):
150
+ print(task, sentence, hypothesis)
151
+ print('Tokenize')
152
+ if task == 'mnli':
153
+ inputs = self.tokenizer(sentence, hypothesis, return_tensors='pt', padding=False, truncation=True)
154
+ elif task == 'mlm':
155
+ inputs = self.tokenizer(sentence, return_tensors='pt', padding=False, truncation=True)
156
+ if maskID is not None and 0 <= maskID < inputs['input_ids'].size(1):
157
+ inputs['input_ids'][0][maskID] = self.tokenizer.mask_token_id
158
+ else:
159
+ inputs = self.tokenizer(sentence, return_tensors='pt', padding=False, truncation=True)
160
+ tokens = self.tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
161
+ print(tokens)
162
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
163
+
164
+ print('Input embeddings with grad')
165
+ embedding_layer = self.model.roberta.embeddings.word_embeddings
166
+ inputs_embeds = embedding_layer(inputs["input_ids"])
167
+ inputs_embeds.requires_grad_()
168
+
169
+ print('Forward pass')
170
+ outputs = self.model.roberta(
171
+ inputs_embeds=inputs_embeds,
172
+ attention_mask=inputs["attention_mask"],
173
+ output_attentions=True
174
+ )
175
+ attentions = outputs.attentions # list of [1, heads, seq, seq]
176
+
177
+ print('Average attentions per layer')
178
+ mean_attns = [a.squeeze(0).mean(dim=0).detach().cpu() for a in attentions]
179
+
180
+ attn_matrices_all = []
181
+ grad_matrices_all = []
182
+ for target_layer in range(len(attentions)):
183
+ grad_matrix, attn_matrix = self.get_grad_attn_matrix(inputs_embeds, attentions, mean_attns, target_layer)
184
+ grad_matrices_all.append(grad_matrix.tolist())
185
+ attn_matrices_all.append(attn_matrix.tolist())
186
+ return grad_matrices_all, attn_matrices_all
187
+
188
+ def get_grad_attn_matrix(self,inputs_embeds, attentions, mean_attns, target_layer):
189
+
190
+ attn_matrix = mean_attns[target_layer]
191
+ seq_len = attn_matrix.shape[0]
192
+ attn_layer = attentions[target_layer].squeeze(0).mean(dim=0) # [seq, seq]
193
+
194
+ print('Computing grad norms')
195
+ grad_norms_list = []
196
+ for k in range(seq_len):
197
+ scalar = attn_layer[:, k].sum()
198
+ grad = torch.autograd.grad(scalar, inputs_embeds, retain_graph=True)[0].squeeze(0)
199
+ grad_norms = grad.norm(dim=1)
200
+ grad_norms_list.append(grad_norms.unsqueeze(1))
201
+
202
+ grad_matrix = torch.cat(grad_norms_list, dim=1)
203
+ grad_matrix = grad_matrix[:seq_len, :seq_len]
204
+ attn_matrix = attn_matrix[:seq_len, :seq_len]
205
+
206
+
207
+ return grad_matrix, attn_matrix
models.py CHANGED
@@ -1,16 +1,16 @@
1
- import torch
2
-
3
-
4
-
5
-
6
- class TransformerVisualizer():
7
- def __init__(self):
8
- self.device = torch.device('cpu')
9
-
10
- def predict(self, task, text):
11
- return task, text,1
12
-
13
-
14
- def get_attention_gradient_matrix(self, task, text, target_layer):
15
- return task, text,target_layer,1
16
 
 
1
+ import torch
2
+
3
+
4
+
5
+
6
+ class TransformerVisualizer():
7
+ def __init__(self):
8
+ self.device = torch.device('cpu')
9
+
10
+ def predict(self, task, text):
11
+ return task, text,1
12
+
13
+
14
+ def get_attention_gradient_matrix(self, task, text, target_layer):
15
+ return task, text,target_layer,1
16
 
server.py CHANGED
@@ -1,370 +1,349 @@
1
- from fastapi import FastAPI, Request
2
- from pydantic import BaseModel
3
- from pathlib import Path
4
-
5
- import torch
6
- from fastapi import UploadFile, File
7
- import os
8
- from fastapi.middleware.cors import CORSMiddleware
9
-
10
- from ROBERTAmodel import *
11
- from BERTmodel import *
12
- from DISTILLBERTmodel import *
13
-
14
- import os
15
- import zipfile
16
- import shutil
17
- from fastapi import Form
18
- from fastapi import UploadFile, File, Form
19
- from pathlib import Path
20
-
21
- VISUALIZER_CLASSES = {
22
- "BERT": BERTVisualizer,
23
- "RoBERTa": RoBERTaVisualizer,
24
- "DistilBERT": DistilBERTVisualizer,
25
- }
26
-
27
- VISUALIZER_CACHE = {}
28
- app = FastAPI()
29
-
30
- app.add_middleware(
31
- CORSMiddleware,
32
- allow_origins=["*"],
33
- allow_credentials=True,
34
- allow_methods=["*"],
35
- allow_headers=["*"],
36
- )
37
-
38
- MODEL_MAP = {
39
- "BERT": "bert-base-uncased",
40
- "RoBERTa": "roberta-base",
41
- "DistilBERT": "distilbert-base-uncased",
42
- }
43
-
44
- class LoadModelRequest(BaseModel):
45
- model: str
46
- sentence: str
47
- task:str
48
- hypothesis:str
49
-
50
- class GradAttnModelRequest(BaseModel):
51
- model: str
52
- task: str
53
- sentence: str
54
- hypothesis:str
55
- maskID: int | None = None
56
-
57
- class PredModelRequest(BaseModel):
58
- model: str
59
- sentence: str
60
- task:str
61
- hypothesis:str
62
- maskID: int | None = None
63
-
64
- @app.get("/ping")
65
- def ping():
66
- return {"message": "pong"}
67
-
68
-
69
-
70
- @app.post("/upload_to_path")
71
- async def upload_to_path(
72
- file: UploadFile = File(...),
73
- dest_path: str = Form(...)
74
- ):
75
- base_path = Path("/data")
76
- target_path = base_path / dest_path
77
-
78
- # If the path ends with "/", or is a directory, treat it as a folder
79
- if str(dest_path).endswith("/") or target_path.is_dir():
80
- target_path = target_path / file.filename
81
-
82
- # Ensure parent directories exist
83
- target_path.parent.mkdir(parents=True, exist_ok=True)
84
-
85
- # Write file
86
- with open(target_path, "wb") as f:
87
- f.write(await file.read())
88
-
89
- return {"status": "uploaded", "path": str(target_path)}
90
-
91
-
92
-
93
- @app.post("/make_dir")
94
- def make_directory(
95
- dir_path: str = Form(...) # e.g., "logs/test_run"
96
- ):
97
- full_dir = Path("/data") / dir_path
98
- full_dir.mkdir(parents=True, exist_ok=True)
99
- return {"status": "created", "directory": str(full_dir)}
100
-
101
-
102
-
103
- @app.get("/list_data")
104
- def list_data():
105
- base_path = Path("/data")
106
- all_items = []
107
-
108
- for path in base_path.rglob("*"): # recursive glob
109
- all_items.append({
110
- "path": str(path.relative_to(base_path)),
111
- "type": "dir" if path.is_dir() else "file",
112
- "size": path.stat().st_size if path.is_file() else None
113
- })
114
-
115
- return {"items": all_items}
116
-
117
-
118
-
119
-
120
-
121
-
122
-
123
-
124
- @app.post("/purge_data_123456789")
125
- def purge_data():
126
- base_path = Path("/data")
127
- if not base_path.exists():
128
- return {"status": "error", "message": "/data does not exist"}
129
-
130
- deleted = []
131
-
132
- for child in base_path.iterdir():
133
- try:
134
- if child.is_file() or child.is_symlink():
135
- child.unlink()
136
- elif child.is_dir():
137
- shutil.rmtree(child)
138
- deleted.append(str(child.name))
139
- except Exception as e:
140
- deleted.append(f"FAILED: {child.name} ({e})")
141
-
142
- return {
143
- "status": "done",
144
- "deleted": deleted,
145
- "total": len(deleted)
146
- }
147
-
148
-
149
-
150
-
151
-
152
-
153
-
154
-
155
-
156
-
157
-
158
-
159
-
160
-
161
-
162
-
163
-
164
-
165
-
166
-
167
-
168
-
169
-
170
-
171
-
172
-
173
-
174
-
175
-
176
-
177
-
178
-
179
-
180
- ##############################################################
181
-
182
-
183
- @app.post("/load_model")
184
- def load_model(req: LoadModelRequest):
185
- print(f"\n--- /load_model request received ---")
186
- print(f"Model: {req.model}")
187
- print(f"Sentence: {req.sentence}")
188
- print(f"Task: {req.task}")
189
- print(f"hypothesis: {req.hypothesis}")
190
-
191
-
192
- if req.model in VISUALIZER_CACHE:
193
- del VISUALIZER_CACHE[req.model]
194
- torch.cuda.empty_cache()
195
-
196
- vis_class = VISUALIZER_CLASSES.get(req.model)
197
- if vis_class is None:
198
- return {"error": f"Unknown model: {req.model}"}
199
-
200
- print("instantiating visualizer")
201
- try:
202
- vis = vis_class(task=req.task.lower())
203
- print(vis)
204
- VISUALIZER_CACHE[req.model] = vis
205
- print("Visualizer instantiated")
206
- except Exception as e:
207
- print("Visualizer init failed:", e)
208
- return {"error": f"Instantiation failed: {str(e)}"}
209
-
210
- print('tokenizing')
211
- try:
212
- if req.task.lower() == 'mnli':
213
- token_output = vis.tokenize(req.sentence, hypothesis=req.hypothesis)
214
- else:
215
- token_output = vis.tokenize(req.sentence)
216
- print("0 Tokenization successful:", token_output["tokens"])
217
- except Exception as e:
218
- print("Tokenization failed:", e)
219
- return {"error": f"Tokenization failed: {str(e)}"}
220
-
221
- print('response ready')
222
- response = {
223
- "model": req.model,
224
- "tokens": token_output['tokens'],
225
- "num_layers": vis.num_attention_layers,
226
- }
227
- print("load model successful")
228
- print(response)
229
- return response
230
-
231
-
232
-
233
-
234
-
235
- @app.post("/predict_model")
236
- def predict_model(req: PredModelRequest):
237
-
238
- print(f"\n--- /predict_model request received ---")
239
- print(f"predict: Model: {req.model}")
240
- print(f"predict: Task: {req.task}")
241
- print(f"predict: sentence: {req.sentence}")
242
- print(f"predict: hypothesis: {req.hypothesis}")
243
- print(f"predict: maskID: {req.maskID}")
244
-
245
-
246
-
247
- print('predict: instantiating')
248
- try:
249
- vis_class = VISUALIZER_CLASSES.get(req.model)
250
- if vis_class is None:
251
- return {"error": f"Unknown model: {req.model}"}
252
- #if any(p.device.type == 'meta' for p in vis.model.parameters()):
253
- # vis.model = torch.nn.Module.to_empty(vis.model, device=torch.device("cpu"))
254
-
255
- vis = vis_class(task=req.task.lower())
256
- VISUALIZER_CACHE[req.model] = vis
257
- print("Model reloaded and cached.")
258
- except Exception as e:
259
- return {"error": f"Failed to reload model: {str(e)}"}
260
-
261
- print('predict: meta stuff')
262
-
263
-
264
-
265
- print('predict: Run prediction')
266
- try:
267
- if req.task.lower() == 'mnli':
268
- decoded, top_probs = vis.predict(req.task.lower(), req.sentence, hypothesis=req.hypothesis)
269
- elif req.task.lower() == 'mlm':
270
- decoded, top_probs = vis.predict(req.task.lower(), req.sentence, maskID=req.maskID)
271
-
272
- else:
273
- decoded, top_probs = vis.predict(req.task.lower(), req.sentence)
274
- except Exception as e:
275
- decoded, top_probs = "error", e
276
- print(e)
277
-
278
- print('predict: response ready')
279
- response = {
280
- "decoded": decoded,
281
- "top_probs": top_probs.tolist(),
282
- }
283
- print("predict: predict model successful")
284
- if len(decoded) > 5:
285
- print([(k,v[:5]) for k,v in response.items()])
286
- else:
287
- print(response)
288
- return response
289
-
290
-
291
-
292
- @app.post("/get_grad_attn_matrix")
293
- def get_grad_attn_matrix(req: GradAttnModelRequest):
294
-
295
- try:
296
- print(f"\n--- /get_grad_matrix request received ---")
297
- print(f"grad:Model: {req.model}")
298
- print(f"grad:Task: {req.task}")
299
- print(f"grad:sentence: {req.sentence}")
300
- print(f"grad: hypothesis: {req.hypothesis}")
301
- print(f"predict: maskID: {req.maskID}")
302
-
303
-
304
-
305
- try:
306
- vis_class = VISUALIZER_CLASSES.get(req.model)
307
- if vis_class is None:
308
- return {"error": f"Unknown model: {req.model}"}
309
- #if any(p.device.type == 'meta' for p in vis.model.parameters()):
310
- # vis.model = torch.nn.Module.to_empty(vis.model, device=torch.device("cpu"))
311
- vis = vis_class(task=req.task.lower())
312
- VISUALIZER_CACHE[req.model] = vis
313
- print("Model reloaded and cached.")
314
- except Exception as e:
315
- return {"error": f"Failed to reload model: {str(e)}"}
316
-
317
-
318
-
319
- print("run function")
320
- try:
321
- if req.task.lower()=='mnli':
322
- grad_matrix, attn_matrix = vis.get_all_grad_attn_matrix(req.task.lower(), req.sentence,hypothesis=req.hypothesis)
323
- elif req.task.lower()=='mlm':
324
- grad_matrix, attn_matrix = vis.get_all_grad_attn_matrix(req.task.lower(), req.sentence,maskID=req.maskID)
325
- else:
326
- grad_matrix, attn_matrix = vis.get_all_grad_attn_matrix(req.task.lower(), req.sentence)
327
- except Exception as e:
328
- print("Exception during grad/attn computation:", e)
329
- grad_matrix, attn_matrix = e,e
330
-
331
-
332
- response = {
333
- "grad_matrix": grad_matrix,
334
- "attn_matrix": attn_matrix,
335
- }
336
- print('grad attn successful')
337
- return response
338
- except Exception as e:
339
- print("SERVER EXCEPTION:", e)
340
- return {"error": str(e)}
341
-
342
-
343
-
344
- @app.post("/load_all_files")
345
- def load_all_files():
346
-
347
- print('load BERTmlm ')
348
- BERTVisualizer('mlm')
349
- print('load BERTmnli ')
350
- BERTVisualizer('mnli')
351
- print('load BERTsst ')
352
- BERTVisualizer('sst')
353
-
354
- print('load roBERTmlm ')
355
- RoBERTaVisualizer('mlm')
356
- print('load roBERTmnli')
357
-
358
- RoBERTaVisualizer('mnli')
359
- print('load roBERTsst')
360
- RoBERTaVisualizer('sst')
361
-
362
- print('load distillBERTmlm ')
363
- DistilBERTVisualizer('mlm')
364
- print('load distillBERTmmli ')
365
- DistilBERTVisualizer('mnli')
366
- print('load distillBERTsst ')
367
- DistilBERTVisualizer('sst')
368
-
369
-
370
-
 
1
+ from fastapi import FastAPI, Request
2
+ from pydantic import BaseModel
3
+ from pathlib import Path
4
+
5
+ import torch
6
+ from fastapi import UploadFile, File
7
+ import os
8
+ from fastapi.middleware.cors import CORSMiddleware
9
+
10
+ from ROBERTAmodel import *
11
+ from BERTmodel import *
12
+ from DISTILLBERTmodel import *
13
+
14
+ import os
15
+ import zipfile
16
+ import shutil
17
+ from fastapi import Form
18
+ from fastapi import UploadFile, File, Form
19
+ from pathlib import Path
20
+
21
+ VISUALIZER_CLASSES = {
22
+ "BERT": BERTVisualizer,
23
+ "RoBERTa": RoBERTaVisualizer,
24
+ "DistilBERT": DistilBERTVisualizer,
25
+ }
26
+
27
+ VISUALIZER_CACHE = {}
28
+ app = FastAPI()
29
+
30
+ app.add_middleware(
31
+ CORSMiddleware,
32
+ allow_origins=["*"],
33
+ allow_credentials=True,
34
+ allow_methods=["*"],
35
+ allow_headers=["*"],
36
+ )
37
+
38
+ MODEL_MAP = {
39
+ "BERT": "bert-base-uncased",
40
+ "RoBERTa": "roberta-base",
41
+ "DistilBERT": "distilbert-base-uncased",
42
+ }
43
+
44
+ class LoadModelRequest(BaseModel):
45
+ model: str
46
+ sentence: str
47
+ task:str
48
+ hypothesis:str
49
+
50
+ class GradAttnModelRequest(BaseModel):
51
+ model: str
52
+ task: str
53
+ sentence: str
54
+ hypothesis:str
55
+ maskID: int | None = None
56
+
57
+ class PredModelRequest(BaseModel):
58
+ model: str
59
+ sentence: str
60
+ task:str
61
+ hypothesis:str
62
+ maskID: int | None = None
63
+
64
+
65
+
66
+ @app.post("/upload_model")
67
+ async def upload_model(file: UploadFile = File(...)):
68
+ save_path = f"/data/models/{file.filename}" # or wherever your disk is mounted
69
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
70
+ with open(save_path, "wb") as f:
71
+ f.write(await file.read())
72
+ return {"status": "uploaded", "path": save_path}
73
+
74
+
75
+
76
+ @app.post("/load_model")
77
+ def load_model(req: LoadModelRequest):
78
+ print(f"\n--- /load_model request received ---")
79
+ print(f"Model: {req.model}")
80
+ print(f"Sentence: {req.sentence}")
81
+ print(f"Task: {req.task}")
82
+ print(f"hypothesis: {req.hypothesis}")
83
+
84
+
85
+ if req.model in VISUALIZER_CACHE:
86
+ del VISUALIZER_CACHE[req.model]
87
+ torch.cuda.empty_cache()
88
+
89
+ vis_class = VISUALIZER_CLASSES.get(req.model)
90
+ if vis_class is None:
91
+ return {"error": f"Unknown model: {req.model}"}
92
+
93
+ print("instantiating visualizer")
94
+ try:
95
+ vis = vis_class(task=req.task.lower())
96
+ print(vis)
97
+ VISUALIZER_CACHE[req.model] = vis
98
+ print("Visualizer instantiated")
99
+ except Exception as e:
100
+ print("Visualizer init failed:", e)
101
+ return {"error": f"Instantiation failed: {str(e)}"}
102
+
103
+ print('tokenizing')
104
+ try:
105
+ if req.task.lower() == 'mnli':
106
+ token_output = vis.tokenize(req.sentence, hypothesis=req.hypothesis)
107
+ else:
108
+ token_output = vis.tokenize(req.sentence)
109
+ print("0 Tokenization successful:", token_output["tokens"])
110
+ except Exception as e:
111
+ print("Tokenization failed:", e)
112
+ return {"error": f"Tokenization failed: {str(e)}"}
113
+
114
+ print('response ready')
115
+ response = {
116
+ "model": req.model,
117
+ "tokens": token_output['tokens'],
118
+ "num_layers": vis.num_attention_layers,
119
+ }
120
+ print("load model successful")
121
+ print(response)
122
+ return response
123
+
124
+
125
+
126
+
127
+
128
+ @app.post("/predict_model")
129
+ def predict_model(req: PredModelRequest):
130
+
131
+ print(f"\n--- /predict_model request received ---")
132
+ print(f"predict: Model: {req.model}")
133
+ print(f"predict: Task: {req.task}")
134
+ print(f"predict: sentence: {req.sentence}")
135
+ print(f"predict: hypothesis: {req.hypothesis}")
136
+ print(f"predict: maskID: {req.maskID}")
137
+
138
+
139
+
140
+ print('predict: instantiating')
141
+ try:
142
+ vis_class = VISUALIZER_CLASSES.get(req.model)
143
+ if vis_class is None:
144
+ return {"error": f"Unknown model: {req.model}"}
145
+ #if any(p.device.type == 'meta' for p in vis.model.parameters()):
146
+ # vis.model = torch.nn.Module.to_empty(vis.model, device=torch.device("cpu"))
147
+
148
+ vis = vis_class(task=req.task.lower())
149
+ VISUALIZER_CACHE[req.model] = vis
150
+ print("Model reloaded and cached.")
151
+ except Exception as e:
152
+ return {"error": f"Failed to reload model: {str(e)}"}
153
+
154
+ print('predict: meta stuff')
155
+
156
+
157
+
158
+ print('predict: Run prediction')
159
+ try:
160
+ if req.task.lower() == 'mnli':
161
+ decoded, top_probs = vis.predict(req.task.lower(), req.sentence, hypothesis=req.hypothesis)
162
+ elif req.task.lower() == 'mlm':
163
+ decoded, top_probs = vis.predict(req.task.lower(), req.sentence, maskID=req.maskID)
164
+
165
+ else:
166
+ decoded, top_probs = vis.predict(req.task.lower(), req.sentence)
167
+ except Exception as e:
168
+ decoded, top_probs = "error", e
169
+ print(e)
170
+
171
+ print('predict: response ready')
172
+ response = {
173
+ "decoded": decoded,
174
+ "top_probs": top_probs.tolist(),
175
+ }
176
+ print("predict: predict model successful")
177
+ if len(decoded) > 5:
178
+ print([(k,v[:5]) for k,v in response.items()])
179
+ else:
180
+ print(response)
181
+ return response
182
+
183
+
184
+
185
+ @app.post("/get_grad_attn_matrix")
186
+ def get_grad_attn_matrix(req: GradAttnModelRequest):
187
+
188
+ try:
189
+ print(f"\n--- /get_grad_matrix request received ---")
190
+ print(f"grad:Model: {req.model}")
191
+ print(f"grad:Task: {req.task}")
192
+ print(f"grad:sentence: {req.sentence}")
193
+ print(f"grad: hypothesis: {req.hypothesis}")
194
+ print(f"predict: maskID: {req.maskID}")
195
+
196
+
197
+
198
+ try:
199
+ vis_class = VISUALIZER_CLASSES.get(req.model)
200
+ if vis_class is None:
201
+ return {"error": f"Unknown model: {req.model}"}
202
+ #if any(p.device.type == 'meta' for p in vis.model.parameters()):
203
+ # vis.model = torch.nn.Module.to_empty(vis.model, device=torch.device("cpu"))
204
+ vis = vis_class(task=req.task.lower())
205
+ VISUALIZER_CACHE[req.model] = vis
206
+ print("Model reloaded and cached.")
207
+ except Exception as e:
208
+ return {"error": f"Failed to reload model: {str(e)}"}
209
+
210
+
211
+
212
+ print("run function")
213
+ try:
214
+ if req.task.lower()=='mnli':
215
+ grad_matrix, attn_matrix = vis.get_all_grad_attn_matrix(req.task.lower(), req.sentence,hypothesis=req.hypothesis)
216
+ elif req.task.lower()=='mlm':
217
+ grad_matrix, attn_matrix = vis.get_all_grad_attn_matrix(req.task.lower(), req.sentence,maskID=req.maskID)
218
+ else:
219
+ grad_matrix, attn_matrix = vis.get_all_grad_attn_matrix(req.task.lower(), req.sentence)
220
+ except Exception as e:
221
+ print("Exception during grad/attn computation:", e)
222
+ grad_matrix, attn_matrix = e,e
223
+
224
+
225
+ response = {
226
+ "grad_matrix": grad_matrix,
227
+ "attn_matrix": attn_matrix,
228
+ }
229
+ print('grad attn successful')
230
+ return response
231
+ except Exception as e:
232
+ print("SERVER EXCEPTION:", e)
233
+ return {"error": str(e)}
234
+
235
+
236
+
237
+
238
+
239
+
240
+
241
+
242
+
243
+ ##################################################
244
+
245
+
246
+
247
+ @app.get("/ping")
248
+ def ping():
249
+ return {"message": "pong"}
250
+
251
+
252
+
253
+ @app.post("/upload_to_path")
254
+ async def upload_to_path(
255
+ file: UploadFile = File(...),
256
+ dest_path: str = Form(...) # e.g., "models/model.pt"
257
+ ):
258
+ full_path = Path("/data") / dest_path
259
+ full_path.parent.mkdir(parents=True, exist_ok=True)
260
+
261
+ with open(full_path, "wb") as f:
262
+ f.write(await file.read())
263
+
264
+ return {"status": "uploaded", "path": str(full_path)}
265
+
266
+
267
+
268
+
269
+ @app.post("/make_dir")
270
+ def make_directory(
271
+ dir_path: str = Form(...) # e.g., "logs/test_run"
272
+ ):
273
+ full_dir = Path("/data") / dir_path
274
+ full_dir.mkdir(parents=True, exist_ok=True)
275
+ return {"status": "created", "directory": str(full_dir)}
276
+
277
+
278
+
279
+ @app.get("/list_data")
280
+ def list_data():
281
+ base_path = Path("/data")
282
+ all_items = []
283
+
284
+ for path in base_path.rglob("*"): # recursive glob
285
+ all_items.append({
286
+ "path": str(path.relative_to(base_path)),
287
+ "type": "dir" if path.is_dir() else "file",
288
+ "size": path.stat().st_size if path.is_file() else None
289
+ })
290
+
291
+ return {"items": all_items}
292
+
293
+
294
+
295
+
296
+
297
+
298
+
299
+
300
+ @app.post("/purge_data_123456789")
301
+ def purge_data():
302
+ base_path = Path("/data")
303
+ if not base_path.exists():
304
+ return {"status": "error", "message": "/data does not exist"}
305
+
306
+ deleted = []
307
+
308
+ for child in base_path.iterdir():
309
+ try:
310
+ if child.is_file() or child.is_symlink():
311
+ child.unlink()
312
+ elif child.is_dir():
313
+ shutil.rmtree(child)
314
+ deleted.append(str(child.name))
315
+ except Exception as e:
316
+ deleted.append(f"FAILED: {child.name} ({e})")
317
+
318
+ return {
319
+ "status": "done",
320
+ "deleted": deleted,
321
+ "total": len(deleted)
322
+ }
323
+
324
+
325
+
326
+
327
+
328
+
329
+
330
+
331
+
332
+ """
333
+ if __name__ == "__main__":
334
+
335
+ print('rim ')
336
+ BERTVisualizer('mlm')
337
+ BERTVisualizer('mnli')
338
+ BERTVisualizer('sst')
339
+
340
+
341
+ RoBERTaVisualizer('mlm')
342
+ RoBERTaVisualizer('mnli')
343
+ RoBERTaVisualizer('sst')
344
+
345
+
346
+ DistilBERTVisualizer('mlm')
347
+ DistilBERTVisualizer('mnli')
348
+ DistilBERTVisualizer('sst')
349
+ """