Spaces:
Sleeping
Sleeping
import torch | |
from torch import nn | |
from torch import optim | |
from torch import functional as F | |
from einops import rearrange | |
import os | |
import pickle | |
#from modules.utils import * | |
from .utils import * | |
class Encoder(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
self.rnn = nn.RNN(input_size=config['z_dim'], | |
hidden_size=config['hidden_dim'], | |
num_layers=config['num_layer']) | |
self.fc = nn.Linear(in_features=config['hidden_dim'], | |
out_features=config['hidden_dim']) | |
def forward(self, x): | |
x_enc, _ = self.rnn(x) | |
x_enc = self.fc(x_enc) | |
return x_enc | |
class Decoder(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
self.rnn = nn.RNN(input_size=config['hidden_dim'], | |
hidden_size=config['hidden_dim'], | |
num_layers=config['num_layer']) | |
self.fc = nn.Linear(in_features=config['hidden_dim'], | |
out_features=config['z_dim']) | |
def forward(self, x_enc): | |
x_dec, _ = self.rnn(x_enc) | |
x_dec = self.fc(x_dec) | |
return x_dec | |
class Interpolator(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
self.sequence_inter = nn.Linear(in_features=(config['ts_size'] - config['total_mask_size']), | |
out_features=config['ts_size']) | |
self.feature_inter = nn.Linear(in_features=config['hidden_dim'], | |
out_features=config['hidden_dim']) | |
def forward(self, x): | |
# x(bs, vis_size, hidden_dim) | |
x = rearrange(x, 'b l f -> b f l') # x(bs, hidden_dim, vis_size) | |
x = self.sequence_inter(x) # x(bs, hidden_dim, ts_size) | |
x = rearrange(x, 'b f l -> b l f') # x(bs, ts_size, hidden_dim) | |
x = self.feature_inter(x) # x(bs, ts_size, hidden_dim) | |
return x | |
class StockEmbedder(nn.Module): | |
def __init__(self, cfg: dict = None) -> None: | |
""" | |
Args: | |
cfg (dict): { | |
'ts_size': 24, | |
'mask_size': 1, | |
'num_masks': 3, | |
'hidden_dim': 12, | |
'embed_dim': 6, | |
'num_layer': 3, | |
'z_dim': 6, | |
'num_embed': 32, | |
'stock_features': [], | |
'min_val': 0, | |
'max_val': 1e6 | |
} | |
""" | |
super().__init__() | |
self.config = cfg | |
self.config['total_mask_size'] = self.config['num_masks'] * self.config['mask_size'] | |
self.encoder = Encoder(config=self.config) | |
self.interpolator = Interpolator(config=self.config) | |
self.decoder = Decoder(config=self.config) | |
print('StockEmbedder initialized') | |
def mask_it(self, | |
x: torch.Tensor, | |
masks: torch.Tensor): | |
# x.shape = (bs, ts_size, z_dim) | |
b, l, f = x.shape | |
x_visible = x[~masks.bool(), :].reshape(b, -1, f) # (bs, vis_size, z_dim) | |
return x_visible | |
def forward_ae(self, x: torch.Tensor): | |
"""mae_pseudo_mask is equivalent to the Autoencoder | |
There is no interpolator in this mode | |
Args: | |
x (torch.Tensor): shape: (bs, ts_size, z_dim) | |
""" | |
out_encoder = self.encoder(x) | |
out_decoder = self.decoder(out_encoder) | |
return out_encoder, out_decoder | |
def forward_mae(self, | |
x: torch.Tensor, | |
masks: torch.Tensor): | |
"""No mask tokens, using Interpolation in the latent space | |
Args: | |
x (torch.Tensor): shape: (bs, ts_size, z_dim) | |
masks (torch.Tensor): | |
""" | |
x_vis = self.mask_it(x, masks=masks) # (bs, vis_size, z_dim) | |
out_encoder = self.encoder(x_vis) # (bs, vis_size, hidden_dim) | |
out_interpolator = self.interpolator(out_encoder) # (bs, ts_size, hidden_dim) | |
out_decoder = self.decoder(out_interpolator) # (bs, ts_size, z_dim) | |
return out_encoder, out_interpolator, out_decoder | |
def forward(self, | |
x: torch.Tensor, | |
masks: torch.Tensor = None, | |
mode: str = 'ae | mae'): | |
x = torch.tensor(x, dtype=torch.float32) | |
if masks is not None: | |
masks = torch.tensor(masks, dtype=torch.float32) | |
if mode == 'ae': | |
out_encoder, out_decoder = self.forward_ae(x) | |
return out_encoder, out_decoder | |
elif mode == 'mae': | |
out_encoder, out_interpolator, out_decoder = self.forward_mae(x, masks=masks) | |
return out_encoder, out_interpolator, out_decoder | |
def get_embedding(self, | |
stock_data: torch.Tensor, | |
embedding_used: str = 'encoder | decoder'): | |
"""get stock_embedding | |
Args: | |
stock_data (torch.Tensor): shape = (batch_size, stock_days, stock_features); NORMALIZED | |
""" | |
with torch.no_grad(): | |
out_encoder, out_decoder = self.forward(stock_data, masks=None, mode='ae') | |
if embedding_used == 'encoder': | |
stock_embedding = out_encoder | |
elif embedding_used == 'decoder': | |
stock_embedding = out_decoder | |
return stock_embedding | |
def save(self, model_dir: str): | |
os.makedirs(model_dir, exist_ok=True) | |
# Save model: | |
torch.save(obj=self.state_dict(), f=os.path.join(model_dir, 'model.pth')) | |
# Save config: | |
with open(file=os.path.join(model_dir, 'config.pkl'), mode='wb') as f: | |
pickle.dump(obj=self.config, file=f) |