import torch from torch import nn from torch import optim from torch import functional as F from einops import rearrange import os import pickle #from modules.utils import * from .utils import * class Encoder(nn.Module): def __init__(self, config): super().__init__() self.rnn = nn.RNN(input_size=config['z_dim'], hidden_size=config['hidden_dim'], num_layers=config['num_layer']) self.fc = nn.Linear(in_features=config['hidden_dim'], out_features=config['hidden_dim']) def forward(self, x): x_enc, _ = self.rnn(x) x_enc = self.fc(x_enc) return x_enc class Decoder(nn.Module): def __init__(self, config): super().__init__() self.rnn = nn.RNN(input_size=config['hidden_dim'], hidden_size=config['hidden_dim'], num_layers=config['num_layer']) self.fc = nn.Linear(in_features=config['hidden_dim'], out_features=config['z_dim']) def forward(self, x_enc): x_dec, _ = self.rnn(x_enc) x_dec = self.fc(x_dec) return x_dec class Interpolator(nn.Module): def __init__(self, config): super().__init__() self.sequence_inter = nn.Linear(in_features=(config['ts_size'] - config['total_mask_size']), out_features=config['ts_size']) self.feature_inter = nn.Linear(in_features=config['hidden_dim'], out_features=config['hidden_dim']) def forward(self, x): # x(bs, vis_size, hidden_dim) x = rearrange(x, 'b l f -> b f l') # x(bs, hidden_dim, vis_size) x = self.sequence_inter(x) # x(bs, hidden_dim, ts_size) x = rearrange(x, 'b f l -> b l f') # x(bs, ts_size, hidden_dim) x = self.feature_inter(x) # x(bs, ts_size, hidden_dim) return x class StockEmbedder(nn.Module): def __init__(self, cfg: dict = None) -> None: """ Args: cfg (dict): { 'ts_size': 24, 'mask_size': 1, 'num_masks': 3, 'hidden_dim': 12, 'embed_dim': 6, 'num_layer': 3, 'z_dim': 6, 'num_embed': 32, 'stock_features': [], 'min_val': 0, 'max_val': 1e6 } """ super().__init__() self.config = cfg self.config['total_mask_size'] = self.config['num_masks'] * self.config['mask_size'] self.encoder = Encoder(config=self.config) self.interpolator = Interpolator(config=self.config) self.decoder = Decoder(config=self.config) print('StockEmbedder initialized') def mask_it(self, x: torch.Tensor, masks: torch.Tensor): # x.shape = (bs, ts_size, z_dim) b, l, f = x.shape x_visible = x[~masks.bool(), :].reshape(b, -1, f) # (bs, vis_size, z_dim) return x_visible def forward_ae(self, x: torch.Tensor): """mae_pseudo_mask is equivalent to the Autoencoder There is no interpolator in this mode Args: x (torch.Tensor): shape: (bs, ts_size, z_dim) """ out_encoder = self.encoder(x) out_decoder = self.decoder(out_encoder) return out_encoder, out_decoder def forward_mae(self, x: torch.Tensor, masks: torch.Tensor): """No mask tokens, using Interpolation in the latent space Args: x (torch.Tensor): shape: (bs, ts_size, z_dim) masks (torch.Tensor): """ x_vis = self.mask_it(x, masks=masks) # (bs, vis_size, z_dim) out_encoder = self.encoder(x_vis) # (bs, vis_size, hidden_dim) out_interpolator = self.interpolator(out_encoder) # (bs, ts_size, hidden_dim) out_decoder = self.decoder(out_interpolator) # (bs, ts_size, z_dim) return out_encoder, out_interpolator, out_decoder def forward(self, x: torch.Tensor, masks: torch.Tensor = None, mode: str = 'ae | mae'): x = torch.tensor(x, dtype=torch.float32) if masks is not None: masks = torch.tensor(masks, dtype=torch.float32) if mode == 'ae': out_encoder, out_decoder = self.forward_ae(x) return out_encoder, out_decoder elif mode == 'mae': out_encoder, out_interpolator, out_decoder = self.forward_mae(x, masks=masks) return out_encoder, out_interpolator, out_decoder def get_embedding(self, stock_data: torch.Tensor, embedding_used: str = 'encoder | decoder'): """get stock_embedding Args: stock_data (torch.Tensor): shape = (batch_size, stock_days, stock_features); NORMALIZED """ with torch.no_grad(): out_encoder, out_decoder = self.forward(stock_data, masks=None, mode='ae') if embedding_used == 'encoder': stock_embedding = out_encoder elif embedding_used == 'decoder': stock_embedding = out_decoder return stock_embedding def save(self, model_dir: str): os.makedirs(model_dir, exist_ok=True) # Save model: torch.save(obj=self.state_dict(), f=os.path.join(model_dir, 'model.pth')) # Save config: with open(file=os.path.join(model_dir, 'config.pkl'), mode='wb') as f: pickle.dump(obj=self.config, file=f)