from typing import Optional import torch from torch import nn, Tensor import torch.nn.functional as F from huggingface_hub import PyTorchModelHubMixin class Transpose(nn.Module): def __init__(self, *dims, contiguous=False): super(Transpose, self).__init__() self.dims, self.contiguous = dims, contiguous def forward(self, x): if self.contiguous: return x.transpose(*self.dims).contiguous() else: return x.transpose(*self.dims) def __repr__(self): if self.contiguous: return f"{self.__class__.__name__}(dims={', '.join([str(d) for d in self.dims])}).contiguous()" else: return ( f"{self.__class__.__name__}({', '.join([str(d) for d in self.dims])})" ) pytorch_acts = [ nn.ELU, nn.LeakyReLU, nn.PReLU, nn.ReLU, nn.ReLU6, nn.SELU, nn.CELU, nn.GELU, nn.Sigmoid, nn.Softplus, nn.Tanh, nn.Softmax, ] pytorch_act_names = [a.__name__.lower() for a in pytorch_acts] def get_act_fn(act, **act_kwargs): if act is None: return elif isinstance(act, nn.Module): return act elif callable(act): return act(**act_kwargs) idx = pytorch_act_names.index(act.lower()) return pytorch_acts[idx](**act_kwargs) class RevIN(nn.Module): def __init__( self, c_in: int, affine: bool = True, subtract_last: bool = False, dim: int = 2, eps: float = 1e-5, ): super().__init__() self.c_in, self.affine, self.subtract_last, self.dim, self.eps = ( c_in, affine, subtract_last, dim, eps, ) if self.affine: self.weight = nn.Parameter(torch.ones(1, c_in, 1)) self.bias = nn.Parameter(torch.zeros(1, c_in, 1)) def forward(self, x: Tensor, mode: Tensor): if mode: return self.normalize(x) else: return self.denormalize(x) def normalize(self, x): if self.subtract_last: self.sub = x[..., -1].unsqueeze(-1).detach() else: self.sub = torch.mean(x, dim=-1, keepdim=True).detach() self.std = ( torch.std(x, dim=-1, keepdim=True, unbiased=False).detach() + self.eps ) if self.affine: x = x.sub(self.sub) x = x.div(self.std) x = x.mul(self.weight) x = x.add(self.bias) return x else: x = x.sub(self.sub) x = x.div(self.std) return x def denormalize(self, x): if self.affine: x = x.sub(self.bias) x = x.div(self.weight) x = x.mul(self.std) x = x.add(self.sub) return x else: x = x.mul(self.std) x = x.add(self.sub) return x class MovingAverage(nn.Module): def __init__( self, kernel_size: int, ): super().__init__() padding_left = (kernel_size - 1) // 2 padding_right = kernel_size - padding_left - 1 self.padding = torch.nn.ReplicationPad1d((padding_left, padding_right)) self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=1) def forward(self, x: Tensor): return self.avg(self.padding(x)) class SeriesDecomposition(nn.Module): def __init__( self, kernel_size: int, # the size of the window ): super().__init__() self.moving_avg = MovingAverage(kernel_size) def forward(self, x: Tensor): moving_mean = self.moving_avg(x) residual = x - moving_mean return residual, moving_mean class _ScaledDotProductAttention(nn.Module): def __init__(self, d_model, n_heads, attn_dropout=0.0, res_attention=False): super().__init__() self.attn_dropout = nn.Dropout(attn_dropout) self.res_attention = res_attention head_dim = d_model // n_heads self.scale = nn.Parameter(torch.tensor(head_dim**-0.5), requires_grad=False) def forward(self, q: Tensor, k: Tensor, v: Tensor, prev: Optional[Tensor] = None): attn_scores = torch.matmul(q, k) * self.scale if prev is not None: attn_scores = attn_scores + prev attn_weights = F.softmax(attn_scores, dim=-1) attn_weights = self.attn_dropout(attn_weights) output = torch.matmul(attn_weights, v) if self.res_attention: return output, attn_weights, attn_scores else: return output, attn_weights class _MultiheadAttention(nn.Module): def __init__( self, d_model, n_heads, d_k=None, d_v=None, res_attention=False, attn_dropout=0.0, proj_dropout=0.0, qkv_bias=True, ): "Multi Head Attention Layer" super().__init__() d_k = d_v = d_model // n_heads self.n_heads, self.d_k, self.d_v = n_heads, d_k, d_v self.W_Q = nn.Linear(d_model, d_k * n_heads, bias=qkv_bias) self.W_K = nn.Linear(d_model, d_k * n_heads, bias=qkv_bias) self.W_V = nn.Linear(d_model, d_v * n_heads, bias=qkv_bias) # Scaled Dot-Product Attention (multiple heads) self.res_attention = res_attention self.sdp_attn = _ScaledDotProductAttention( d_model, n_heads, attn_dropout=attn_dropout, res_attention=self.res_attention, ) # Poject output self.to_out = nn.Sequential( nn.Linear(n_heads * d_v, d_model), nn.Dropout(proj_dropout) ) def forward( self, Q: Tensor, K: Optional[Tensor] = None, V: Optional[Tensor] = None, prev: Optional[Tensor] = None, ): bs = Q.size(0) if K is None: K = Q if V is None: V = Q # Linear (+ split in multiple heads) q_s = ( self.W_Q(Q).view(bs, -1, self.n_heads, self.d_k).transpose(1, 2) ) # q_s: [bs x n_heads x max_q_len x d_k] k_s = ( self.W_K(K).view(bs, -1, self.n_heads, self.d_k).permute(0, 2, 3, 1) ) # k_s: [bs x n_heads x d_k x q_len] - transpose(1,2) + transpose(2,3) v_s = ( self.W_V(V).view(bs, -1, self.n_heads, self.d_v).transpose(1, 2) ) # v_s: [bs x n_heads x q_len x d_v] # Apply Scaled Dot-Product Attention (multiple heads) if self.res_attention: output, attn_weights, attn_scores = self.sdp_attn(q_s, k_s, v_s, prev=prev) else: output, attn_weights = self.sdp_attn(q_s, k_s, v_s) # output: [bs x n_heads x q_len x d_v], attn: [bs x n_heads x q_len x q_len], scores: [bs x n_heads x max_q_len x q_len] # back to the original inputs dimensions output = ( output.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * self.d_v) ) # output: [bs x q_len x n_heads * d_v] output = self.to_out(output) if self.res_attention: return output, attn_weights, attn_scores else: return output, attn_weights class Flatten_Head(nn.Module): def __init__(self, individual, n_vars, nf, pred_dim): super().__init__() if isinstance(pred_dim, (tuple, list)): pred_dim = pred_dim[-1] self.individual = individual self.n = n_vars if individual else 1 self.nf, self.pred_dim = nf, pred_dim if individual: self.layers = nn.ModuleList() for i in range(self.n): self.layers.append( nn.Sequential(nn.Flatten(start_dim=-2), nn.Linear(nf, pred_dim)) ) else: self.layer = nn.Sequential( nn.Flatten(start_dim=-2), nn.Linear(nf, pred_dim) ) def forward(self, x: Tensor): """ Args: x: [bs x nvars x d_model x n_patch] output: [bs x nvars x pred_dim] """ if self.individual: x_out = [] for i, layer in enumerate(self.layers): x_out.append(layer(x[:, i])) x = torch.stack(x_out, dim=1) return x else: return self.layer(x) class _TSTiEncoderLayer(nn.Module): def __init__( self, q_len, d_model, n_heads, d_k=None, d_v=None, d_ff=256, store_attn=False, norm="BatchNorm", attn_dropout=0, dropout=0.0, bias=True, activation="gelu", res_attention=False, pre_norm=False, ): super().__init__() assert ( not d_model % n_heads ), f"d_model ({d_model}) must be divisible by n_heads ({n_heads})" d_k = d_model // n_heads if d_k is None else d_k d_v = d_model // n_heads if d_v is None else d_v # Multi-Head attention self.res_attention = res_attention self.self_attn = _MultiheadAttention( d_model, n_heads, d_k, d_v, attn_dropout=attn_dropout, proj_dropout=dropout, res_attention=res_attention, ) # Add & Norm self.dropout_attn = nn.Dropout(dropout) if "batch" in norm.lower(): self.norm_attn = nn.Sequential( Transpose(1, 2), nn.BatchNorm1d(d_model), Transpose(1, 2) ) else: self.norm_attn = nn.LayerNorm(d_model) # Position-wise Feed-Forward self.ff = nn.Sequential( nn.Linear(d_model, d_ff, bias=bias), get_act_fn(activation), nn.Dropout(dropout), nn.Linear(d_ff, d_model, bias=bias), ) # Add & Norm self.dropout_ffn = nn.Dropout(dropout) if "batch" in norm.lower(): self.norm_ffn = nn.Sequential( Transpose(1, 2), nn.BatchNorm1d(d_model), Transpose(1, 2) ) else: self.norm_ffn = nn.LayerNorm(d_model) self.pre_norm = pre_norm self.store_attn = store_attn def forward(self, src: Tensor, prev: Optional[Tensor] = None): """ Args: src: [bs x q_len x d_model] """ # Multi-Head attention sublayer if self.pre_norm: src = self.norm_attn(src) ## Multi-Head attention if self.res_attention: src2, attn, scores = self.self_attn(src, src, src, prev) else: src2, attn = self.self_attn(src, src, src) if self.store_attn: self.attn = attn ## Add & Norm src = src + self.dropout_attn( src2 ) # Add: residual connection with residual dropout if not self.pre_norm: src = self.norm_attn(src) # Feed-forward sublayer if self.pre_norm: src = self.norm_ffn(src) ## Position-wise Feed-Forward src2 = self.ff(src) ## Add & Norm src = src + self.dropout_ffn( src2 ) # Add: residual connection with residual dropout if not self.pre_norm: src = self.norm_ffn(src) if self.res_attention: return src, scores else: return src class _TSTiEncoder(nn.Module): # i means channel-independent def __init__( self, c_in, patch_num, patch_len, n_layers=3, d_model=128, n_heads=16, d_k=None, d_v=None, d_ff=256, norm="BatchNorm", attn_dropout=0.0, dropout=0.0, act="gelu", store_attn=False, res_attention=True, pre_norm=False, ): super().__init__() self.patch_num = patch_num self.patch_len = patch_len # Input encoding q_len = patch_num self.W_P = nn.Linear( patch_len, d_model ) # Eq 1: projection of feature vectors onto a d-dim vector space self.seq_len = q_len # Positional encoding W_pos = torch.empty((q_len, d_model)) nn.init.uniform_(W_pos, -0.02, 0.02) self.W_pos = nn.Parameter(W_pos) # Residual dropout self.dropout = nn.Dropout(dropout) # Encoder self.layers = nn.ModuleList( [ _TSTiEncoderLayer( q_len, d_model, n_heads=n_heads, d_k=d_k, d_v=d_v, d_ff=d_ff, norm=norm, attn_dropout=attn_dropout, dropout=dropout, activation=act, res_attention=res_attention, pre_norm=pre_norm, store_attn=store_attn, ) for i in range(n_layers) ] ) self.res_attention = res_attention def forward(self, x: Tensor): """ Args: x: [bs x nvars x patch_len x patch_num] """ n_vars = x.shape[1] # Input encoding x = x.permute(0, 1, 3, 2) # x: [bs x nvars x patch_num x patch_len] x = self.W_P(x) # x: [bs x nvars x patch_num x d_model] x = torch.reshape( x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3]) ) # x: [bs * nvars x patch_num x d_model] x = self.dropout(x + self.W_pos) # x: [bs * nvars x patch_num x d_model] # Encoder if self.res_attention: scores = None for mod in self.layers: x, scores = mod(x, prev=scores) else: for mod in self.layers: x = mod(x) x = torch.reshape( x, (-1, n_vars, x.shape[-2], x.shape[-1]) ) # x: [bs x nvars x patch_num x d_model] x = x.permute(0, 1, 3, 2) # x: [bs x nvars x d_model x patch_num] return x class _PatchTST_backbone(nn.Module): def __init__( self, c_in, seq_len, pred_dim, patch_len, stride, n_layers=3, d_model=128, n_heads=16, d_k=None, d_v=None, d_ff=256, norm="BatchNorm", attn_dropout=0.0, dropout=0.0, act="gelu", res_attention=True, pre_norm=False, store_attn=False, padding_patch=True, individual=False, revin=True, affine=True, subtract_last=False, ): super().__init__() self.revin = revin self.revin_layer = RevIN(c_in, affine=affine, subtract_last=subtract_last) self.patch_len = patch_len self.stride = stride self.padding_patch = padding_patch patch_num = int((seq_len - patch_len) / stride + 1) + 1 self.patch_num = patch_num self.padding_patch_layer = nn.ReplicationPad1d((stride, 0)) self.unfold = nn.Unfold(kernel_size=(1, patch_len), stride=stride) self.patch_len = patch_len self.backbone = _TSTiEncoder( c_in, patch_num=patch_num, patch_len=patch_len, n_layers=n_layers, d_model=d_model, n_heads=n_heads, d_k=d_k, d_v=d_v, d_ff=d_ff, attn_dropout=attn_dropout, dropout=dropout, act=act, res_attention=res_attention, pre_norm=pre_norm, store_attn=store_attn, ) # Head self.head_nf = d_model * patch_num self.n_vars = c_in self.individual = individual self.head = Flatten_Head(self.individual, self.n_vars, self.head_nf, pred_dim) def forward(self, z: Tensor): """ Args: z: [bs x c_in x seq_len] """ if self.revin: z = self.revin_layer(z, torch.tensor(True, dtype=torch.bool)) z = self.padding_patch_layer(z) b, c, s = z.size() z = z.reshape(-1, 1, 1, s) z = self.unfold(z) z = z.permute(0, 2, 1).reshape(b, c, -1, self.patch_len).permute(0, 1, 3, 2) z = self.backbone(z) z = self.head(z) if self.revin: z = self.revin_layer(z, torch.tensor(False, dtype=torch.bool)) return z class PatchTST(nn.Module, PyTorchModelHubMixin): def __init__( self, c_in, c_out, seq_len, pred_dim=None, n_layers=2, n_heads=8, d_model=512, d_ff=2048, dropout=0.05, attn_dropout=0.0, patch_len=16, stride=8, padding_patch=True, revin=True, affine=False, individual=False, subtract_last=False, decomposition=False, kernel_size=25, activation="gelu", norm="BatchNorm", pre_norm=False, res_attention=True, store_attn=False, classification=False, ): super().__init__() if pred_dim is None: pred_dim = seq_len self.decomposition = decomposition if self.decomposition: self.decomp_module = SeriesDecomposition(kernel_size) self.model_trend = _PatchTST_backbone( c_in=c_in, seq_len=seq_len, pred_dim=pred_dim, patch_len=patch_len, stride=stride, n_layers=n_layers, d_model=d_model, n_heads=n_heads, d_ff=d_ff, norm=norm, attn_dropout=attn_dropout, dropout=dropout, act=activation, res_attention=res_attention, pre_norm=pre_norm, store_attn=store_attn, padding_patch=padding_patch, individual=individual, revin=revin, affine=affine, subtract_last=subtract_last, ) self.model_res = _PatchTST_backbone( c_in=c_in, seq_len=seq_len, pred_dim=pred_dim, patch_len=patch_len, stride=stride, n_layers=n_layers, d_model=d_model, n_heads=n_heads, d_ff=d_ff, norm=norm, attn_dropout=attn_dropout, dropout=dropout, act=activation, res_attention=res_attention, pre_norm=pre_norm, store_attn=store_attn, padding_patch=padding_patch, individual=individual, revin=revin, affine=affine, subtract_last=subtract_last, ) self.patch_num = self.model_trend.patch_num else: self.model = _PatchTST_backbone( c_in=c_in, seq_len=seq_len, pred_dim=pred_dim, patch_len=patch_len, stride=stride, n_layers=n_layers, d_model=d_model, n_heads=n_heads, d_ff=d_ff, norm=norm, attn_dropout=attn_dropout, dropout=dropout, act=activation, res_attention=res_attention, pre_norm=pre_norm, store_attn=store_attn, padding_patch=padding_patch, individual=individual, revin=revin, affine=affine, subtract_last=subtract_last, ) self.patch_num = self.model.patch_num self.classification = classification def forward(self, x): if self.decomposition: res_init, trend_init = self.decomp_module(x) res = self.model_res(res_init) trend = self.model_trend(trend_init) x = res + trend else: x = self.model(x) if self.classification: x = x.squeeze(-2) return x