Spitfire1970 commited on
Commit
d25a262
·
1 Parent(s): ac37ff3
Files changed (1) hide show
  1. handler.py +3 -2
handler.py CHANGED
@@ -1,10 +1,11 @@
 
1
  import torch
2
  from encoder.model import Encoder
3
 
4
  class EndpointHandler():
5
- def __init__(self, path="6.pt"):
6
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7
- checkpoint = torch.load(path, self.device, weights_only=True)
8
  self.model = Encoder(self.device)
9
  state_dict = checkpoint['model_state']
10
  self.model.load_state_dict(state_dict)
 
1
+ import os
2
  import torch
3
  from encoder.model import Encoder
4
 
5
  class EndpointHandler():
6
+ def __init__(self, model_dir):
7
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
+ checkpoint = torch.load(os.path.join(model_dir, "6.pt"), self.device, weights_only=True)
9
  self.model = Encoder(self.device)
10
  state_dict = checkpoint['model_state']
11
  self.model.load_state_dict(state_dict)