yifan0sun commited on
Commit
2e03d42
·
verified ·
1 Parent(s): fb57040

Update DISTILLBERTmodel.py

Browse files
Files changed (1) hide show
  1. DISTILLBERTmodel.py +3 -3
DISTILLBERTmodel.py CHANGED
@@ -37,7 +37,7 @@ class DistilBERTVisualizer(TransformerVisualizer):
37
  MODEL = 'distilbert-base-uncased'
38
  LOCAL_PATH = os.path.join(CACHE_DIR, "models",MODEL)
39
 
40
- self.model = DistilBertForMaskedLM.from_pretrained( LOCAL_PATH, local_files_only=True )
41
  """
42
  try:
43
  except Exception as e:
@@ -48,7 +48,7 @@ class DistilBERTVisualizer(TransformerVisualizer):
48
  MODEL = 'distilbert-base-uncased-finetuned-sst-2-english'
49
  LOCAL_PATH = os.path.join(CACHE_DIR, "models",MODEL)
50
 
51
- self.model = DistilBertForSequenceClassification.from_pretrained( LOCAL_PATH, local_files_only=True )
52
  """
53
  try:
54
  self.model = DistilBertForSequenceClassification.from_pretrained( LOCAL_PATH, local_files_only=True )
@@ -62,7 +62,7 @@ class DistilBERTVisualizer(TransformerVisualizer):
62
  LOCAL_PATH = os.path.join(CACHE_DIR, "models",MODEL)
63
 
64
 
65
- self.model = DistilBertForSequenceClassification.from_pretrained( LOCAL_PATH, local_files_only=True)
66
  """
67
  try:
68
  self.model = DistilBertForSequenceClassification.from_pretrained( LOCAL_PATH, local_files_only=True)
 
37
  MODEL = 'distilbert-base-uncased'
38
  LOCAL_PATH = os.path.join(CACHE_DIR, "models",MODEL)
39
 
40
+ self.model = DistilBertForMaskedLM.from_pretrained( LOCAL_PATH, local_files_only=True ).to(self.device)
41
  """
42
  try:
43
  except Exception as e:
 
48
  MODEL = 'distilbert-base-uncased-finetuned-sst-2-english'
49
  LOCAL_PATH = os.path.join(CACHE_DIR, "models",MODEL)
50
 
51
+ self.model = DistilBertForSequenceClassification.from_pretrained( LOCAL_PATH, local_files_only=True ).to(self.device)
52
  """
53
  try:
54
  self.model = DistilBertForSequenceClassification.from_pretrained( LOCAL_PATH, local_files_only=True )
 
62
  LOCAL_PATH = os.path.join(CACHE_DIR, "models",MODEL)
63
 
64
 
65
+ self.model = DistilBertForSequenceClassification.from_pretrained( LOCAL_PATH, local_files_only=True).to(self.device)
66
  """
67
  try:
68
  self.model = DistilBertForSequenceClassification.from_pretrained( LOCAL_PATH, local_files_only=True)