mboss commited on
Commit
cc2ade2
·
1 Parent(s): d232006

Fix loading without cuda

Browse files
Files changed (1) hide show
  1. marble.py +4 -1
marble.py CHANGED
@@ -59,7 +59,10 @@ def setup_control_mlp(
59
 
60
  net = control_mlp(features)
61
  net.load_state_dict(
62
- torch.load(os.path.join(file_dir, f"model_weights/{material_parameter}.pt"))
 
 
 
63
  )
64
  net.to(device, dtype=dtype)
65
  net.eval()
 
59
 
60
  net = control_mlp(features)
61
  net.load_state_dict(
62
+ torch.load(
63
+ os.path.join(file_dir, f"model_weights/{material_parameter}.pt"),
64
+ map_location=device
65
+ )
66
  )
67
  net.to(device, dtype=dtype)
68
  net.eval()