suryadev1 commited on
Commit
61a85a4
·
verified ·
1 Parent(s): 168746a
Files changed (1) hide show
  1. new_test_saved_finetuned_model.py +2 -4
new_test_saved_finetuned_model.py CHANGED
@@ -162,10 +162,8 @@ class BERTFineTuneTrainer:
162
  logits = self.model.forward(data["input"], data["segment_label"], data["feat"])
163
 
164
  logits = logits.cpu()
165
- devic = logits.device # or self.model.device if available
166
- labels = data["label"].to(devic)
167
-
168
- loss = self.criterion(logits, data["label"])
169
  # if torch.cuda.device_count() > 1:
170
  # loss = loss.mean()
171
 
 
162
  logits = self.model.forward(data["input"], data["segment_label"], data["feat"])
163
 
164
  logits = logits.cpu()
165
+ labels = data["label"].to(logits.device)
166
+ loss = self.criterion(logits, labels)
 
 
167
  # if torch.cuda.device_count() > 1:
168
  # loss = loss.mean()
169