File size: 316 Bytes
228af26 |
1 2 3 4 5 6 7 8 9 |
from backend.ml_models.tornado_predictor import TornadoSuperPredictor as Model
import torch
def load(device='cpu'):
m = Model().to(device)
m.load_state_dict(torch.load('pytorch_model.bin', map_location=device))
m.eval(); return m
def apply_temperature(logits, T):
return logits / max(T,1e-6)
|