import math import torch import numpy as np import torch.nn.functional as F from torch import nn from einops import rearrange, reduce, repeat from .model_utils import ( LearnablePositionalEncoding, Conv_MLP, AdaLayerNorm, Transpose, GELU2, series_decomp, ) class TrendBlock(nn.Module): """ Model trend of time series using the polynomial regressor. """ def __init__(self, in_dim, out_dim, in_feat, out_feat, act): super(TrendBlock, self).__init__() trend_poly = 3 self.trend = nn.Sequential( nn.Conv1d( in_channels=in_dim, out_channels=trend_poly, kernel_size=3, padding=1 ), act, Transpose(shape=(1, 2)), nn.Conv1d(in_feat, out_feat, 3, stride=1, padding=1), ) lin_space = torch.arange(1, out_dim + 1, 1) / (out_dim + 1) self.poly_space = torch.stack( [lin_space ** float(p + 1) for p in range(trend_poly)], dim=0 ) def forward(self, input): b, c, h = input.shape x = self.trend(input).transpose(1, 2) trend_vals = torch.matmul(x.transpose(1, 2), self.poly_space.to(x.device)) trend_vals = trend_vals.transpose(1, 2) return trend_vals class MovingBlock(nn.Module): """ Model trend of time series using the moving average. """ def __init__(self, out_dim): super(MovingBlock, self).__init__() size = max(min(int(out_dim / 4), 24), 4) self.decomp = series_decomp(size) def forward(self, input): b, c, h = input.shape x, trend_vals = self.decomp(input) return x, trend_vals class FourierLayer(nn.Module): """ Model seasonality of time series using the inverse DFT. """ def __init__(self, d_model, low_freq=1, factor=1): super().__init__() self.d_model = d_model self.factor = factor self.low_freq = low_freq def forward(self, x): """x: (b, t, d)""" # x = x.to("cpu") if torch.backends.mps.is_available() else x b, t, d = x.shape x_freq = torch.fft.rfft(x, dim=1) if t % 2 == 0: x_freq = x_freq[:, self.low_freq : -1] f = torch.fft.rfftfreq(t)[self.low_freq : -1] else: x_freq = x_freq[:, self.low_freq :] f = torch.fft.rfftfreq(t)[self.low_freq :] x_freq, index_tuple = self.topk_freq(x_freq) f = repeat(f, "f -> b f d", b=x_freq.size(0), d=x_freq.size(2)).to( x_freq.device ) f = rearrange(f[index_tuple], "b f d -> b f () d").to(x_freq.device) result = self.extrapolate(x_freq, f, t) return result # return result.to("mps") if torch.backends.mps.is_available() else result def extrapolate(self, x_freq, f, t): x_freq = torch.cat([x_freq, x_freq.conj()], dim=1) f = torch.cat([f, -f], dim=1) t = rearrange(torch.arange(t, dtype=torch.float), "t -> () () t ()").to( x_freq.device ) amp = rearrange(x_freq.abs(), "b f d -> b f () d") phase = rearrange(x_freq.angle(), "b f d -> b f () d") # x_freq_angle = x_freq.cpu().angle().to(x_freq.device) # print(x_freq.device, x_freq.shape) # def angle(x): return torch.atan2(x.imag, x.real) # print(x_freq.angle().type(), x_freq.angle().device, x_freq.angle().shape) # print(angle(x_freq).type(), angle(x_freq).device, angle(x_freq).shape) # phase = rearrange(angle(x_freq).float(), 'b f d -> b f () d') x_time = amp * torch.cos(2 * math.pi * f * t + phase) return reduce(x_time, "b f t d -> b t d", "sum") def topk_freq(self, x_freq): length = x_freq.shape[1] top_k = int(self.factor * math.log(length)) values, indices = torch.topk( x_freq.abs(), top_k, dim=1, largest=True, sorted=True ) mesh_a, mesh_b = torch.meshgrid( torch.arange(x_freq.size(0)), torch.arange(x_freq.size(2)), indexing="ij" ) index_tuple = (mesh_a.unsqueeze(1), indices, mesh_b.unsqueeze(1)) x_freq = x_freq[index_tuple] return x_freq, index_tuple class SeasonBlock(nn.Module): """ Model seasonality of time series using the Fourier series. """ def __init__(self, in_dim, out_dim, factor=1): super(SeasonBlock, self).__init__() season_poly = factor * min(32, int(out_dim // 2)) self.season = nn.Conv1d( in_channels=in_dim, out_channels=season_poly, kernel_size=1, padding=0 ) fourier_space = torch.arange(0, out_dim, 1) / out_dim p1, p2 = ( (season_poly // 2, season_poly // 2) if season_poly % 2 == 0 else (season_poly // 2, season_poly // 2 + 1) ) s1 = torch.stack( [torch.cos(2 * np.pi * p * fourier_space) for p in range(1, p1 + 1)], dim=0 ) s2 = torch.stack( [torch.sin(2 * np.pi * p * fourier_space) for p in range(1, p2 + 1)], dim=0 ) self.poly_space = torch.cat([s1, s2]) def forward(self, input): b, c, h = input.shape x = self.season(input) season_vals = torch.matmul(x.transpose(1, 2), self.poly_space.to(x.device)) season_vals = season_vals.transpose(1, 2) return season_vals class FullAttention(nn.Module): def __init__( self, n_embd, # the embed dim n_head, # the number of heads attn_pdrop=0.1, # attention dropout prob resid_pdrop=0.1, # residual attention dropout prob ): super().__init__() assert n_embd % n_head == 0 # key, query, value projections for all heads self.key = nn.Linear(n_embd, n_embd) self.query = nn.Linear(n_embd, n_embd) self.value = nn.Linear(n_embd, n_embd) # regularization self.attn_drop = nn.Dropout(attn_pdrop) self.resid_drop = nn.Dropout(resid_pdrop) # output projection self.proj = nn.Linear(n_embd, n_embd) self.n_head = n_head def forward(self, x, mask=None): B, T, C = x.size() k = ( self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) ) # (B, nh, T, hs) q = ( self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) ) # (B, nh, T, hs) v = ( self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) ) # (B, nh, T, hs) att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) # (B, nh, T, T) att = F.softmax(att, dim=-1) # (B, nh, T, T) att = self.attn_drop(att) y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) y = ( y.transpose(1, 2).contiguous().view(B, T, C) ) # re-assemble all head outputs side by side, (B, T, C) att = att.mean(dim=1, keepdim=False) # (B, T, T) # output projection y = self.resid_drop(self.proj(y)) return y, att class CrossAttention(nn.Module): def __init__( self, n_embd, # the embed dim condition_embd, # condition dim n_head, # the number of heads attn_pdrop=0.1, # attention dropout prob resid_pdrop=0.1, # residual attention dropout prob ): super().__init__() assert n_embd % n_head == 0 # key, query, value projections for all heads self.key = nn.Linear(condition_embd, n_embd) self.query = nn.Linear(n_embd, n_embd) self.value = nn.Linear(condition_embd, n_embd) # regularization self.attn_drop = nn.Dropout(attn_pdrop) self.resid_drop = nn.Dropout(resid_pdrop) # output projection self.proj = nn.Linear(n_embd, n_embd) self.n_head = n_head def forward(self, x, encoder_output, mask=None): B, T, C = x.size() B, T_E, _ = encoder_output.size() # calculate query, key, values for all heads in batch and move head forward to be the batch dim k = ( self.key(encoder_output) .view(B, T_E, self.n_head, C // self.n_head) .transpose(1, 2) ) # (B, nh, T, hs) q = ( self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) ) # (B, nh, T, hs) v = ( self.value(encoder_output) .view(B, T_E, self.n_head, C // self.n_head) .transpose(1, 2) ) # (B, nh, T, hs) att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) # (B, nh, T, T) att = F.softmax(att, dim=-1) # (B, nh, T, T) att = self.attn_drop(att) y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) y = ( y.transpose(1, 2).contiguous().view(B, T, C) ) # re-assemble all head outputs side by side, (B, T, C) att = att.mean(dim=1, keepdim=False) # (B, T, T) # output projection y = self.resid_drop(self.proj(y)) return y, att class EncoderBlock(nn.Module): """an unassuming Transformer block""" def __init__( self, n_embd=1024, n_head=16, attn_pdrop=0.1, resid_pdrop=0.1, mlp_hidden_times=4, activate="GELU", ): super().__init__() self.ln1 = AdaLayerNorm(n_embd) self.ln2 = nn.LayerNorm(n_embd) self.attn = FullAttention( n_embd=n_embd, n_head=n_head, attn_pdrop=attn_pdrop, resid_pdrop=resid_pdrop, ) assert activate in ["GELU", "GELU2"] act = nn.GELU() if activate == "GELU" else GELU2() self.mlp = nn.Sequential( nn.Linear(n_embd, mlp_hidden_times * n_embd), act, nn.Linear(mlp_hidden_times * n_embd, n_embd), nn.Dropout(resid_pdrop), ) def forward(self, x, timestep, mask=None, label_emb=None): a, att = self.attn(self.ln1(x, timestep, label_emb), mask=mask) x = x + a x = x + self.mlp(self.ln2(x)) # only one really use encoder_output return x, att class Encoder(nn.Module): def __init__( self, n_layer=14, n_embd=1024, n_head=16, attn_pdrop=0.0, resid_pdrop=0.0, mlp_hidden_times=4, block_activate="GELU", ): super().__init__() self.blocks = nn.Sequential( *[ EncoderBlock( n_embd=n_embd, n_head=n_head, attn_pdrop=attn_pdrop, resid_pdrop=resid_pdrop, mlp_hidden_times=mlp_hidden_times, activate=block_activate, ) for _ in range(n_layer) ] ) def forward(self, input, t, padding_masks=None, label_emb=None): x = input for block_idx in range(len(self.blocks)): x, _ = self.blocks[block_idx](x, t, mask=padding_masks, label_emb=label_emb) return x class DecoderBlock(nn.Module): """an unassuming Transformer block""" def __init__( self, n_channel, n_feat, n_embd=1024, n_head=16, attn_pdrop=0.1, resid_pdrop=0.1, mlp_hidden_times=4, activate="GELU", condition_dim=1024, ): super().__init__() self.ln1 = AdaLayerNorm(n_embd) self.ln2 = nn.LayerNorm(n_embd) self.attn1 = FullAttention( n_embd=n_embd, n_head=n_head, attn_pdrop=attn_pdrop, resid_pdrop=resid_pdrop, ) self.attn2 = CrossAttention( n_embd=n_embd, condition_embd=condition_dim, n_head=n_head, attn_pdrop=attn_pdrop, resid_pdrop=resid_pdrop, ) self.ln1_1 = AdaLayerNorm(n_embd) assert activate in ["GELU", "GELU2"] act = nn.GELU() if activate == "GELU" else GELU2() self.trend = TrendBlock(n_channel, n_channel, n_embd, n_feat, act=act) # self.decomp = MovingBlock(n_channel) self.seasonal = FourierLayer(d_model=n_embd) # self.seasonal = SeasonBlock(n_channel, n_channel) self.mlp = nn.Sequential( nn.Linear(n_embd, mlp_hidden_times * n_embd), act, nn.Linear(mlp_hidden_times * n_embd, n_embd), nn.Dropout(resid_pdrop), ) self.proj = nn.Conv1d(n_channel, n_channel * 2, 1) self.linear = nn.Linear(n_embd, n_feat) def forward(self, x, encoder_output, timestep, mask=None, label_emb=None): a, att = self.attn1(self.ln1(x, timestep, label_emb), mask=mask) x = x + a a, att = self.attn2(self.ln1_1(x, timestep), encoder_output, mask=mask) x = x + a x1, x2 = self.proj(x).chunk(2, dim=1) trend, season = self.trend(x1), self.seasonal(x2) x = x + self.mlp(self.ln2(x)) m = torch.mean(x, dim=1, keepdim=True) return x - m, self.linear(m), trend, season class Decoder(nn.Module): def __init__( self, n_channel, n_feat, n_embd=1024, n_head=16, n_layer=10, attn_pdrop=0.1, resid_pdrop=0.1, mlp_hidden_times=4, block_activate="GELU", condition_dim=512, ): super().__init__() self.d_model = n_embd self.n_feat = n_feat self.blocks = nn.Sequential( *[ DecoderBlock( n_feat=n_feat, n_channel=n_channel, n_embd=n_embd, n_head=n_head, attn_pdrop=attn_pdrop, resid_pdrop=resid_pdrop, mlp_hidden_times=mlp_hidden_times, activate=block_activate, condition_dim=condition_dim, ) for _ in range(n_layer) ] ) def forward(self, x, t, enc, padding_masks=None, label_emb=None): b, c, _ = x.shape # att_weights = [] mean = [] season = torch.zeros((b, c, self.d_model), device=x.device) trend = torch.zeros((b, c, self.n_feat), device=x.device) for block_idx in range(len(self.blocks)): x, residual_mean, residual_trend, residual_season = self.blocks[block_idx]( x, enc, t, mask=padding_masks, label_emb=label_emb ) season += residual_season trend += residual_trend mean.append(residual_mean) mean = torch.cat(mean, dim=1) return x, mean, trend, season class Transformer(nn.Module): def __init__( self, n_feat, n_channel, n_layer_enc=5, n_layer_dec=14, n_embd=1024, n_heads=16, attn_pdrop=0.1, resid_pdrop=0.1, mlp_hidden_times=4, block_activate="GELU", max_len=2048, conv_params=None, **kwargs ): super().__init__() self.emb = Conv_MLP(n_feat, n_embd, resid_pdrop=resid_pdrop) self.inverse = Conv_MLP(n_embd, n_feat, resid_pdrop=resid_pdrop) if conv_params is None or conv_params[0] is None: if n_feat < 32 and n_channel < 64: kernel_size, padding = 1, 0 else: kernel_size, padding = 5, 2 else: kernel_size, padding = conv_params self.combine_s = nn.Conv1d( n_embd, n_feat, kernel_size=kernel_size, stride=1, padding=padding, padding_mode="circular", bias=False, ) self.combine_m = nn.Conv1d( n_layer_dec, 1, kernel_size=1, stride=1, padding=0, padding_mode="circular", bias=False, ) self.encoder = Encoder( n_layer_enc, n_embd, n_heads, attn_pdrop, resid_pdrop, mlp_hidden_times, block_activate, ) self.pos_enc = LearnablePositionalEncoding( n_embd, dropout=resid_pdrop, max_len=max_len ) self.decoder = Decoder( n_channel, n_feat, n_embd, n_heads, n_layer_dec, attn_pdrop, resid_pdrop, mlp_hidden_times, block_activate, condition_dim=n_embd, ) self.pos_dec = LearnablePositionalEncoding( n_embd, dropout=resid_pdrop, max_len=max_len ) def forward(self, input, t, padding_masks=None, return_res=False): emb = self.emb(input) inp_enc = self.pos_enc(emb) enc_cond = self.encoder(inp_enc, t, padding_masks=padding_masks) inp_dec = self.pos_dec(emb) output, mean, trend, season = self.decoder( inp_dec, t, enc_cond, padding_masks=padding_masks ) res = self.inverse(output) res_m = torch.mean(res, dim=1, keepdim=True) season_error = ( self.combine_s(season.transpose(1, 2)).transpose(1, 2) + res - res_m ) trend = self.combine_m(mean) + res_m + trend if return_res: return ( trend, self.combine_s(season.transpose(1, 2)).transpose(1, 2), res - res_m, ) return trend, season_error if __name__ == "__main__": pass