SymTime / model.py
whenxuan's picture
whenxuan: init
233512e verified
raw
history blame
5.16 kB
from typing import Tuple
import torch
import torch.nn as nn
from torch import Tensor
from torch.nn import functional as F
from einops import rearrange, repeat
from transformers.modeling_utils import PreTrainedModel
from configuration_symtime import SymTimeConfig
from layers import MultiHeadAttention, TSTEncoder, TSTEncoderLayer
class SymTimeModel(PreTrainedModel):
"""
SymTime Model for Huggingface.
Parameters
----------
config: SymTimeConfig
The configuration of the SymTime model.
Attributes
----------
config: SymTimeConfig
The configuration of the SymTime model.
encoder: TSTEncoder
The encoder of the SymTime model.
Methods
-------
forward(x: Tensor) -> Tuple[Tensor, Tensor]:
Forward pass of the SymTime model.
_init_weights(module: nn.Module) -> None:
Initialize weights for the SymTime encoder stack.
"""
def __init__(self, config: SymTimeConfig):
super().__init__(config)
self.config = config
self.encoder = TSTEncoder(
patch_size=config.patch_size,
num_layers=config.num_layers,
hidden_size=config.d_model,
num_heads=config.num_heads,
d_ff=config.d_ff,
norm=config.norm,
attn_dropout=config.dropout,
dropout=config.dropout,
act=config.act,
pre_norm=config.pre_norm,
)
# Initialize weights and apply final processing
self.post_init()
def _init_weights(self, module) -> None:
"""Initialize weights for the SymTime encoder stack.
The model is built on top of Hugging Face `PreTrainedModel`, so this method
is called recursively via `post_init()`. We keep the initialization aligned
with the current backbone structure in `layers.py`:
- `TSTEncoder.W_P`: patch projection linear layer
- `TSTEncoder.cls_token`: learnable CLS token
- `TSTEncoderLayer.self_attn`: Q/K/V and output projections
- `TSTEncoderLayer.ff`: feed-forward linear layers
- `LayerNorm` / `BatchNorm1d`: normalization layers
"""
super()._init_weights(module)
factor = self.config.initializer_factor
d_model = self.config.d_model
num_heads = self.config.num_heads
d_k = d_model // num_heads
d_v = d_k
if isinstance(module, nn.Linear):
nn.init.normal_(
module.weight, mean=0.0, std=factor * (module.in_features**-0.5)
)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.LayerNorm):
nn.init.ones_(module.weight)
nn.init.zeros_(module.bias)
elif isinstance(module, nn.BatchNorm1d):
if module.weight is not None:
nn.init.ones_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, TSTEncoder):
if hasattr(module, "cls_token") and module.cls_token is not None:
nn.init.normal_(module.cls_token, mean=0.0, std=factor)
if hasattr(module, "W_P") and isinstance(module.W_P, nn.Linear):
nn.init.normal_(
module.W_P.weight,
mean=0.0,
std=factor * (module.W_P.in_features**-0.5),
)
if module.W_P.bias is not None:
nn.init.zeros_(module.W_P.bias)
elif isinstance(module, MultiHeadAttention):
nn.init.normal_(module.W_Q.weight, mean=0.0, std=factor * (d_model**-0.5))
nn.init.normal_(module.W_K.weight, mean=0.0, std=factor * (d_model**-0.5))
nn.init.normal_(module.W_V.weight, mean=0.0, std=factor * (d_model**-0.5))
if module.W_Q.bias is not None:
nn.init.zeros_(module.W_Q.bias)
if module.W_K.bias is not None:
nn.init.zeros_(module.W_K.bias)
if module.W_V.bias is not None:
nn.init.zeros_(module.W_V.bias)
out_proj = module.to_out[0]
nn.init.normal_(
out_proj.weight, mean=0.0, std=factor * ((num_heads * d_v) ** -0.5)
)
if out_proj.bias is not None:
nn.init.zeros_(out_proj.bias)
elif isinstance(module, TSTEncoderLayer):
for submodule in module.ff:
if isinstance(submodule, nn.Linear):
nn.init.normal_(
submodule.weight,
mean=0.0,
std=factor * (submodule.in_features**-0.5),
)
if submodule.bias is not None:
nn.init.zeros_(submodule.bias)
def forward(
self, x: Tensor, return_cls_token: bool = True
) -> Tuple[Tensor, Tensor]:
return self.encoder(x, return_cls_token=return_cls_token)