| from typing import Optional |
|
|
| import torch |
| from torch import nn, Tensor |
| import torch.nn.functional as F |
|
|
| from huggingface_hub import PyTorchModelHubMixin |
|
|
|
|
| class Transpose(nn.Module): |
| def __init__(self, *dims, contiguous=False): |
| super(Transpose, self).__init__() |
| self.dims, self.contiguous = dims, contiguous |
|
|
| def forward(self, x): |
| if self.contiguous: |
| return x.transpose(*self.dims).contiguous() |
| else: |
| return x.transpose(*self.dims) |
|
|
| def __repr__(self): |
| if self.contiguous: |
| return f"{self.__class__.__name__}(dims={', '.join([str(d) for d in self.dims])}).contiguous()" |
| else: |
| return ( |
| f"{self.__class__.__name__}({', '.join([str(d) for d in self.dims])})" |
| ) |
|
|
|
|
| pytorch_acts = [ |
| nn.ELU, |
| nn.LeakyReLU, |
| nn.PReLU, |
| nn.ReLU, |
| nn.ReLU6, |
| nn.SELU, |
| nn.CELU, |
| nn.GELU, |
| nn.Sigmoid, |
| nn.Softplus, |
| nn.Tanh, |
| nn.Softmax, |
| ] |
| pytorch_act_names = [a.__name__.lower() for a in pytorch_acts] |
|
|
|
|
| def get_act_fn(act, **act_kwargs): |
| if act is None: |
| return |
| elif isinstance(act, nn.Module): |
| return act |
| elif callable(act): |
| return act(**act_kwargs) |
| idx = pytorch_act_names.index(act.lower()) |
| return pytorch_acts[idx](**act_kwargs) |
|
|
|
|
| class RevIN(nn.Module): |
| def __init__( |
| self, |
| c_in: int, |
| affine: bool = True, |
| subtract_last: bool = False, |
| dim: int = 2, |
| eps: float = 1e-5, |
| ): |
| super().__init__() |
| self.c_in, self.affine, self.subtract_last, self.dim, self.eps = ( |
| c_in, |
| affine, |
| subtract_last, |
| dim, |
| eps, |
| ) |
| if self.affine: |
| self.weight = nn.Parameter(torch.ones(1, c_in, 1)) |
| self.bias = nn.Parameter(torch.zeros(1, c_in, 1)) |
|
|
| def forward(self, x: Tensor, mode: Tensor): |
| if mode: |
| return self.normalize(x) |
| else: |
| return self.denormalize(x) |
|
|
| def normalize(self, x): |
| if self.subtract_last: |
| self.sub = x[..., -1].unsqueeze(-1).detach() |
| else: |
| self.sub = torch.mean(x, dim=-1, keepdim=True).detach() |
| self.std = ( |
| torch.std(x, dim=-1, keepdim=True, unbiased=False).detach() + self.eps |
| ) |
| if self.affine: |
| x = x.sub(self.sub) |
| x = x.div(self.std) |
| x = x.mul(self.weight) |
| x = x.add(self.bias) |
| return x |
| else: |
| x = x.sub(self.sub) |
| x = x.div(self.std) |
| return x |
|
|
| def denormalize(self, x): |
| if self.affine: |
| x = x.sub(self.bias) |
| x = x.div(self.weight) |
| x = x.mul(self.std) |
| x = x.add(self.sub) |
| return x |
| else: |
| x = x.mul(self.std) |
| x = x.add(self.sub) |
| return x |
|
|
|
|
| class MovingAverage(nn.Module): |
| def __init__( |
| self, |
| kernel_size: int, |
| ): |
| super().__init__() |
| padding_left = (kernel_size - 1) // 2 |
| padding_right = kernel_size - padding_left - 1 |
| self.padding = torch.nn.ReplicationPad1d((padding_left, padding_right)) |
| self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=1) |
|
|
| def forward(self, x: Tensor): |
| return self.avg(self.padding(x)) |
|
|
|
|
| class SeriesDecomposition(nn.Module): |
| def __init__( |
| self, |
| kernel_size: int, |
| ): |
| super().__init__() |
| self.moving_avg = MovingAverage(kernel_size) |
|
|
| def forward(self, x: Tensor): |
| moving_mean = self.moving_avg(x) |
| residual = x - moving_mean |
| return residual, moving_mean |
|
|
|
|
| class _ScaledDotProductAttention(nn.Module): |
| def __init__(self, d_model, n_heads, attn_dropout=0.0, res_attention=False): |
| super().__init__() |
| self.attn_dropout = nn.Dropout(attn_dropout) |
| self.res_attention = res_attention |
| head_dim = d_model // n_heads |
| self.scale = nn.Parameter(torch.tensor(head_dim**-0.5), requires_grad=False) |
|
|
| def forward(self, q: Tensor, k: Tensor, v: Tensor, prev: Optional[Tensor] = None): |
| attn_scores = torch.matmul(q, k) * self.scale |
|
|
| if prev is not None: |
| attn_scores = attn_scores + prev |
|
|
| attn_weights = F.softmax(attn_scores, dim=-1) |
| attn_weights = self.attn_dropout(attn_weights) |
|
|
| output = torch.matmul(attn_weights, v) |
|
|
| if self.res_attention: |
| return output, attn_weights, attn_scores |
| else: |
| return output, attn_weights |
|
|
|
|
| class _MultiheadAttention(nn.Module): |
| def __init__( |
| self, |
| d_model, |
| n_heads, |
| d_k=None, |
| d_v=None, |
| res_attention=False, |
| attn_dropout=0.0, |
| proj_dropout=0.0, |
| qkv_bias=True, |
| ): |
| "Multi Head Attention Layer" |
|
|
| super().__init__() |
| d_k = d_v = d_model // n_heads |
|
|
| self.n_heads, self.d_k, self.d_v = n_heads, d_k, d_v |
|
|
| self.W_Q = nn.Linear(d_model, d_k * n_heads, bias=qkv_bias) |
| self.W_K = nn.Linear(d_model, d_k * n_heads, bias=qkv_bias) |
| self.W_V = nn.Linear(d_model, d_v * n_heads, bias=qkv_bias) |
|
|
| |
| self.res_attention = res_attention |
| self.sdp_attn = _ScaledDotProductAttention( |
| d_model, |
| n_heads, |
| attn_dropout=attn_dropout, |
| res_attention=self.res_attention, |
| ) |
|
|
| |
| self.to_out = nn.Sequential( |
| nn.Linear(n_heads * d_v, d_model), nn.Dropout(proj_dropout) |
| ) |
|
|
| def forward( |
| self, |
| Q: Tensor, |
| K: Optional[Tensor] = None, |
| V: Optional[Tensor] = None, |
| prev: Optional[Tensor] = None, |
| ): |
| bs = Q.size(0) |
| if K is None: |
| K = Q |
| if V is None: |
| V = Q |
|
|
| |
| q_s = ( |
| self.W_Q(Q).view(bs, -1, self.n_heads, self.d_k).transpose(1, 2) |
| ) |
| k_s = ( |
| self.W_K(K).view(bs, -1, self.n_heads, self.d_k).permute(0, 2, 3, 1) |
| ) |
| v_s = ( |
| self.W_V(V).view(bs, -1, self.n_heads, self.d_v).transpose(1, 2) |
| ) |
|
|
| |
| if self.res_attention: |
| output, attn_weights, attn_scores = self.sdp_attn(q_s, k_s, v_s, prev=prev) |
| else: |
| output, attn_weights = self.sdp_attn(q_s, k_s, v_s) |
| |
|
|
| |
| output = ( |
| output.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * self.d_v) |
| ) |
| output = self.to_out(output) |
|
|
| if self.res_attention: |
| return output, attn_weights, attn_scores |
| else: |
| return output, attn_weights |
|
|
|
|
| class Flatten_Head(nn.Module): |
| def __init__(self, individual, n_vars, nf, pred_dim): |
| super().__init__() |
|
|
| if isinstance(pred_dim, (tuple, list)): |
| pred_dim = pred_dim[-1] |
| self.individual = individual |
| self.n = n_vars if individual else 1 |
| self.nf, self.pred_dim = nf, pred_dim |
|
|
| if individual: |
| self.layers = nn.ModuleList() |
| for i in range(self.n): |
| self.layers.append( |
| nn.Sequential(nn.Flatten(start_dim=-2), nn.Linear(nf, pred_dim)) |
| ) |
| else: |
| self.layer = nn.Sequential( |
| nn.Flatten(start_dim=-2), nn.Linear(nf, pred_dim) |
| ) |
|
|
| def forward(self, x: Tensor): |
| """ |
| Args: |
| x: [bs x nvars x d_model x n_patch] |
| output: [bs x nvars x pred_dim] |
| """ |
| if self.individual: |
| x_out = [] |
| for i, layer in enumerate(self.layers): |
| x_out.append(layer(x[:, i])) |
| x = torch.stack(x_out, dim=1) |
| return x |
| else: |
| return self.layer(x) |
|
|
|
|
| class _TSTiEncoderLayer(nn.Module): |
| def __init__( |
| self, |
| q_len, |
| d_model, |
| n_heads, |
| d_k=None, |
| d_v=None, |
| d_ff=256, |
| store_attn=False, |
| norm="BatchNorm", |
| attn_dropout=0, |
| dropout=0.0, |
| bias=True, |
| activation="gelu", |
| res_attention=False, |
| pre_norm=False, |
| ): |
| super().__init__() |
| assert ( |
| not d_model % n_heads |
| ), f"d_model ({d_model}) must be divisible by n_heads ({n_heads})" |
| d_k = d_model // n_heads if d_k is None else d_k |
| d_v = d_model // n_heads if d_v is None else d_v |
|
|
| |
| self.res_attention = res_attention |
| self.self_attn = _MultiheadAttention( |
| d_model, |
| n_heads, |
| d_k, |
| d_v, |
| attn_dropout=attn_dropout, |
| proj_dropout=dropout, |
| res_attention=res_attention, |
| ) |
|
|
| |
| self.dropout_attn = nn.Dropout(dropout) |
| if "batch" in norm.lower(): |
| self.norm_attn = nn.Sequential( |
| Transpose(1, 2), nn.BatchNorm1d(d_model), Transpose(1, 2) |
| ) |
| else: |
| self.norm_attn = nn.LayerNorm(d_model) |
|
|
| |
| self.ff = nn.Sequential( |
| nn.Linear(d_model, d_ff, bias=bias), |
| get_act_fn(activation), |
| nn.Dropout(dropout), |
| nn.Linear(d_ff, d_model, bias=bias), |
| ) |
|
|
| |
| self.dropout_ffn = nn.Dropout(dropout) |
| if "batch" in norm.lower(): |
| self.norm_ffn = nn.Sequential( |
| Transpose(1, 2), nn.BatchNorm1d(d_model), Transpose(1, 2) |
| ) |
| else: |
| self.norm_ffn = nn.LayerNorm(d_model) |
|
|
| self.pre_norm = pre_norm |
| self.store_attn = store_attn |
|
|
| def forward(self, src: Tensor, prev: Optional[Tensor] = None): |
| """ |
| Args: |
| src: [bs x q_len x d_model] |
| """ |
|
|
| |
| if self.pre_norm: |
| src = self.norm_attn(src) |
| |
| if self.res_attention: |
| src2, attn, scores = self.self_attn(src, src, src, prev) |
| else: |
| src2, attn = self.self_attn(src, src, src) |
| if self.store_attn: |
| self.attn = attn |
| |
| src = src + self.dropout_attn( |
| src2 |
| ) |
| if not self.pre_norm: |
| src = self.norm_attn(src) |
|
|
| |
| if self.pre_norm: |
| src = self.norm_ffn(src) |
| |
| src2 = self.ff(src) |
| |
| src = src + self.dropout_ffn( |
| src2 |
| ) |
| if not self.pre_norm: |
| src = self.norm_ffn(src) |
|
|
| if self.res_attention: |
| return src, scores |
| else: |
| return src |
|
|
|
|
| class _TSTiEncoder(nn.Module): |
| def __init__( |
| self, |
| c_in, |
| patch_num, |
| patch_len, |
| n_layers=3, |
| d_model=128, |
| n_heads=16, |
| d_k=None, |
| d_v=None, |
| d_ff=256, |
| norm="BatchNorm", |
| attn_dropout=0.0, |
| dropout=0.0, |
| act="gelu", |
| store_attn=False, |
| res_attention=True, |
| pre_norm=False, |
| ): |
|
|
| super().__init__() |
|
|
| self.patch_num = patch_num |
| self.patch_len = patch_len |
|
|
| |
| q_len = patch_num |
| self.W_P = nn.Linear( |
| patch_len, d_model |
| ) |
| self.seq_len = q_len |
|
|
| |
| W_pos = torch.empty((q_len, d_model)) |
| nn.init.uniform_(W_pos, -0.02, 0.02) |
| self.W_pos = nn.Parameter(W_pos) |
|
|
| |
| self.dropout = nn.Dropout(dropout) |
|
|
| |
| self.layers = nn.ModuleList( |
| [ |
| _TSTiEncoderLayer( |
| q_len, |
| d_model, |
| n_heads=n_heads, |
| d_k=d_k, |
| d_v=d_v, |
| d_ff=d_ff, |
| norm=norm, |
| attn_dropout=attn_dropout, |
| dropout=dropout, |
| activation=act, |
| res_attention=res_attention, |
| pre_norm=pre_norm, |
| store_attn=store_attn, |
| ) |
| for i in range(n_layers) |
| ] |
| ) |
| self.res_attention = res_attention |
|
|
| def forward(self, x: Tensor): |
| """ |
| Args: |
| x: [bs x nvars x patch_len x patch_num] |
| """ |
|
|
| n_vars = x.shape[1] |
| |
| x = x.permute(0, 1, 3, 2) |
| x = self.W_P(x) |
|
|
| x = torch.reshape( |
| x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3]) |
| ) |
| x = self.dropout(x + self.W_pos) |
|
|
| |
| if self.res_attention: |
| scores = None |
| for mod in self.layers: |
| x, scores = mod(x, prev=scores) |
| else: |
| for mod in self.layers: |
| x = mod(x) |
| x = torch.reshape( |
| x, (-1, n_vars, x.shape[-2], x.shape[-1]) |
| ) |
| x = x.permute(0, 1, 3, 2) |
|
|
| return x |
|
|
|
|
| class _PatchTST_backbone(nn.Module): |
| def __init__( |
| self, |
| c_in, |
| seq_len, |
| pred_dim, |
| patch_len, |
| stride, |
| n_layers=3, |
| d_model=128, |
| n_heads=16, |
| d_k=None, |
| d_v=None, |
| d_ff=256, |
| norm="BatchNorm", |
| attn_dropout=0.0, |
| dropout=0.0, |
| act="gelu", |
| res_attention=True, |
| pre_norm=False, |
| store_attn=False, |
| padding_patch=True, |
| individual=False, |
| revin=True, |
| affine=True, |
| subtract_last=False, |
| ): |
|
|
| super().__init__() |
|
|
| self.revin = revin |
| self.revin_layer = RevIN(c_in, affine=affine, subtract_last=subtract_last) |
|
|
| self.patch_len = patch_len |
| self.stride = stride |
| self.padding_patch = padding_patch |
| patch_num = int((seq_len - patch_len) / stride + 1) + 1 |
| self.patch_num = patch_num |
| self.padding_patch_layer = nn.ReplicationPad1d((stride, 0)) |
|
|
| self.unfold = nn.Unfold(kernel_size=(1, patch_len), stride=stride) |
| self.patch_len = patch_len |
|
|
| self.backbone = _TSTiEncoder( |
| c_in, |
| patch_num=patch_num, |
| patch_len=patch_len, |
| n_layers=n_layers, |
| d_model=d_model, |
| n_heads=n_heads, |
| d_k=d_k, |
| d_v=d_v, |
| d_ff=d_ff, |
| attn_dropout=attn_dropout, |
| dropout=dropout, |
| act=act, |
| res_attention=res_attention, |
| pre_norm=pre_norm, |
| store_attn=store_attn, |
| ) |
|
|
| |
| self.head_nf = d_model * patch_num |
| self.n_vars = c_in |
| self.individual = individual |
| self.head = Flatten_Head(self.individual, self.n_vars, self.head_nf, pred_dim) |
|
|
| def forward(self, z: Tensor): |
| """ |
| Args: |
| z: [bs x c_in x seq_len] |
| """ |
|
|
| if self.revin: |
| z = self.revin_layer(z, torch.tensor(True, dtype=torch.bool)) |
|
|
| z = self.padding_patch_layer(z) |
| b, c, s = z.size() |
| z = z.reshape(-1, 1, 1, s) |
| z = self.unfold(z) |
| z = z.permute(0, 2, 1).reshape(b, c, -1, self.patch_len).permute(0, 1, 3, 2) |
|
|
| z = self.backbone(z) |
| z = self.head(z) |
|
|
| if self.revin: |
| z = self.revin_layer(z, torch.tensor(False, dtype=torch.bool)) |
| return z |
|
|
|
|
| class PatchTST(nn.Module, PyTorchModelHubMixin): |
| def __init__( |
| self, |
| c_in, |
| c_out, |
| seq_len, |
| pred_dim=None, |
| n_layers=2, |
| n_heads=8, |
| d_model=512, |
| d_ff=2048, |
| dropout=0.05, |
| attn_dropout=0.0, |
| patch_len=16, |
| stride=8, |
| padding_patch=True, |
| revin=True, |
| affine=False, |
| individual=False, |
| subtract_last=False, |
| decomposition=False, |
| kernel_size=25, |
| activation="gelu", |
| norm="BatchNorm", |
| pre_norm=False, |
| res_attention=True, |
| store_attn=False, |
| classification=False, |
| ): |
|
|
| super().__init__() |
|
|
| if pred_dim is None: |
| pred_dim = seq_len |
|
|
| self.decomposition = decomposition |
| if self.decomposition: |
| self.decomp_module = SeriesDecomposition(kernel_size) |
| self.model_trend = _PatchTST_backbone( |
| c_in=c_in, |
| seq_len=seq_len, |
| pred_dim=pred_dim, |
| patch_len=patch_len, |
| stride=stride, |
| n_layers=n_layers, |
| d_model=d_model, |
| n_heads=n_heads, |
| d_ff=d_ff, |
| norm=norm, |
| attn_dropout=attn_dropout, |
| dropout=dropout, |
| act=activation, |
| res_attention=res_attention, |
| pre_norm=pre_norm, |
| store_attn=store_attn, |
| padding_patch=padding_patch, |
| individual=individual, |
| revin=revin, |
| affine=affine, |
| subtract_last=subtract_last, |
| ) |
| self.model_res = _PatchTST_backbone( |
| c_in=c_in, |
| seq_len=seq_len, |
| pred_dim=pred_dim, |
| patch_len=patch_len, |
| stride=stride, |
| n_layers=n_layers, |
| d_model=d_model, |
| n_heads=n_heads, |
| d_ff=d_ff, |
| norm=norm, |
| attn_dropout=attn_dropout, |
| dropout=dropout, |
| act=activation, |
| res_attention=res_attention, |
| pre_norm=pre_norm, |
| store_attn=store_attn, |
| padding_patch=padding_patch, |
| individual=individual, |
| revin=revin, |
| affine=affine, |
| subtract_last=subtract_last, |
| ) |
| self.patch_num = self.model_trend.patch_num |
| else: |
| self.model = _PatchTST_backbone( |
| c_in=c_in, |
| seq_len=seq_len, |
| pred_dim=pred_dim, |
| patch_len=patch_len, |
| stride=stride, |
| n_layers=n_layers, |
| d_model=d_model, |
| n_heads=n_heads, |
| d_ff=d_ff, |
| norm=norm, |
| attn_dropout=attn_dropout, |
| dropout=dropout, |
| act=activation, |
| res_attention=res_attention, |
| pre_norm=pre_norm, |
| store_attn=store_attn, |
| padding_patch=padding_patch, |
| individual=individual, |
| revin=revin, |
| affine=affine, |
| subtract_last=subtract_last, |
| ) |
| self.patch_num = self.model.patch_num |
| self.classification = classification |
|
|
| def forward(self, x): |
| if self.decomposition: |
| res_init, trend_init = self.decomp_module(x) |
| res = self.model_res(res_init) |
| trend = self.model_trend(trend_init) |
| x = res + trend |
| else: |
| x = self.model(x) |
|
|
| if self.classification: |
| x = x.squeeze(-2) |
| return x |
|
|