File size: 8,444 Bytes
62d1028 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 | from typing import Tuple
import torch
import torch.nn as nn
from torch import Tensor
from torch.nn import functional as F
from einops import rearrange, repeat
from transformers.modeling_utils import PreTrainedModel
from configuration_symtime import SymTimeConfig
from layers import MultiHeadAttention, TSTEncoder, TSTEncoderLayer
class SymTimeModel(PreTrainedModel):
"""
SymTime Model for Huggingface.
Parameters
----------
config: SymTimeConfig
The configuration of the SymTime model.
Attributes
----------
config: SymTimeConfig
The configuration of the SymTime model.
encoder: TSTEncoder
The encoder of the SymTime model.
Methods
-------
forward(x: Tensor) -> Tuple[Tensor, Tensor]:
Forward pass of the SymTime model.
_init_weights(module: nn.Module) -> None:
Initialize weights for the SymTime encoder stack.
"""
config_class = SymTimeConfig
def __init__(self, config: SymTimeConfig):
super().__init__(config)
self.config = config
self.patch_size = config.patch_size
self.stride = config.stride
self.padding_patch_layer = nn.ReplicationPad1d((0, self.stride))
self.encoder = TSTEncoder(
patch_size=config.patch_size,
num_layers=config.num_layers,
hidden_size=config.d_model,
num_heads=config.num_heads,
d_ff=config.d_ff,
norm=config.norm,
attn_dropout=config.dropout,
dropout=config.dropout,
act=config.act,
pre_norm=config.pre_norm,
)
# Initialize weights and apply final processing
self.post_init()
def _init_weights(self, module) -> None:
"""Initialize weights for the SymTime encoder stack.
The model is built on top of Hugging Face `PreTrainedModel`, so this method
is called recursively via `post_init()`. We keep the initialization aligned
with the current backbone structure in `layers.py`:
- `TSTEncoder.W_P`: patch projection linear layer
- `TSTEncoder.cls_token`: learnable CLS token
- `TSTEncoderLayer.self_attn`: Q/K/V and output projections
- `TSTEncoderLayer.ff`: feed-forward linear layers
- `LayerNorm` / `BatchNorm1d`: normalization layers
"""
super()._init_weights(module)
factor = self.config.initializer_factor
d_model = self.config.d_model
num_heads = self.config.num_heads
d_k = d_model // num_heads
d_v = d_k
if isinstance(module, nn.Linear):
nn.init.normal_(
module.weight, mean=0.0, std=factor * (module.in_features**-0.5)
)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.LayerNorm):
nn.init.ones_(module.weight)
nn.init.zeros_(module.bias)
elif isinstance(module, nn.BatchNorm1d):
if module.weight is not None:
nn.init.ones_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, TSTEncoder):
if hasattr(module, "cls_token") and module.cls_token is not None:
nn.init.normal_(module.cls_token, mean=0.0, std=factor)
if hasattr(module, "W_P") and isinstance(module.W_P, nn.Linear):
nn.init.normal_(
module.W_P.weight,
mean=0.0,
std=factor * (module.W_P.in_features**-0.5),
)
if module.W_P.bias is not None:
nn.init.zeros_(module.W_P.bias)
elif isinstance(module, MultiHeadAttention):
nn.init.normal_(module.W_Q.weight, mean=0.0, std=factor * (d_model**-0.5))
nn.init.normal_(module.W_K.weight, mean=0.0, std=factor * (d_model**-0.5))
nn.init.normal_(module.W_V.weight, mean=0.0, std=factor * (d_model**-0.5))
if module.W_Q.bias is not None:
nn.init.zeros_(module.W_Q.bias)
if module.W_K.bias is not None:
nn.init.zeros_(module.W_K.bias)
if module.W_V.bias is not None:
nn.init.zeros_(module.W_V.bias)
out_proj = module.to_out[0]
nn.init.normal_(
out_proj.weight, mean=0.0, std=factor * ((num_heads * d_v) ** -0.5)
)
if out_proj.bias is not None:
nn.init.zeros_(out_proj.bias)
elif isinstance(module, TSTEncoderLayer):
for submodule in module.ff:
if isinstance(submodule, nn.Linear):
nn.init.normal_(
submodule.weight,
mean=0.0,
std=factor * (submodule.in_features**-0.5),
)
if submodule.bias is not None:
nn.init.zeros_(submodule.bias)
def patching(self, time_series: torch.Tensor) -> torch.Tensor:
"""Split a raw 1D time series into overlapping or non-overlapping patches.
The encoder does not operate directly on the full sequence. Instead, it
first converts the input into a sequence of local windows, where each
window has length ``self.patch_size`` and consecutive windows are shifted
by ``self.stride``. This patch-based representation reduces the temporal
resolution while preserving local patterns that are useful for attention
layers.
If the sequence length is not compatible with the patch size, we pad the
sequence on the right using replication padding so that the final patch
extraction remains well-defined.
"""
# Unpack the input shape for clarity: each sample is a 1D signal.
batch_size, seq_length = time_series.shape
# When the sequence length cannot be evenly covered by the patch size,
# extend the sequence with replicated boundary values. This avoids
# discarding the tail of the signal and keeps the patching procedure
# consistent for every batch element.
if seq_length % self.patch_size != 0:
time_series = self.padding_patch_layer(time_series)
# Convert the padded sequence into a patch tensor using a sliding window.
# The resulting tensor contains local segments sampled along the last
# dimension, which will be consumed by the transformer encoder.
time_series = time_series.unfold(
dimension=-1, size=self.patch_size, step=self.stride
)
return time_series
def forward(
self, x: Tensor, return_cls_token: bool = True
) -> Tuple[Tensor, Tensor]:
"""Run the full SymTime inference pipeline.
The forward pass expects a 2D tensor of shape ``[batch_size, seq_length]``
containing a batch of univariate time series. The input is first converted
into patch embeddings through :meth:`patching`, and the resulting patch
sequence is then passed into the transformer encoder.
Parameters
----------
x : Tensor
Batched input time series with shape ``[batch_size, seq_length]``.
return_cls_token : bool, optional
If ``True``, the encoder also returns the learned CLS token output
alongside the patch-level representations. This is useful when the
downstream task needs a global sequence summary.
Returns
-------
Tuple[Tensor, Tensor]
The encoded patch sequence and, optionally, the CLS token output.
"""
# Validate that the input follows the expected batch-by-time layout.
assert (
x.dim() == 2
), "Input time series must be a 2D tensor with shape of [batch_size, seq_length]."
# Convert the raw signal into a patch-based representation before encoding.
time_series = self.patching(x)
# Feed the patch sequence into the transformer encoder and return its output.
return self.encoder(time_series, return_cls_token=return_cls_token)
|