Stock_Embedder / Models /stock_embedder.py
Huy0502's picture
Update Models/stock_embedder.py
943db67 verified
import torch
from torch import nn
from torch import optim
from torch import functional as F
from einops import rearrange
import os
import pickle
#from modules.utils import *
from .utils import *
class Encoder(nn.Module):
def __init__(self, config):
super().__init__()
self.rnn = nn.RNN(input_size=config['z_dim'],
hidden_size=config['hidden_dim'],
num_layers=config['num_layer'])
self.fc = nn.Linear(in_features=config['hidden_dim'],
out_features=config['hidden_dim'])
def forward(self, x):
x_enc, _ = self.rnn(x)
x_enc = self.fc(x_enc)
return x_enc
class Decoder(nn.Module):
def __init__(self, config):
super().__init__()
self.rnn = nn.RNN(input_size=config['hidden_dim'],
hidden_size=config['hidden_dim'],
num_layers=config['num_layer'])
self.fc = nn.Linear(in_features=config['hidden_dim'],
out_features=config['z_dim'])
def forward(self, x_enc):
x_dec, _ = self.rnn(x_enc)
x_dec = self.fc(x_dec)
return x_dec
class Interpolator(nn.Module):
def __init__(self, config):
super().__init__()
self.sequence_inter = nn.Linear(in_features=(config['ts_size'] - config['total_mask_size']),
out_features=config['ts_size'])
self.feature_inter = nn.Linear(in_features=config['hidden_dim'],
out_features=config['hidden_dim'])
def forward(self, x):
# x(bs, vis_size, hidden_dim)
x = rearrange(x, 'b l f -> b f l') # x(bs, hidden_dim, vis_size)
x = self.sequence_inter(x) # x(bs, hidden_dim, ts_size)
x = rearrange(x, 'b f l -> b l f') # x(bs, ts_size, hidden_dim)
x = self.feature_inter(x) # x(bs, ts_size, hidden_dim)
return x
class StockEmbedder(nn.Module):
def __init__(self, cfg: dict = None) -> None:
"""
Args:
cfg (dict): {
'ts_size': 24,
'mask_size': 1,
'num_masks': 3,
'hidden_dim': 12,
'embed_dim': 6,
'num_layer': 3,
'z_dim': 6,
'num_embed': 32,
'stock_features': [],
'min_val': 0,
'max_val': 1e6
}
"""
super().__init__()
self.config = cfg
self.config['total_mask_size'] = self.config['num_masks'] * self.config['mask_size']
self.encoder = Encoder(config=self.config)
self.interpolator = Interpolator(config=self.config)
self.decoder = Decoder(config=self.config)
print('StockEmbedder initialized')
def mask_it(self,
x: torch.Tensor,
masks: torch.Tensor):
# x.shape = (bs, ts_size, z_dim)
b, l, f = x.shape
x_visible = x[~masks.bool(), :].reshape(b, -1, f) # (bs, vis_size, z_dim)
return x_visible
def forward_ae(self, x: torch.Tensor):
"""mae_pseudo_mask is equivalent to the Autoencoder
There is no interpolator in this mode
Args:
x (torch.Tensor): shape: (bs, ts_size, z_dim)
"""
out_encoder = self.encoder(x)
out_decoder = self.decoder(out_encoder)
return out_encoder, out_decoder
def forward_mae(self,
x: torch.Tensor,
masks: torch.Tensor):
"""No mask tokens, using Interpolation in the latent space
Args:
x (torch.Tensor): shape: (bs, ts_size, z_dim)
masks (torch.Tensor):
"""
x_vis = self.mask_it(x, masks=masks) # (bs, vis_size, z_dim)
out_encoder = self.encoder(x_vis) # (bs, vis_size, hidden_dim)
out_interpolator = self.interpolator(out_encoder) # (bs, ts_size, hidden_dim)
out_decoder = self.decoder(out_interpolator) # (bs, ts_size, z_dim)
return out_encoder, out_interpolator, out_decoder
def forward(self,
x: torch.Tensor,
masks: torch.Tensor = None,
mode: str = 'ae | mae'):
x = torch.tensor(x, dtype=torch.float32)
if masks is not None:
masks = torch.tensor(masks, dtype=torch.float32)
if mode == 'ae':
out_encoder, out_decoder = self.forward_ae(x)
return out_encoder, out_decoder
elif mode == 'mae':
out_encoder, out_interpolator, out_decoder = self.forward_mae(x, masks=masks)
return out_encoder, out_interpolator, out_decoder
def get_embedding(self,
stock_data: torch.Tensor,
embedding_used: str = 'encoder | decoder'):
"""get stock_embedding
Args:
stock_data (torch.Tensor): shape = (batch_size, stock_days, stock_features); NORMALIZED
"""
with torch.no_grad():
out_encoder, out_decoder = self.forward(stock_data, masks=None, mode='ae')
if embedding_used == 'encoder':
stock_embedding = out_encoder
elif embedding_used == 'decoder':
stock_embedding = out_decoder
return stock_embedding
def save(self, model_dir: str):
os.makedirs(model_dir, exist_ok=True)
# Save model:
torch.save(obj=self.state_dict(), f=os.path.join(model_dir, 'model.pth'))
# Save config:
with open(file=os.path.join(model_dir, 'config.pkl'), mode='wb') as f:
pickle.dump(obj=self.config, file=f)