| from typing import Tuple
|
|
|
| import torch
|
| import torch.nn as nn
|
| from torch import Tensor
|
| from torch.nn import functional as F
|
| from einops import rearrange, repeat
|
| from transformers.modeling_utils import PreTrainedModel
|
|
|
| from configuration_symtime import SymTimeConfig
|
| from layers import MultiHeadAttention, TSTEncoder, TSTEncoderLayer
|
|
|
|
|
| class SymTimeModel(PreTrainedModel):
|
| """
|
| SymTime Model for Huggingface.
|
|
|
| Parameters
|
| ----------
|
| config: SymTimeConfig
|
| The configuration of the SymTime model.
|
|
|
| Attributes
|
| ----------
|
| config: SymTimeConfig
|
| The configuration of the SymTime model.
|
| encoder: TSTEncoder
|
| The encoder of the SymTime model.
|
|
|
| Methods
|
| -------
|
| forward(x: Tensor) -> Tuple[Tensor, Tensor]:
|
| Forward pass of the SymTime model.
|
|
|
| _init_weights(module: nn.Module) -> None:
|
| Initialize weights for the SymTime encoder stack.
|
| """
|
|
|
| def __init__(self, config: SymTimeConfig):
|
| super().__init__(config)
|
| self.config = config
|
| self.encoder = TSTEncoder(
|
| patch_size=config.patch_size,
|
| num_layers=config.num_layers,
|
| hidden_size=config.d_model,
|
| num_heads=config.num_heads,
|
| d_ff=config.d_ff,
|
| norm=config.norm,
|
| attn_dropout=config.dropout,
|
| dropout=config.dropout,
|
| act=config.act,
|
| pre_norm=config.pre_norm,
|
| )
|
|
|
|
|
| self.post_init()
|
|
|
| def _init_weights(self, module) -> None:
|
| """Initialize weights for the SymTime encoder stack.
|
|
|
| The model is built on top of Hugging Face `PreTrainedModel`, so this method
|
| is called recursively via `post_init()`. We keep the initialization aligned
|
| with the current backbone structure in `layers.py`:
|
|
|
| - `TSTEncoder.W_P`: patch projection linear layer
|
| - `TSTEncoder.cls_token`: learnable CLS token
|
| - `TSTEncoderLayer.self_attn`: Q/K/V and output projections
|
| - `TSTEncoderLayer.ff`: feed-forward linear layers
|
| - `LayerNorm` / `BatchNorm1d`: normalization layers
|
| """
|
| super()._init_weights(module)
|
|
|
| factor = self.config.initializer_factor
|
| d_model = self.config.d_model
|
| num_heads = self.config.num_heads
|
| d_k = d_model // num_heads
|
| d_v = d_k
|
|
|
| if isinstance(module, nn.Linear):
|
| nn.init.normal_(
|
| module.weight, mean=0.0, std=factor * (module.in_features**-0.5)
|
| )
|
| if module.bias is not None:
|
| nn.init.zeros_(module.bias)
|
|
|
| elif isinstance(module, nn.LayerNorm):
|
| nn.init.ones_(module.weight)
|
| nn.init.zeros_(module.bias)
|
|
|
| elif isinstance(module, nn.BatchNorm1d):
|
| if module.weight is not None:
|
| nn.init.ones_(module.weight)
|
| if module.bias is not None:
|
| nn.init.zeros_(module.bias)
|
|
|
| elif isinstance(module, TSTEncoder):
|
| if hasattr(module, "cls_token") and module.cls_token is not None:
|
| nn.init.normal_(module.cls_token, mean=0.0, std=factor)
|
| if hasattr(module, "W_P") and isinstance(module.W_P, nn.Linear):
|
| nn.init.normal_(
|
| module.W_P.weight,
|
| mean=0.0,
|
| std=factor * (module.W_P.in_features**-0.5),
|
| )
|
| if module.W_P.bias is not None:
|
| nn.init.zeros_(module.W_P.bias)
|
|
|
| elif isinstance(module, MultiHeadAttention):
|
| nn.init.normal_(module.W_Q.weight, mean=0.0, std=factor * (d_model**-0.5))
|
| nn.init.normal_(module.W_K.weight, mean=0.0, std=factor * (d_model**-0.5))
|
| nn.init.normal_(module.W_V.weight, mean=0.0, std=factor * (d_model**-0.5))
|
| if module.W_Q.bias is not None:
|
| nn.init.zeros_(module.W_Q.bias)
|
| if module.W_K.bias is not None:
|
| nn.init.zeros_(module.W_K.bias)
|
| if module.W_V.bias is not None:
|
| nn.init.zeros_(module.W_V.bias)
|
|
|
| out_proj = module.to_out[0]
|
| nn.init.normal_(
|
| out_proj.weight, mean=0.0, std=factor * ((num_heads * d_v) ** -0.5)
|
| )
|
| if out_proj.bias is not None:
|
| nn.init.zeros_(out_proj.bias)
|
|
|
| elif isinstance(module, TSTEncoderLayer):
|
| for submodule in module.ff:
|
| if isinstance(submodule, nn.Linear):
|
| nn.init.normal_(
|
| submodule.weight,
|
| mean=0.0,
|
| std=factor * (submodule.in_features**-0.5),
|
| )
|
| if submodule.bias is not None:
|
| nn.init.zeros_(submodule.bias)
|
|
|
| def forward(
|
| self, x: Tensor, return_cls_token: bool = True
|
| ) -> Tuple[Tensor, Tensor]:
|
| return self.encoder(x, return_cls_token=return_cls_token)
|
|
|