CoffeBank commited on
Commit
e4082c9
·
1 Parent(s): c739e17
Files changed (1) hide show
  1. model_utils.py +3 -4
model_utils.py CHANGED
@@ -7,7 +7,6 @@ from NN_classifier.simple_binary_classifier import Medium_Binary_Network
7
  from feature_extraction import extract_features
8
  import pandas as pd
9
 
10
- DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
11
 
12
  def load_model(model_dir='models/medium_binary_classifier'):
13
  model_path = os.path.join(model_dir, 'nn_model.pt')
@@ -29,8 +28,8 @@ def load_model(model_dir='models/medium_binary_classifier'):
29
 
30
  input_size = scaler.n_features_in_
31
 
32
- model = Medium_Binary_Network(input_size, hidden_sizes=[256, 192, 128, 64], dropout=0.3).to(DEVICE)
33
- model.load_state_dict(torch.load(model_path, map_location=DEVICE))
34
  model.eval()
35
 
36
  if imputer is not None:
@@ -76,7 +75,7 @@ def classify_text(text, model, scaler, label_encoder, imputer=None, scores=None)
76
 
77
  features_scaled = scaler.transform(features)
78
 
79
- features_tensor = torch.FloatTensor(features_scaled).to(DEVICE)
80
 
81
  with torch.no_grad():
82
  outputs = model(features_tensor)
 
7
  from feature_extraction import extract_features
8
  import pandas as pd
9
 
 
10
 
11
  def load_model(model_dir='models/medium_binary_classifier'):
12
  model_path = os.path.join(model_dir, 'nn_model.pt')
 
28
 
29
  input_size = scaler.n_features_in_
30
 
31
+ model = Medium_Binary_Network(input_size, hidden_sizes=[256, 192, 128, 64], dropout=0.3)
32
+ model.load_state_dict(torch.load(model_path))
33
  model.eval()
34
 
35
  if imputer is not None:
 
75
 
76
  features_scaled = scaler.transform(features)
77
 
78
+ features_tensor = torch.FloatTensor(features_scaled)
79
 
80
  with torch.no_grad():
81
  outputs = model(features_tensor)