yifan0sun commited on
Commit
9a86cc4
·
1 Parent(s): a44938d
Files changed (4) hide show
  1. BERTmodel.py +55 -10
  2. DISTILLBERTmodel.py +53 -6
  3. ROBERTAmodel.py +46 -5
  4. server.py +16 -12
BERTmodel.py CHANGED
@@ -10,7 +10,7 @@ from transformers import (
10
  BertForSequenceClassification,
11
  )
12
  import torch.nn.functional as F
13
-
14
 
15
  CACHE_DIR = "./hf_cache"
16
 
@@ -18,22 +18,67 @@ CACHE_DIR = "./hf_cache"
18
  class BERTVisualizer(TransformerVisualizer):
19
  def __init__(self,task):
20
  super().__init__()
21
- print(task,'BERTVIS START')
22
  self.task = task
23
- self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', cache_dir=CACHE_DIR)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  print('finding model', self.task)
25
  if self.task == 'mlm':
26
- self.model = BertForMaskedLM.from_pretrained(
27
- "bert-base-uncased",
28
- attn_implementation="eager", # fallback to standard attention
29
- cache_dir=CACHE_DIR
30
- ).to(self.device)
 
 
 
 
 
31
  elif self.task == 'sst':
32
- self.model = BertForSequenceClassification.from_pretrained("textattack/bert-base-uncased-SST-2",device_map=None, cache_dir=CACHE_DIR)
 
 
 
 
 
 
 
 
 
 
33
  elif self.task == 'mnli':
34
- self.model = BertForSequenceClassification.from_pretrained("textattack/bert-base-uncased-MNLI", device_map=None, cache_dir=CACHE_DIR)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  else:
36
  raise ValueError(f"Unsupported task: {self.task}")
 
 
 
37
  print('model found')
38
  #self.model.to(self.device)
39
  print('self device junk')
 
10
  BertForSequenceClassification,
11
  )
12
  import torch.nn.functional as F
13
+ import os
14
 
15
  CACHE_DIR = "./hf_cache"
16
 
 
18
  class BERTVisualizer(TransformerVisualizer):
19
  def __init__(self,task):
20
  super().__init__()
 
21
  self.task = task
22
+ print(task,'BERTVIS START')
23
+
24
+ TOKENIZER = 'bert-base-uncased'
25
+ LOCAL_PATH = os.path.join(CACHE_DIR, "tokenizers",TOKENIZER.replace("/", "_"))
26
+
27
+
28
+
29
+ try:
30
+ self.tokenizer = BertTokenizer.from_pretrained(LOCAL_PATH, local_files_only=True)
31
+ except Exception as e:
32
+ self.tokenizer = BertTokenizer.from_pretrained(TOKENIZER)
33
+ self.tokenizer.save_pretrained(LOCAL_PATH)
34
+
35
+
36
+
37
  print('finding model', self.task)
38
  if self.task == 'mlm':
39
+
40
+ MODEL = 'bert-base-uncased'
41
+ LOCAL_PATH = os.path.join(CACHE_DIR, "models",MODEL.replace("/", "_"))
42
+
43
+ try:
44
+ self.model = BertForMaskedLM.from_pretrained( LOCAL_PATH, local_files_only=True, attn_implementation="eager" ).to(self.device)
45
+ except Exception as e:
46
+ self.model = BertForMaskedLM.from_pretrained( MODEL, attn_implementation="eager" ).to(self.device)
47
+ self.model.save_pretrained(LOCAL_PATH)
48
+
49
  elif self.task == 'sst':
50
+ MODEL = "textattack/bert-base-uncased-SST-2"
51
+ LOCAL_PATH = os.path.join(CACHE_DIR, "models",MODEL.replace("/", "_"))
52
+
53
+
54
+ try:
55
+ self.model = BertForSequenceClassification.from_pretrained( LOCAL_PATH, local_files_only=True, device_map=None )
56
+ except Exception as e:
57
+ self.model = BertForSequenceClassification.from_pretrained( MODEL, device_map=None )
58
+ self.model.save_pretrained(LOCAL_PATH)
59
+
60
+
61
  elif self.task == 'mnli':
62
+ MODEL = 'textattack/bert-base-uncased-MNLI'
63
+
64
+
65
+ LOCAL_PATH = os.path.join(CACHE_DIR, "models",MODEL.replace("/", "_"))
66
+
67
+
68
+ try:
69
+ self.model = BertForSequenceClassification.from_pretrained( LOCAL_PATH, local_files_only=True, device_map=None )
70
+ except Exception as e:
71
+ self.model = BertForSequenceClassification.from_pretrained( MODEL, device_map=None)
72
+ self.model.save_pretrained(LOCAL_PATH)
73
+
74
+
75
+
76
+
77
  else:
78
  raise ValueError(f"Unsupported task: {self.task}")
79
+
80
+
81
+
82
  print('model found')
83
  #self.model.to(self.device)
84
  print('self device junk')
