Spitfire1970
commited on
Commit
·
d25a262
1
Parent(s):
ac37ff3
fix
Browse files- handler.py +3 -2
handler.py
CHANGED
@@ -1,10 +1,11 @@
|
|
|
|
1 |
import torch
|
2 |
from encoder.model import Encoder
|
3 |
|
4 |
class EndpointHandler():
|
5 |
-
def __init__(self,
|
6 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
7 |
-
checkpoint = torch.load(path, self.device, weights_only=True)
|
8 |
self.model = Encoder(self.device)
|
9 |
state_dict = checkpoint['model_state']
|
10 |
self.model.load_state_dict(state_dict)
|
|
|
1 |
+
import os
|
2 |
import torch
|
3 |
from encoder.model import Encoder
|
4 |
|
5 |
class EndpointHandler():
|
6 |
+
def __init__(self, model_dir):
|
7 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
8 |
+
checkpoint = torch.load(os.path.join(model_dir, "6.pt"), self.device, weights_only=True)
|
9 |
self.model = Encoder(self.device)
|
10 |
state_dict = checkpoint['model_state']
|
11 |
self.model.load_state_dict(state_dict)
|