enver1323's picture
feat: update model
ec72f1e
from typing import Optional
import torch
from torch import nn, Tensor
import torch.nn.functional as F
from huggingface_hub import PyTorchModelHubMixin
class Transpose(nn.Module):
def __init__(self, *dims, contiguous=False):
super(Transpose, self).__init__()
self.dims, self.contiguous = dims, contiguous
def forward(self, x):
if self.contiguous:
return x.transpose(*self.dims).contiguous()
else:
return x.transpose(*self.dims)
def __repr__(self):
if self.contiguous:
return f"{self.__class__.__name__}(dims={', '.join([str(d) for d in self.dims])}).contiguous()"
else:
return (
f"{self.__class__.__name__}({', '.join([str(d) for d in self.dims])})"
)
pytorch_acts = [
nn.ELU,
nn.LeakyReLU,
nn.PReLU,
nn.ReLU,
nn.ReLU6,
nn.SELU,
nn.CELU,
nn.GELU,
nn.Sigmoid,
nn.Softplus,
nn.Tanh,
nn.Softmax,
]
pytorch_act_names = [a.__name__.lower() for a in pytorch_acts]
def get_act_fn(act, **act_kwargs):
if act is None:
return
elif isinstance(act, nn.Module):
return act
elif callable(act):
return act(**act_kwargs)
idx = pytorch_act_names.index(act.lower())
return pytorch_acts[idx](**act_kwargs)
class RevIN(nn.Module):
def __init__(
self,
c_in: int,
affine: bool = True,
subtract_last: bool = False,
dim: int = 2,
eps: float = 1e-5,
):
super().__init__()
self.c_in, self.affine, self.subtract_last, self.dim, self.eps = (
c_in,
affine,
subtract_last,
dim,
eps,
)
if self.affine:
self.weight = nn.Parameter(torch.ones(1, c_in, 1))
self.bias = nn.Parameter(torch.zeros(1, c_in, 1))
def forward(self, x: Tensor, mode: Tensor):
if mode:
return self.normalize(x)
else:
return self.denormalize(x)
def normalize(self, x):
if self.subtract_last:
self.sub = x[..., -1].unsqueeze(-1).detach()
else:
self.sub = torch.mean(x, dim=-1, keepdim=True).detach()
self.std = (
torch.std(x, dim=-1, keepdim=True, unbiased=False).detach() + self.eps
)
if self.affine:
x = x.sub(self.sub)
x = x.div(self.std)
x = x.mul(self.weight)
x = x.add(self.bias)
return x
else:
x = x.sub(self.sub)
x = x.div(self.std)
return x
def denormalize(self, x):
if self.affine:
x = x.sub(self.bias)
x = x.div(self.weight)
x = x.mul(self.std)
x = x.add(self.sub)
return x
else:
x = x.mul(self.std)
x = x.add(self.sub)
return x
class MovingAverage(nn.Module):
def __init__(
self,
kernel_size: int,
):
super().__init__()
padding_left = (kernel_size - 1) // 2
padding_right = kernel_size - padding_left - 1
self.padding = torch.nn.ReplicationPad1d((padding_left, padding_right))
self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=1)
def forward(self, x: Tensor):
return self.avg(self.padding(x))
class SeriesDecomposition(nn.Module):
def __init__(
self,
kernel_size: int, # the size of the window
):
super().__init__()
self.moving_avg = MovingAverage(kernel_size)
def forward(self, x: Tensor):
moving_mean = self.moving_avg(x)
residual = x - moving_mean
return residual, moving_mean
class _ScaledDotProductAttention(nn.Module):
def __init__(self, d_model, n_heads, attn_dropout=0.0, res_attention=False):
super().__init__()
self.attn_dropout = nn.Dropout(attn_dropout)
self.res_attention = res_attention
head_dim = d_model // n_heads
self.scale = nn.Parameter(torch.tensor(head_dim**-0.5), requires_grad=False)
def forward(self, q: Tensor, k: Tensor, v: Tensor, prev: Optional[Tensor] = None):
attn_scores = torch.matmul(q, k) * self.scale
if prev is not None:
attn_scores = attn_scores + prev
attn_weights = F.softmax(attn_scores, dim=-1)
attn_weights = self.attn_dropout(attn_weights)
output = torch.matmul(attn_weights, v)
if self.res_attention:
return output, attn_weights, attn_scores
else:
return output, attn_weights
class _MultiheadAttention(nn.Module):
def __init__(
self,
d_model,
n_heads,
d_k=None,
d_v=None,
res_attention=False,
attn_dropout=0.0,
proj_dropout=0.0,
qkv_bias=True,
):
"Multi Head Attention Layer"
super().__init__()
d_k = d_v = d_model // n_heads
self.n_heads, self.d_k, self.d_v = n_heads, d_k, d_v
self.W_Q = nn.Linear(d_model, d_k * n_heads, bias=qkv_bias)
self.W_K = nn.Linear(d_model, d_k * n_heads, bias=qkv_bias)
self.W_V = nn.Linear(d_model, d_v * n_heads, bias=qkv_bias)
# Scaled Dot-Product Attention (multiple heads)
self.res_attention = res_attention
self.sdp_attn = _ScaledDotProductAttention(
d_model,
n_heads,
attn_dropout=attn_dropout,
res_attention=self.res_attention,
)
# Poject output
self.to_out = nn.Sequential(
nn.Linear(n_heads * d_v, d_model), nn.Dropout(proj_dropout)
)
def forward(
self,
Q: Tensor,
K: Optional[Tensor] = None,
V: Optional[Tensor] = None,
prev: Optional[Tensor] = None,
):
bs = Q.size(0)
if K is None:
K = Q
if V is None:
V = Q
# Linear (+ split in multiple heads)
q_s = (
self.W_Q(Q).view(bs, -1, self.n_heads, self.d_k).transpose(1, 2)
) # q_s: [bs x n_heads x max_q_len x d_k]
k_s = (
self.W_K(K).view(bs, -1, self.n_heads, self.d_k).permute(0, 2, 3, 1)
) # k_s: [bs x n_heads x d_k x q_len] - transpose(1,2) + transpose(2,3)
v_s = (
self.W_V(V).view(bs, -1, self.n_heads, self.d_v).transpose(1, 2)
) # v_s: [bs x n_heads x q_len x d_v]
# Apply Scaled Dot-Product Attention (multiple heads)
if self.res_attention:
output, attn_weights, attn_scores = self.sdp_attn(q_s, k_s, v_s, prev=prev)
else:
output, attn_weights = self.sdp_attn(q_s, k_s, v_s)
# output: [bs x n_heads x q_len x d_v], attn: [bs x n_heads x q_len x q_len], scores: [bs x n_heads x max_q_len x q_len]
# back to the original inputs dimensions
output = (
output.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * self.d_v)
) # output: [bs x q_len x n_heads * d_v]
output = self.to_out(output)
if self.res_attention:
return output, attn_weights, attn_scores
else:
return output, attn_weights
class Flatten_Head(nn.Module):
def __init__(self, individual, n_vars, nf, pred_dim):
super().__init__()
if isinstance(pred_dim, (tuple, list)):
pred_dim = pred_dim[-1]
self.individual = individual
self.n = n_vars if individual else 1
self.nf, self.pred_dim = nf, pred_dim
if individual:
self.layers = nn.ModuleList()
for i in range(self.n):
self.layers.append(
nn.Sequential(nn.Flatten(start_dim=-2), nn.Linear(nf, pred_dim))
)
else:
self.layer = nn.Sequential(
nn.Flatten(start_dim=-2), nn.Linear(nf, pred_dim)
)
def forward(self, x: Tensor):
"""
Args:
x: [bs x nvars x d_model x n_patch]
output: [bs x nvars x pred_dim]
"""
if self.individual:
x_out = []
for i, layer in enumerate(self.layers):
x_out.append(layer(x[:, i]))
x = torch.stack(x_out, dim=1)
return x
else:
return self.layer(x)
class _TSTiEncoderLayer(nn.Module):
def __init__(
self,
q_len,
d_model,
n_heads,
d_k=None,
d_v=None,
d_ff=256,
store_attn=False,
norm="BatchNorm",
attn_dropout=0,
dropout=0.0,
bias=True,
activation="gelu",
res_attention=False,
pre_norm=False,
):
super().__init__()
assert (
not d_model % n_heads
), f"d_model ({d_model}) must be divisible by n_heads ({n_heads})"
d_k = d_model // n_heads if d_k is None else d_k
d_v = d_model // n_heads if d_v is None else d_v
# Multi-Head attention
self.res_attention = res_attention
self.self_attn = _MultiheadAttention(
d_model,
n_heads,
d_k,
d_v,
attn_dropout=attn_dropout,
proj_dropout=dropout,
res_attention=res_attention,
)
# Add & Norm
self.dropout_attn = nn.Dropout(dropout)
if "batch" in norm.lower():
self.norm_attn = nn.Sequential(
Transpose(1, 2), nn.BatchNorm1d(d_model), Transpose(1, 2)
)
else:
self.norm_attn = nn.LayerNorm(d_model)
# Position-wise Feed-Forward
self.ff = nn.Sequential(
nn.Linear(d_model, d_ff, bias=bias),
get_act_fn(activation),
nn.Dropout(dropout),
nn.Linear(d_ff, d_model, bias=bias),
)
# Add & Norm
self.dropout_ffn = nn.Dropout(dropout)
if "batch" in norm.lower():
self.norm_ffn = nn.Sequential(
Transpose(1, 2), nn.BatchNorm1d(d_model), Transpose(1, 2)
)
else:
self.norm_ffn = nn.LayerNorm(d_model)
self.pre_norm = pre_norm
self.store_attn = store_attn
def forward(self, src: Tensor, prev: Optional[Tensor] = None):
"""
Args:
src: [bs x q_len x d_model]
"""
# Multi-Head attention sublayer
if self.pre_norm:
src = self.norm_attn(src)
## Multi-Head attention
if self.res_attention:
src2, attn, scores = self.self_attn(src, src, src, prev)
else:
src2, attn = self.self_attn(src, src, src)
if self.store_attn:
self.attn = attn
## Add & Norm
src = src + self.dropout_attn(
src2
) # Add: residual connection with residual dropout
if not self.pre_norm:
src = self.norm_attn(src)
# Feed-forward sublayer
if self.pre_norm:
src = self.norm_ffn(src)
## Position-wise Feed-Forward
src2 = self.ff(src)
## Add & Norm
src = src + self.dropout_ffn(
src2
) # Add: residual connection with residual dropout
if not self.pre_norm:
src = self.norm_ffn(src)
if self.res_attention:
return src, scores
else:
return src
class _TSTiEncoder(nn.Module): # i means channel-independent
def __init__(
self,
c_in,
patch_num,
patch_len,
n_layers=3,
d_model=128,
n_heads=16,
d_k=None,
d_v=None,
d_ff=256,
norm="BatchNorm",
attn_dropout=0.0,
dropout=0.0,
act="gelu",
store_attn=False,
res_attention=True,
pre_norm=False,
):
super().__init__()
self.patch_num = patch_num
self.patch_len = patch_len
# Input encoding
q_len = patch_num
self.W_P = nn.Linear(
patch_len, d_model
) # Eq 1: projection of feature vectors onto a d-dim vector space
self.seq_len = q_len
# Positional encoding
W_pos = torch.empty((q_len, d_model))
nn.init.uniform_(W_pos, -0.02, 0.02)
self.W_pos = nn.Parameter(W_pos)
# Residual dropout
self.dropout = nn.Dropout(dropout)
# Encoder
self.layers = nn.ModuleList(
[
_TSTiEncoderLayer(
q_len,
d_model,
n_heads=n_heads,
d_k=d_k,
d_v=d_v,
d_ff=d_ff,
norm=norm,
attn_dropout=attn_dropout,
dropout=dropout,
activation=act,
res_attention=res_attention,
pre_norm=pre_norm,
store_attn=store_attn,
)
for i in range(n_layers)
]
)
self.res_attention = res_attention
def forward(self, x: Tensor):
"""
Args:
x: [bs x nvars x patch_len x patch_num]
"""
n_vars = x.shape[1]
# Input encoding
x = x.permute(0, 1, 3, 2) # x: [bs x nvars x patch_num x patch_len]
x = self.W_P(x) # x: [bs x nvars x patch_num x d_model]
x = torch.reshape(
x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3])
) # x: [bs * nvars x patch_num x d_model]
x = self.dropout(x + self.W_pos) # x: [bs * nvars x patch_num x d_model]
# Encoder
if self.res_attention:
scores = None
for mod in self.layers:
x, scores = mod(x, prev=scores)
else:
for mod in self.layers:
x = mod(x)
x = torch.reshape(
x, (-1, n_vars, x.shape[-2], x.shape[-1])
) # x: [bs x nvars x patch_num x d_model]
x = x.permute(0, 1, 3, 2) # x: [bs x nvars x d_model x patch_num]
return x
class _PatchTST_backbone(nn.Module):
def __init__(
self,
c_in,
seq_len,
pred_dim,
patch_len,
stride,
n_layers=3,
d_model=128,
n_heads=16,
d_k=None,
d_v=None,
d_ff=256,
norm="BatchNorm",
attn_dropout=0.0,
dropout=0.0,
act="gelu",
res_attention=True,
pre_norm=False,
store_attn=False,
padding_patch=True,
individual=False,
revin=True,
affine=True,
subtract_last=False,
):
super().__init__()
self.revin = revin
self.revin_layer = RevIN(c_in, affine=affine, subtract_last=subtract_last)
self.patch_len = patch_len
self.stride = stride
self.padding_patch = padding_patch
patch_num = int((seq_len - patch_len) / stride + 1) + 1
self.patch_num = patch_num
self.padding_patch_layer = nn.ReplicationPad1d((stride, 0))
self.unfold = nn.Unfold(kernel_size=(1, patch_len), stride=stride)
self.patch_len = patch_len
self.backbone = _TSTiEncoder(
c_in,
patch_num=patch_num,
patch_len=patch_len,
n_layers=n_layers,
d_model=d_model,
n_heads=n_heads,
d_k=d_k,
d_v=d_v,
d_ff=d_ff,
attn_dropout=attn_dropout,
dropout=dropout,
act=act,
res_attention=res_attention,
pre_norm=pre_norm,
store_attn=store_attn,
)
# Head
self.head_nf = d_model * patch_num
self.n_vars = c_in
self.individual = individual
self.head = Flatten_Head(self.individual, self.n_vars, self.head_nf, pred_dim)
def forward(self, z: Tensor):
"""
Args:
z: [bs x c_in x seq_len]
"""
if self.revin:
z = self.revin_layer(z, torch.tensor(True, dtype=torch.bool))
z = self.padding_patch_layer(z)
b, c, s = z.size()
z = z.reshape(-1, 1, 1, s)
z = self.unfold(z)
z = z.permute(0, 2, 1).reshape(b, c, -1, self.patch_len).permute(0, 1, 3, 2)
z = self.backbone(z)
z = self.head(z)
if self.revin:
z = self.revin_layer(z, torch.tensor(False, dtype=torch.bool))
return z
class PatchTST(nn.Module, PyTorchModelHubMixin):
def __init__(
self,
c_in,
c_out,
seq_len,
pred_dim=None,
n_layers=2,
n_heads=8,
d_model=512,
d_ff=2048,
dropout=0.05,
attn_dropout=0.0,
patch_len=16,
stride=8,
padding_patch=True,
revin=True,
affine=False,
individual=False,
subtract_last=False,
decomposition=False,
kernel_size=25,
activation="gelu",
norm="BatchNorm",
pre_norm=False,
res_attention=True,
store_attn=False,
classification=False,
):
super().__init__()
if pred_dim is None:
pred_dim = seq_len
self.decomposition = decomposition
if self.decomposition:
self.decomp_module = SeriesDecomposition(kernel_size)
self.model_trend = _PatchTST_backbone(
c_in=c_in,
seq_len=seq_len,
pred_dim=pred_dim,
patch_len=patch_len,
stride=stride,
n_layers=n_layers,
d_model=d_model,
n_heads=n_heads,
d_ff=d_ff,
norm=norm,
attn_dropout=attn_dropout,
dropout=dropout,
act=activation,
res_attention=res_attention,
pre_norm=pre_norm,
store_attn=store_attn,
padding_patch=padding_patch,
individual=individual,
revin=revin,
affine=affine,
subtract_last=subtract_last,
)
self.model_res = _PatchTST_backbone(
c_in=c_in,
seq_len=seq_len,
pred_dim=pred_dim,
patch_len=patch_len,
stride=stride,
n_layers=n_layers,
d_model=d_model,
n_heads=n_heads,
d_ff=d_ff,
norm=norm,
attn_dropout=attn_dropout,
dropout=dropout,
act=activation,
res_attention=res_attention,
pre_norm=pre_norm,
store_attn=store_attn,
padding_patch=padding_patch,
individual=individual,
revin=revin,
affine=affine,
subtract_last=subtract_last,
)
self.patch_num = self.model_trend.patch_num
else:
self.model = _PatchTST_backbone(
c_in=c_in,
seq_len=seq_len,
pred_dim=pred_dim,
patch_len=patch_len,
stride=stride,
n_layers=n_layers,
d_model=d_model,
n_heads=n_heads,
d_ff=d_ff,
norm=norm,
attn_dropout=attn_dropout,
dropout=dropout,
act=activation,
res_attention=res_attention,
pre_norm=pre_norm,
store_attn=store_attn,
padding_patch=padding_patch,
individual=individual,
revin=revin,
affine=affine,
subtract_last=subtract_last,
)
self.patch_num = self.model.patch_num
self.classification = classification
def forward(self, x):
if self.decomposition:
res_init, trend_init = self.decomp_module(x)
res = self.model_res(res_init)
trend = self.model_trend(trend_init)
x = res + trend
else:
x = self.model(x)
if self.classification:
x = x.squeeze(-2)
return x