DISTILLBERTmodel.py CHANGED
@@ -4,7 +4,6 @@ import torch.nn.functional as F
4
 
5
 
6
  import os
7
- from transformers import DistilBertModel, DistilBertTokenizer
8
  from models import TransformerVisualizer
9
 
10
  from transformers import (
@@ -17,17 +16,65 @@ class DistilBERTVisualizer(TransformerVisualizer):
17
  def __init__(self, task):
18
  super().__init__()
19
  self.task = task
20
- self.tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased', cache_dir=CACHE_DIR)
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  if self.task == 'mlm':
22
- self.model = DistilBertForMaskedLM.from_pretrained('distilbert-base-uncased', cache_dir=CACHE_DIR)
 
 
 
 
 
 
 
 
 
23
  elif self.task == 'sst':
24
- self.model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased-finetuned-sst-2-english', cache_dir=CACHE_DIR)
 
 
 
 
 
 
 
 
 
 
25
  elif self.task == 'mnli':
26
- self.model = DistilBertForSequenceClassification.from_pretrained("textattack/distilbert-base-uncased-MNLI", cache_dir=CACHE_DIR)
 
 
 
 
 
 
 
 
 
 
 
27
 
 
28
 
29
  else:
30
- raise NotImplementedError("Task not supported for DistilBERT")
 
 
 
 
31
 
32
 
33
  self.model.eval()
 
4
 
5
 
6
  import os
 
7
  from models import TransformerVisualizer
8
 
9
  from transformers import (
 
16
  def __init__(self, task):
17
  super().__init__()
18
  self.task = task
19
+
20
+
21
+ TOKENIZER = 'distilbert-base-uncased'
22
+ LOCAL_PATH = os.path.join(CACHE_DIR, "tokenizers",TOKENIZER.replace("/", "_"))
23
+
24
+ try:
25
+ self.tokenizer = DistilBertTokenizer.from_pretrained(LOCAL_PATH, local_files_only=True)
26
+ except Exception as e:
27
+ self.tokenizer = DistilBertTokenizer.from_pretrained(TOKENIZER)
28
+ self.tokenizer.save_pretrained(LOCAL_PATH)
29
+
30
+
31
+
32
+ print('finding model', self.task)
33
  if self.task == 'mlm':
34
+
35
+ MODEL = 'distilbert-base-uncased'
36
+ LOCAL_PATH = os.path.join(CACHE_DIR, "models",MODEL.replace("/", "_"))
37
+
38
+ try:
39
+ self.model = DistilBertForMaskedLM.from_pretrained( LOCAL_PATH, local_files_only=True )
40
+ except Exception as e:
41
+ self.model = DistilBertForMaskedLM.from_pretrained( MODEL )
42
+ self.model.save_pretrained(LOCAL_PATH)
43
+
44
  elif self.task == 'sst':
45
+ MODEL = 'distilbert-base-uncased-finetuned-sst-2-english'
46
+ LOCAL_PATH = os.path.join(CACHE_DIR, "models",MODEL.replace("/", "_"))
47
+
48
+
49
+ try:
50
+ self.model = DistilBertForSequenceClassification.from_pretrained( LOCAL_PATH, local_files_only=True )
51
+ except Exception as e:
52
+ self.model = DistilBertForSequenceClassification.from_pretrained( MODEL )
53
+ self.model.save_pretrained(LOCAL_PATH)
54
+
55
+
56
  elif self.task == 'mnli':
57
+ MODEL = "textattack/distilbert-base-uncased-MNLI"
58
+
59
+
60
+ LOCAL_PATH = os.path.join(CACHE_DIR, "models",MODEL.replace("/", "_"))
61
+
62
+
63
+ try:
64
+ self.model = DistilBertForSequenceClassification.from_pretrained( LOCAL_PATH, local_files_only=True)
65
+ except Exception as e:
66
+ self.model = DistilBertForSequenceClassification.from_pretrained( MODEL)
67
+ self.model.save_pretrained(LOCAL_PATH)
68
+
69
 
70
+
71
 
72
  else:
73
+ raise ValueError(f"Unsupported task: {self.task}")
74
+
75
+
76
+
77
+
78
 
79
 
80
  self.model.eval()
ROBERTAmodel.py CHANGED
@@ -5,19 +5,60 @@ from models import TransformerVisualizer
5
  from transformers import (
6
  RobertaForMaskedLM, RobertaForSequenceClassification
7
  )
8
-
9
  CACHE_DIR = "./hf_cache"
10
  class RoBERTaVisualizer(TransformerVisualizer):
11
  def __init__(self, task):
12
  super().__init__()
13
  self.task = task
14
- self.tokenizer = RobertaTokenizer.from_pretrained('roberta-base', cache_dir=CACHE_DIR)
 
 
 
 
 
 
 
 
 
 
 
15
  if self.task == 'mlm':
16
- self.model = RobertaForMaskedLM.from_pretrained("roberta-base", cache_dir=CACHE_DIR)
 
 
 
 
 
 
 
 
 
17
  elif self.task == 'sst':
18
- self.model = RobertaForSequenceClassification.from_pretrained('textattack/roberta-base-SST-2', cache_dir=CACHE_DIR)
 
 
 
 
 
 
 
 
 
 
 
19
  elif self.task == 'mnli':
20
- self.model = RobertaForSequenceClassification.from_pretrained("roberta-large-mnli", cache_dir=CACHE_DIR)
 
 
 
 
 
 
 
 
 
 
21
 
22
 
23
  self.model.to(self.device)
 
5
  from transformers import (
6
  RobertaForMaskedLM, RobertaForSequenceClassification
7
  )
8
+ import os
9
  CACHE_DIR = "./hf_cache"
10
  class RoBERTaVisualizer(TransformerVisualizer):
11
  def __init__(self, task):
12
  super().__init__()
13
  self.task = task
14
+
15
+
16
+
17
+ TOKENIZER = 'roberta-base'
18
+ LOCAL_PATH = os.path.join(CACHE_DIR, "tokenizers",TOKENIZER.replace("/", "_"))
19
+
20
+ try:
21
+ self.tokenizer = RobertaTokenizer.from_pretrained(LOCAL_PATH, local_files_only=True)
22
+ except Exception as e:
23
+ self.tokenizer = RobertaTokenizer.from_pretrained(TOKENIZER)
24
+ self.tokenizer.save_pretrained(LOCAL_PATH)
25
+
26
  if self.task == 'mlm':
27
+
28
+ MODEL = "roberta-base"
29
+ LOCAL_PATH = os.path.join(CACHE_DIR, "models",MODEL.replace("/", "_"))
30
+
31
+ try:
32
+ self.model = RobertaForMaskedLM.from_pretrained( LOCAL_PATH, local_files_only=True )
33
+ except Exception as e:
34
+ self.model = RobertaForMaskedLM.from_pretrained( MODEL )
35
+ self.model.save_pretrained(LOCAL_PATH)
36
+
37
  elif self.task == 'sst':
38
+
39
+
40
+ MODEL = 'textattack/roberta-base-SST-2'
41
+ LOCAL_PATH = os.path.join(CACHE_DIR, "models",MODEL.replace("/", "_"))
42
+
43
+ try:
44
+ self.model = RobertaForSequenceClassification.from_pretrained( LOCAL_PATH, local_files_only=True )
45
+ except Exception as e:
46
+ self.model = RobertaForSequenceClassification.from_pretrained( MODEL )
47
+ self.model.save_pretrained(LOCAL_PATH)
48
+
49
+
50
  elif self.task == 'mnli':
51
+ MODEL = "roberta-large-mnli"
52
+ LOCAL_PATH = os.path.join(CACHE_DIR, "models",MODEL.replace("/", "_"))
53
+
54
+
55
+ try:
56
+ self.model = RobertaForSequenceClassification.from_pretrained( LOCAL_PATH, local_files_only=True)
57
+ except Exception as e:
58
+ self.model = RobertaForSequenceClassification.from_pretrained( MODEL)
59
+ self.model.save_pretrained(LOCAL_PATH)
60
+
61
+
62
 
63
 
64
  self.model.to(self.device)
server.py CHANGED
@@ -9,18 +9,6 @@ from ROBERTAmodel import *
9
  from BERTmodel import *
10
  from DISTILLBERTmodel import *
11
 
12
- import shutil
13
- import os
14
-
15
- CACHE_DIR = "./hf_cache"
16
- if os.path.exists(CACHE_DIR):
17
- try:
18
- shutil.rmtree(CACHE_DIR)
19
- print("✅ Cleared hf_cache directory")
20
- except Exception as e:
21
- print("❌ Failed to clear hf_cache:", e)
22
-
23
-
24
  VISUALIZER_CLASSES = {
25
  "BERT": BERTVisualizer,
26
  "RoBERTa": RoBERTaVisualizer,
@@ -229,3 +217,19 @@ def get_grad_attn_matrix(req: GradAttnModelRequest):
229
  return {"error": str(e)}
230
 
231
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  from BERTmodel import *
10
  from DISTILLBERTmodel import *
11
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  VISUALIZER_CLASSES = {
13
  "BERT": BERTVisualizer,
14
  "RoBERTa": RoBERTaVisualizer,
 
217
  return {"error": str(e)}
218
 
219
 
220
+ if __name__ == "__main__":
221
+
222
+ print('rim ')
223
+ BERTVisualizer('mlm')
224
+ BERTVisualizer('mnli')
225
+ BERTVisualizer('sst')
226
+
227
+
228
+ RoBERTaVisualizer('mlm')
229
+ RoBERTaVisualizer('mnli')
230
+ RoBERTaVisualizer('sst')
231
+
232
+
233
+ DistilBERTVisualizer('mlm')
234
+ DistilBERTVisualizer('mnli')
235
+ DistilBERTVisualizer('sst')