File size: 8,444 Bytes
62d1028
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
from typing import Tuple

import torch
import torch.nn as nn
from torch import Tensor
from torch.nn import functional as F
from einops import rearrange, repeat
from transformers.modeling_utils import PreTrainedModel

from configuration_symtime import SymTimeConfig
from layers import MultiHeadAttention, TSTEncoder, TSTEncoderLayer


class SymTimeModel(PreTrainedModel):
    """

    SymTime Model for Huggingface.



    Parameters

    ----------

    config: SymTimeConfig

        The configuration of the SymTime model.



    Attributes

    ----------

    config: SymTimeConfig

        The configuration of the SymTime model.

    encoder: TSTEncoder

        The encoder of the SymTime model.



    Methods

    -------

    forward(x: Tensor) -> Tuple[Tensor, Tensor]:

        Forward pass of the SymTime model.



    _init_weights(module: nn.Module) -> None:

        Initialize weights for the SymTime encoder stack.

    """

    config_class = SymTimeConfig

    def __init__(self, config: SymTimeConfig):
        super().__init__(config)
        self.config = config

        self.patch_size = config.patch_size
        self.stride = config.stride

        self.padding_patch_layer = nn.ReplicationPad1d((0, self.stride))
        self.encoder = TSTEncoder(
            patch_size=config.patch_size,
            num_layers=config.num_layers,
            hidden_size=config.d_model,
            num_heads=config.num_heads,
            d_ff=config.d_ff,
            norm=config.norm,
            attn_dropout=config.dropout,
            dropout=config.dropout,
            act=config.act,
            pre_norm=config.pre_norm,
        )

        # Initialize weights and apply final processing
        self.post_init()

    def _init_weights(self, module) -> None:
        """Initialize weights for the SymTime encoder stack.



        The model is built on top of Hugging Face `PreTrainedModel`, so this method

        is called recursively via `post_init()`. We keep the initialization aligned

        with the current backbone structure in `layers.py`:



        - `TSTEncoder.W_P`: patch projection linear layer

        - `TSTEncoder.cls_token`: learnable CLS token

        - `TSTEncoderLayer.self_attn`: Q/K/V and output projections

        - `TSTEncoderLayer.ff`: feed-forward linear layers

        - `LayerNorm` / `BatchNorm1d`: normalization layers

        """
        super()._init_weights(module)

        factor = self.config.initializer_factor
        d_model = self.config.d_model
        num_heads = self.config.num_heads
        d_k = d_model // num_heads
        d_v = d_k

        if isinstance(module, nn.Linear):
            nn.init.normal_(
                module.weight, mean=0.0, std=factor * (module.in_features**-0.5)
            )
            if module.bias is not None:
                nn.init.zeros_(module.bias)

        elif isinstance(module, nn.LayerNorm):
            nn.init.ones_(module.weight)
            nn.init.zeros_(module.bias)

        elif isinstance(module, nn.BatchNorm1d):
            if module.weight is not None:
                nn.init.ones_(module.weight)
            if module.bias is not None:
                nn.init.zeros_(module.bias)

        elif isinstance(module, TSTEncoder):
            if hasattr(module, "cls_token") and module.cls_token is not None:
                nn.init.normal_(module.cls_token, mean=0.0, std=factor)
            if hasattr(module, "W_P") and isinstance(module.W_P, nn.Linear):
                nn.init.normal_(
                    module.W_P.weight,
                    mean=0.0,
                    std=factor * (module.W_P.in_features**-0.5),
                )
                if module.W_P.bias is not None:
                    nn.init.zeros_(module.W_P.bias)

        elif isinstance(module, MultiHeadAttention):
            nn.init.normal_(module.W_Q.weight, mean=0.0, std=factor * (d_model**-0.5))
            nn.init.normal_(module.W_K.weight, mean=0.0, std=factor * (d_model**-0.5))
            nn.init.normal_(module.W_V.weight, mean=0.0, std=factor * (d_model**-0.5))
            if module.W_Q.bias is not None:
                nn.init.zeros_(module.W_Q.bias)
            if module.W_K.bias is not None:
                nn.init.zeros_(module.W_K.bias)
            if module.W_V.bias is not None:
                nn.init.zeros_(module.W_V.bias)

            out_proj = module.to_out[0]
            nn.init.normal_(
                out_proj.weight, mean=0.0, std=factor * ((num_heads * d_v) ** -0.5)
            )
            if out_proj.bias is not None:
                nn.init.zeros_(out_proj.bias)

        elif isinstance(module, TSTEncoderLayer):
            for submodule in module.ff:
                if isinstance(submodule, nn.Linear):
                    nn.init.normal_(
                        submodule.weight,
                        mean=0.0,
                        std=factor * (submodule.in_features**-0.5),
                    )
                    if submodule.bias is not None:
                        nn.init.zeros_(submodule.bias)

    def patching(self, time_series: torch.Tensor) -> torch.Tensor:
        """Split a raw 1D time series into overlapping or non-overlapping patches.



        The encoder does not operate directly on the full sequence. Instead, it

        first converts the input into a sequence of local windows, where each

        window has length ``self.patch_size`` and consecutive windows are shifted

        by ``self.stride``. This patch-based representation reduces the temporal

        resolution while preserving local patterns that are useful for attention

        layers.



        If the sequence length is not compatible with the patch size, we pad the

        sequence on the right using replication padding so that the final patch

        extraction remains well-defined.

        """

        # Unpack the input shape for clarity: each sample is a 1D signal.
        batch_size, seq_length = time_series.shape

        # When the sequence length cannot be evenly covered by the patch size,
        # extend the sequence with replicated boundary values. This avoids
        # discarding the tail of the signal and keeps the patching procedure
        # consistent for every batch element.
        if seq_length % self.patch_size != 0:
            time_series = self.padding_patch_layer(time_series)

        # Convert the padded sequence into a patch tensor using a sliding window.
        # The resulting tensor contains local segments sampled along the last
        # dimension, which will be consumed by the transformer encoder.
        time_series = time_series.unfold(
            dimension=-1, size=self.patch_size, step=self.stride
        )

        return time_series

    def forward(

        self, x: Tensor, return_cls_token: bool = True

    ) -> Tuple[Tensor, Tensor]:
        """Run the full SymTime inference pipeline.



        The forward pass expects a 2D tensor of shape ``[batch_size, seq_length]``

        containing a batch of univariate time series. The input is first converted

        into patch embeddings through :meth:`patching`, and the resulting patch

        sequence is then passed into the transformer encoder.



        Parameters

        ----------

        x : Tensor

            Batched input time series with shape ``[batch_size, seq_length]``.

        return_cls_token : bool, optional

            If ``True``, the encoder also returns the learned CLS token output

            alongside the patch-level representations. This is useful when the

            downstream task needs a global sequence summary.



        Returns

        -------

        Tuple[Tensor, Tensor]

            The encoded patch sequence and, optionally, the CLS token output.

        """

        # Validate that the input follows the expected batch-by-time layout.
        assert (
            x.dim() == 2
        ), "Input time series must be a 2D tensor with shape of [batch_size, seq_length]."

        # Convert the raw signal into a patch-based representation before encoding.
        time_series = self.patching(x)

        # Feed the patch sequence into the transformer encoder and return its output.
        return self.encoder(time_series, return_cls_token=return_cls_token)