SymTime / model.py
whenxuan's picture
whenxuan: add the patching for time series
62d1028 verified
from typing import Tuple
import torch
import torch.nn as nn
from torch import Tensor
from torch.nn import functional as F
from einops import rearrange, repeat
from transformers.modeling_utils import PreTrainedModel
from configuration_symtime import SymTimeConfig
from layers import MultiHeadAttention, TSTEncoder, TSTEncoderLayer
class SymTimeModel(PreTrainedModel):
"""
SymTime Model for Huggingface.
Parameters
----------
config: SymTimeConfig
The configuration of the SymTime model.
Attributes
----------
config: SymTimeConfig
The configuration of the SymTime model.
encoder: TSTEncoder
The encoder of the SymTime model.
Methods
-------
forward(x: Tensor) -> Tuple[Tensor, Tensor]:
Forward pass of the SymTime model.
_init_weights(module: nn.Module) -> None:
Initialize weights for the SymTime encoder stack.
"""
config_class = SymTimeConfig
def __init__(self, config: SymTimeConfig):
super().__init__(config)
self.config = config
self.patch_size = config.patch_size
self.stride = config.stride
self.padding_patch_layer = nn.ReplicationPad1d((0, self.stride))
self.encoder = TSTEncoder(
patch_size=config.patch_size,
num_layers=config.num_layers,
hidden_size=config.d_model,
num_heads=config.num_heads,
d_ff=config.d_ff,
norm=config.norm,
attn_dropout=config.dropout,
dropout=config.dropout,
act=config.act,
pre_norm=config.pre_norm,
)
# Initialize weights and apply final processing
self.post_init()
def _init_weights(self, module) -> None:
"""Initialize weights for the SymTime encoder stack.
The model is built on top of Hugging Face `PreTrainedModel`, so this method
is called recursively via `post_init()`. We keep the initialization aligned
with the current backbone structure in `layers.py`:
- `TSTEncoder.W_P`: patch projection linear layer
- `TSTEncoder.cls_token`: learnable CLS token
- `TSTEncoderLayer.self_attn`: Q/K/V and output projections
- `TSTEncoderLayer.ff`: feed-forward linear layers
- `LayerNorm` / `BatchNorm1d`: normalization layers
"""
super()._init_weights(module)
factor = self.config.initializer_factor
d_model = self.config.d_model
num_heads = self.config.num_heads
d_k = d_model // num_heads
d_v = d_k
if isinstance(module, nn.Linear):
nn.init.normal_(
module.weight, mean=0.0, std=factor * (module.in_features**-0.5)
)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.LayerNorm):
nn.init.ones_(module.weight)
nn.init.zeros_(module.bias)
elif isinstance(module, nn.BatchNorm1d):
if module.weight is not None:
nn.init.ones_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, TSTEncoder):
if hasattr(module, "cls_token") and module.cls_token is not None:
nn.init.normal_(module.cls_token, mean=0.0, std=factor)
if hasattr(module, "W_P") and isinstance(module.W_P, nn.Linear):
nn.init.normal_(
module.W_P.weight,
mean=0.0,
std=factor * (module.W_P.in_features**-0.5),
)
if module.W_P.bias is not None:
nn.init.zeros_(module.W_P.bias)
elif isinstance(module, MultiHeadAttention):
nn.init.normal_(module.W_Q.weight, mean=0.0, std=factor * (d_model**-0.5))
nn.init.normal_(module.W_K.weight, mean=0.0, std=factor * (d_model**-0.5))
nn.init.normal_(module.W_V.weight, mean=0.0, std=factor * (d_model**-0.5))
if module.W_Q.bias is not None:
nn.init.zeros_(module.W_Q.bias)
if module.W_K.bias is not None:
nn.init.zeros_(module.W_K.bias)
if module.W_V.bias is not None:
nn.init.zeros_(module.W_V.bias)
out_proj = module.to_out[0]
nn.init.normal_(
out_proj.weight, mean=0.0, std=factor * ((num_heads * d_v) ** -0.5)
)
if out_proj.bias is not None:
nn.init.zeros_(out_proj.bias)
elif isinstance(module, TSTEncoderLayer):
for submodule in module.ff:
if isinstance(submodule, nn.Linear):
nn.init.normal_(
submodule.weight,
mean=0.0,
std=factor * (submodule.in_features**-0.5),
)
if submodule.bias is not None:
nn.init.zeros_(submodule.bias)
def patching(self, time_series: torch.Tensor) -> torch.Tensor:
"""Split a raw 1D time series into overlapping or non-overlapping patches.
The encoder does not operate directly on the full sequence. Instead, it
first converts the input into a sequence of local windows, where each
window has length ``self.patch_size`` and consecutive windows are shifted
by ``self.stride``. This patch-based representation reduces the temporal
resolution while preserving local patterns that are useful for attention
layers.
If the sequence length is not compatible with the patch size, we pad the
sequence on the right using replication padding so that the final patch
extraction remains well-defined.
"""
# Unpack the input shape for clarity: each sample is a 1D signal.
batch_size, seq_length = time_series.shape
# When the sequence length cannot be evenly covered by the patch size,
# extend the sequence with replicated boundary values. This avoids
# discarding the tail of the signal and keeps the patching procedure
# consistent for every batch element.
if seq_length % self.patch_size != 0:
time_series = self.padding_patch_layer(time_series)
# Convert the padded sequence into a patch tensor using a sliding window.
# The resulting tensor contains local segments sampled along the last
# dimension, which will be consumed by the transformer encoder.
time_series = time_series.unfold(
dimension=-1, size=self.patch_size, step=self.stride
)
return time_series
def forward(
self, x: Tensor, return_cls_token: bool = True
) -> Tuple[Tensor, Tensor]:
"""Run the full SymTime inference pipeline.
The forward pass expects a 2D tensor of shape ``[batch_size, seq_length]``
containing a batch of univariate time series. The input is first converted
into patch embeddings through :meth:`patching`, and the resulting patch
sequence is then passed into the transformer encoder.
Parameters
----------
x : Tensor
Batched input time series with shape ``[batch_size, seq_length]``.
return_cls_token : bool, optional
If ``True``, the encoder also returns the learned CLS token output
alongside the patch-level representations. This is useful when the
downstream task needs a global sequence summary.
Returns
-------
Tuple[Tensor, Tensor]
The encoded patch sequence and, optionally, the CLS token output.
"""
# Validate that the input follows the expected batch-by-time layout.
assert (
x.dim() == 2
), "Input time series must be a 2D tensor with shape of [batch_size, seq_length]."
# Convert the raw signal into a patch-based representation before encoding.
time_series = self.patching(x)
# Feed the patch sequence into the transformer encoder and return its output.
return self.encoder(time_series, return_cls_token=return_cls_token)