yifan0sun commited on
Commit
fb57040
·
verified ·
1 Parent(s): 013177a

Update BERTmodel.py

Browse files
Files changed (1) hide show
  1. BERTmodel.py +2 -2
BERTmodel.py CHANGED
@@ -53,7 +53,7 @@ class BERTVisualizer(TransformerVisualizer):
53
  MODEL = "textattack_bert-base-uncased-SST-2"
54
  LOCAL_PATH = os.path.join(CACHE_DIR, "models",MODEL)
55
 
56
- self.model = BertForSequenceClassification.from_pretrained( LOCAL_PATH, local_files_only=True, device_map=None )
57
  """
58
  try:
59
  self.model = BertForSequenceClassification.from_pretrained( LOCAL_PATH, local_files_only=True, device_map=None )
@@ -68,7 +68,7 @@ class BERTVisualizer(TransformerVisualizer):
68
 
69
  LOCAL_PATH = os.path.join(CACHE_DIR, "models",MODEL)
70
 
71
- self.model = BertForSequenceClassification.from_pretrained( LOCAL_PATH, local_files_only=True, device_map=None )
72
  """
73
  try:
74
  self.model = BertForSequenceClassification.from_pretrained( LOCAL_PATH, local_files_only=True, device_map=None )
 
53
  MODEL = "textattack_bert-base-uncased-SST-2"
54
  LOCAL_PATH = os.path.join(CACHE_DIR, "models",MODEL)
55
 
56
+ self.model = BertForSequenceClassification.from_pretrained( LOCAL_PATH, local_files_only=True, device_map=None ).to(self.device)
57
  """
58
  try:
59
  self.model = BertForSequenceClassification.from_pretrained( LOCAL_PATH, local_files_only=True, device_map=None )
 
68
 
69
  LOCAL_PATH = os.path.join(CACHE_DIR, "models",MODEL)
70
 
71
+ self.model = BertForSequenceClassification.from_pretrained( LOCAL_PATH, local_files_only=True, device_map=None ).to(self.device)
72
  """
73
  try:
74
  self.model = BertForSequenceClassification.from_pretrained( LOCAL_PATH, local_files_only=True, device_map=None )