| import os |
| import torch |
|
|
| |
| from models.LSTM import LSTM |
| from models.LSTNet import LSTNet |
| from models.Transformer import Transformer |
| from models.Autoformer import Autoformer |
| from models.Informer import Informer |
| from models.PatchTST import PatchTST |
| from models.TimesNet import TimesNet |
| from models.TimesFM import TimesFM |
|
|
| |
| from model_kwargs import * |
|
|
| |
| |
| lookback, lookahead, heterogeneity = 512, 48, 'HET' |
|
|
| if __name__ == "__main__": |
|
|
| models = [LSTM, LSTNet, Transformer, Autoformer, Informer, PatchTST, TimesNet, TimesFM] |
| kw_fns = [lstm_kwargs, lstnet_kwargs, transformer_kwargs, autoformer_kwargs, informer_kwargs, patchtst_kwargs, timesnet_kwargs, timesfm_kwargs] |
|
|
| |
| for model_class, kw_fn in zip(models,kw_fns): |
| |
| model = model_class(**kw_fn(lookback = lookback, lookahead = lookahead)) |
| |
| result = model.load_state_dict(torch.load(os.path.join(*[os.getcwd(),'weights',f'{model_class.__name__}_L_{lookback}_T_{lookahead}_{heterogeneity}.pth']),map_location='cpu')) |
| |
| print(f"Loading weight for model {model_class.__name__}, lookback {lookback}, lookahead {lookahead}, heterogeneity {heterogeneity}, and the result was: {result}.") |
|
|