| | from typing import Optional, Tuple, Dict, List, Union |
| | import copy |
| | import numpy as np |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from torch.nn.functional import interpolate |
| |
|
| | from transformers import PreTrainedModel |
| | from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions |
| | from .configuration_super_linear import SuperLinearConfig |
| |
|
| |
|
| |
|
| | "-------------------------------------------------------------------------------------------------------------------" |
| | class Linear(nn.Module): |
| | """Simple linear layer expert.""" |
| | def __init__(self, input_len, output_len): |
| | super(Linear, self).__init__() |
| | self.Linear = nn.Linear(input_len, output_len) |
| |
|
| | def forward(self, x): |
| | |
| | x = x.clone() |
| | x = self.Linear(x).clone() |
| | return x |
| | |
| | class Naive(nn.Module): |
| | """Naive forecasting expert - repeats last value.""" |
| | def __init__(self, input_len, output_len): |
| | super(Naive, self).__init__() |
| | self.output_len = output_len |
| |
|
| | def forward(self, x): |
| | |
| | x = x[:,-1].unsqueeze(1).repeat(1, self.output_len) |
| | return x |
| | |
| | class Mean(nn.Module): |
| | """Mean forecasting expert - repeats mean value.""" |
| | def __init__(self, input_len, output_len): |
| | super(Mean, self).__init__() |
| | self.output_len = output_len |
| |
|
| | def forward(self, x): |
| | |
| | x = x.mean(dim=1).unsqueeze(1).repeat(1, self.output_len) |
| | return x |
| |
|
| | "-------------------------------------------------------------------------------------------------------------------" |
| | class SparseMoE(nn.Module): |
| | """ |
| | Sparse Mixture of Experts (MoE) module that routes inputs to the most relevant experts. |
| | |
| | This implementation uses a gating network to determine which experts should process each input. |
| | Only the top-k experts are used for each input, creating a sparse computation pattern. |
| | |
| | Args: |
| | configs: Configuration object containing MoE parameters |
| | experts: Collection of expert modules (neural networks) |
| | """ |
| | def __init__(self, configs, experts=None): |
| | super(SparseMoE, self).__init__() |
| | self.noise_std = configs.noisy_gating_std |
| | self.experts = nn.ModuleList(experts) |
| | self.num_experts = len(experts) |
| | self.k = configs.top_k_experts |
| | |
| | if self.k > self.num_experts: |
| | self.k = self.num_experts |
| | |
| | self.moe_temp = configs.moe_temp |
| | self.use_fft = configs.use_fft |
| | self.fft_len = configs.fft_len |
| | self.moe_norm = configs.moe_norm |
| |
|
| | |
| | self.stacked_weights = None |
| | self.stacked_biases = None |
| |
|
| | |
| | self.linear_expert_types = ['Linear'] |
| | self.linear_experts = [] |
| | self.nonlinear_experts = [] |
| |
|
| | for idx, expert in enumerate(self.experts): |
| | expert_type = type(expert).__name__ |
| | if expert_type in self.linear_expert_types: |
| | self.linear_experts.append(idx) |
| | else: |
| | self.nonlinear_experts.append(idx) |
| | self.num_linear_experts = len(self.linear_experts) |
| | self.num_nonlinear_experts = len(self.nonlinear_experts) |
| | |
| | |
| | if self.use_fft: |
| | self.gating_network = nn.Linear(self.fft_len//2, self.num_experts, bias=True) |
| | else: |
| | self.gating_network = nn.Linear(configs.train_seq_len, self.num_experts, bias=True) |
| |
|
| | if self.moe_norm: |
| | self.batch_norm = nn.BatchNorm1d(self.num_experts) |
| |
|
| | def _get_stacked_expert_params(self): |
| | """Get batched parameters for linear experts.""" |
| | if self.stacked_weights is None: |
| | |
| | weights = torch.stack([self.experts[i].Linear.weight for i in self.linear_experts], dim=0) |
| | |
| | biases = torch.stack([self.experts[i].Linear.bias for i in self.linear_experts], dim=0) |
| |
|
| | self.stacked_weights = weights |
| | self.stacked_biases = biases |
| | return self.stacked_weights, self.stacked_biases |
| |
|
| | def get_periodogram(self, inputs, n=10000): |
| | """ |
| | Calculate the periodogram (power spectral density) of input time series. |
| | |
| | The periodogram is used as a frequency-domain representation of the signal |
| | to help the gating network identify periodic patterns. |
| | |
| | Args: |
| | inputs: Input time series tensor of shape [batch_size, sequence_length] or [batch_size, sequence_length, features] |
| | n: Number of points in FFT computation |
| | |
| | Returns: |
| | Normalized periodogram of the input signals |
| | """ |
| | x_0 = inputs - torch.mean(inputs, dim=1, keepdim=True) |
| |
|
| | |
| | dft = torch.fft.fft(x_0, dim=1, n=n) / np.sqrt(n) |
| | dft = dft[:, :n//2] |
| | I = torch.abs(dft) ** 2 |
| |
|
| | |
| | I_sum = torch.sum(I, dim=1, keepdim=True) |
| | I_sum[I_sum == 0] = 1 |
| | I = I / I_sum |
| | |
| | return I |
| |
|
| | def forward(self, x, get_prob=False, get_prob_only=False): |
| | """ |
| | Forward pass through the Mixture of Experts. |
| | |
| | Args: |
| | x: Input tensor of shape [batch_size, sequence_length] |
| | get_prob: Whether to return expert selection probabilities |
| | get_prob_only: Whether to return only probabilities without computation |
| | |
| | Returns: |
| | - Output tensor from the selected experts |
| | - (Optional) Expert selection probabilities if get_prob is True |
| | """ |
| | |
| | if self.use_fft: |
| | x_0 = self.get_periodogram(x, n=self.fft_len) |
| | else: |
| | x_0 = x |
| | |
| | |
| | gate_outputs = self.gating_network(x_0) |
| | |
| | if self.moe_norm: |
| | gate_outputs = self.batch_norm(gate_outputs) |
| |
|
| | |
| | if not self.training: |
| | gate_outputs = gate_outputs / self.moe_temp |
| |
|
| | if get_prob_only: |
| | expert_probs = F.softmax(gate_outputs, dim=1) |
| | return expert_probs |
| |
|
| | |
| | if self.training: |
| | noise = torch.randn_like(gate_outputs).to(x.device) * self.noise_std |
| | noisy_gate_outputs = gate_outputs + noise |
| | topk_values, topk_indices = torch.topk(noisy_gate_outputs, self.k, dim=1) |
| | else: |
| | topk_values, topk_indices = torch.topk(gate_outputs, self.k, dim=1) |
| |
|
| | |
| | topk_gates = F.softmax(topk_values, dim=1) |
| |
|
| | batch_size = x.size(0) |
| |
|
| | |
| | x_mean, x_std = torch.mean(x, dim=1, keepdim=True), torch.std(x, dim=1, keepdim=True) |
| | x_norm = (x - x_mean) / (x_std + 1e-5) |
| |
|
| | |
| | pred_len = self.experts[0](x_norm[:1]).shape[-1] |
| | expert_outputs = torch.zeros(batch_size, self.num_experts, pred_len, device=x.device) |
| |
|
| | |
| | if self.num_linear_experts > 0: |
| | all_weights, all_biases = self._get_stacked_expert_params() |
| | |
| | |
| | linear_expert_outputs = torch.einsum('epd,bd->bep', all_weights, x_norm) |
| | |
| | linear_expert_outputs = linear_expert_outputs + all_biases.unsqueeze(0) |
| | |
| | for i, expert_idx in enumerate(self.linear_experts): |
| | expert_outputs[:, expert_idx, :] = linear_expert_outputs[:, i, :] |
| |
|
| | |
| | for expert_idx in self.nonlinear_experts: |
| | expert_outputs[:, expert_idx, :] = self.experts[expert_idx](x_norm) |
| |
|
| | |
| | topk_indices_expanded = topk_indices.unsqueeze(-1).expand(-1, -1, expert_outputs.size(2)) |
| | sparse_expert_outputs = torch.gather(expert_outputs, 1, topk_indices_expanded) |
| |
|
| | |
| | output = torch.sum(topk_gates.unsqueeze(2) * sparse_expert_outputs, dim=1) |
| |
|
| | |
| | output = output * (x_std + 1e-5) + x_mean |
| |
|
| | if get_prob: |
| | expert_probs = F.softmax(gate_outputs, dim=1) |
| | return output, expert_probs |
| | |
| | return output |
| |
|
| |
|
| | class Model(nn.Module): |
| | """ |
| | Main model class that employs a Mixture of Experts for time series forecasting. |
| | |
| | This model can work with various types of linear layers as experts and supports |
| | both standard prediction and auto-regressive prediction for longer horizons. |
| | |
| | Args: |
| | configs: Configuration object containing model parameters |
| | """ |
| | def __init__(self, configs): |
| | super(Model, self).__init__() |
| |
|
| | self.configs = copy.deepcopy(configs) |
| |
|
| | |
| | self.train_pred_len = configs.train_pred_len |
| | self.train_seq_len = configs.train_seq_len |
| | self.layer_type = configs.layer_type |
| |
|
| |
|
| | |
| | self.long_horizon_scaling = configs.long_horizon_scaling |
| | self.lookback_resampling = configs.lookback_resampling |
| | lookback_scale_str = configs.scale_list |
| | if isinstance(lookback_scale_str, str): |
| | self.scale_list = [float(x.strip()) for x in lookback_scale_str.split(',')] |
| | else: |
| | self.scale_list = lookback_scale_str |
| | self.threshold = configs.threshold |
| | self.freq_bound = configs.freq_bound |
| | self.penalty_scale = configs.penalty_scale |
| | self.fft_len = configs.fft_len |
| |
|
| | |
| | freq_experts_str = configs.freq_experts |
| | if freq_experts_str == "": |
| | self.freq_experts = None |
| | else: |
| | self.freq_experts = freq_experts_str.split('_') |
| |
|
| | |
| | self.top_k_experts = configs.top_k_experts |
| | self.freeze_experts = configs.freeze_experts |
| |
|
| | |
| | self.experts = {} |
| | if self.freq_experts is not None: |
| | for expert_freq in self.freq_experts: |
| | if expert_freq.lower() == "naive": |
| | self.experts[expert_freq] = Naive(self.train_seq_len, self.train_pred_len) |
| | elif expert_freq.lower() == "mean": |
| | self.experts[expert_freq] = Mean(self.train_seq_len, self.train_pred_len) |
| | else: |
| | self.experts[expert_freq] = Linear(self.train_seq_len, self.train_pred_len) |
| | self.n_experts = len(self.experts) |
| | else: |
| | raise ValueError("Please specify experts in the configuration.") |
| |
|
| | |
| | comp_moe = configs.comp_moe |
| | if comp_moe > 0: |
| | if comp_moe == 1: |
| | print("Creating complementary expert") |
| | self.experts["comp"] = Linear(self.train_seq_len, self.train_pred_len) |
| | else: |
| | for i in range(comp_moe): |
| | print(f"Creating complementary expert {i}") |
| | self.experts["comp_"+str(i)] = Linear(self.train_seq_len, self.train_pred_len) |
| | |
| | |
| | self.moe = SparseMoE(configs, experts=self.experts.values()) |
| | |
| | print("Experts:", self.experts.keys()) |
| |
|
| | def add_experts(self, experts: Dict[str, nn.Module]) -> nn.Module: |
| | """ |
| | Add new experts to the model. |
| | |
| | Args: |
| | experts: Dictionary of expert instances to add |
| | |
| | Returns: |
| | Updated MoE layer |
| | """ |
| | for name, expert in experts.items(): |
| | if name not in self.experts: |
| | self.experts[name] = expert |
| | print(f"Added expert: {name}") |
| | else: |
| | print(f"Expert {name} already exists. Skipping addition.") |
| | |
| | self.moe = SparseMoE(self.configs, experts=self.experts.values()) |
| | return self.moe |
| |
|
| | def apply_long_horizon_scaling(self, ar_out: torch.Tensor, ar_x: torch.Tensor) -> torch.Tensor: |
| | """ |
| | Apply scaling to auto-regressive outputs to maintain statistical properties during long horizon prediction. |
| | |
| | This function identifies cases where the variance of the new predictions exceeds the variance |
| | of the input sequence and applies scaling to maintain consistent statistical properties. |
| | |
| | Args: |
| | ar_out: Auto-regressive output tensor of shape [batch_size * features, pred_len] |
| | ar_x: Input sequence tensor of shape [batch_size * features, seq_len] |
| | |
| | Returns: |
| | Scaled auto-regressive output tensor |
| | """ |
| | if not (self.long_horizon_scaling and not self.training): |
| | return ar_out |
| | |
| | |
| | std_new = torch.std(ar_out, dim=1, keepdim=True) |
| | mean_new = torch.mean(ar_out, dim=1, keepdim=True) |
| | std_old = torch.std(ar_x, dim=1, keepdim=True) |
| | |
| | |
| | inds = torch.where(std_new / std_old > 1)[0] |
| | |
| | if len(inds) > 0: |
| | |
| | ar_out_centered = ar_out[inds] - mean_new[inds] |
| | |
| | |
| | scaling = std_old[inds] / (std_new[inds] + 1e-8) |
| | |
| | |
| | ar_out_adjusted = ar_out_centered * scaling + mean_new[inds] |
| | ar_out[inds] = ar_out_adjusted |
| | |
| | return ar_out |
| |
|
| | def lookback_resample_search(self, x, scale_list=[2,4,6], min_lookback=512): |
| | """ |
| | Search for optimal resampling scale based on lookback analysis of expert selection. |
| | |
| | This function analyzes the frequency content and expert selection lookback to determine |
| | the best resampling scale for each input sequence, potentially improving model performance |
| | by matching input characteristics to expert capabilities. |
| | |
| | Args: |
| | x: Input tensor of shape [batch_size, features, sequence_length] |
| | scale_list: List of potential downsampling scales to evaluate |
| | min_lookback: Minimum sequence length required after resampling |
| | |
| | Returns: |
| | Tuple of (resampled_input, final_scales) where: |
| | - resampled_input: Optimally resampled input tensor |
| | - final_scales: Scale factors used for each sample |
| | """ |
| | B, V, L = x.shape |
| |
|
| | lookback = self.train_seq_len |
| | x_0 = x.reshape(B*V, L)[:, -lookback:] |
| | output_x = x_0.clone()[:, -lookback:] |
| |
|
| | x_reshape = x.reshape(B*V, L) |
| | x_fft_init = self.moe.get_periodogram(x_reshape, n=self.fft_len) |
| |
|
| | right_cumsum = torch.cumsum(x_fft_init, dim=-1) |
| | mask = right_cumsum > 1-self.threshold |
| | j_threshold = mask.float().argmax(dim=-1) |
| |
|
| | freqs = np.array([np.linspace(0, 0.5, self.fft_len//2)]) |
| | threshhold_freqs = np.take_along_axis(freqs, j_threshold.unsqueeze(-1).detach().cpu().numpy(), axis=1) |
| | |
| | |
| | threshhold_freqs[threshhold_freqs == 0] = self.freq_bound |
| | max_scale_factor = (self.freq_bound/ threshhold_freqs).astype(int).flatten() |
| |
|
| |
|
| | if self.threshold==0: |
| | max_scale_factor = np.inf * np.ones(B*V, dtype=int) |
| |
|
| | |
| | energy_loss_penalties = {} |
| | total_energy = torch.sum(x_fft_init, dim=-1) |
| | |
| | for scale in scale_list: |
| | if scale <= 1: |
| | continue |
| | |
| | |
| | nyquist_after_downsample = 0.5 / scale |
| | |
| | |
| | freq_bins = torch.linspace(0, 0.5, self.fft_len//2, device=x_fft_init.device) |
| | lost_freq_mask = freq_bins > nyquist_after_downsample |
| | |
| | |
| | lost_energy = torch.sum(x_fft_init[:, lost_freq_mask], dim=-1) |
| | |
| | energy_loss_fraction = lost_energy / (total_energy + 1e-10) |
| | energy_loss_penalties[scale] = energy_loss_fraction |
| |
|
| | |
| | prob = self.moe(x_0, get_prob_only=True) |
| | best_scores = -torch.sum(prob * torch.log(prob + 1e-10), dim=-1) |
| | final_scales = torch.ones(B*V, device=x.device) |
| |
|
| | for scale in scale_list: |
| | x_interp = torch.nn.functional.interpolate( |
| | x, scale_factor=1/scale, mode='linear', align_corners=True |
| | ) |
| | |
| | if x_interp.shape[2] >= min_lookback: |
| | x_interp_reshaped = x_interp.reshape(B*V, x_interp.shape[-1]) |
| | x_interp_reshaped = x_interp_reshaped[:, -lookback:] |
| | prob = self.moe(x_interp_reshaped, get_prob_only=True) |
| |
|
| | scores = -torch.sum(prob * torch.log(prob + 1e-10), dim=-1) |
| | |
| | |
| | if scale in energy_loss_penalties: |
| | energy_penalty = energy_loss_penalties[scale] |
| | scores = scores + energy_penalty*self.penalty_scale |
| |
|
| | idx = np.where((scores < best_scores).cpu() & torch.tensor(max_scale_factor >= scale))[0] |
| |
|
| | if len(idx) > 0: |
| | output_x[idx] = x_interp_reshaped[idx] |
| | best_scores[idx] = scores[idx] |
| | final_scales[idx] = scale |
| |
|
| | return output_x.reshape(B, V, output_x.shape[-1]), final_scales |
| |
|
| | def lookback_resample_reverse(self, y, final_scales, inf_pred_len=None): |
| | """ |
| | Reverse the resampling operation on the output. |
| | |
| | This function upsamples the model outputs back to the original scale |
| | based on the resampling factors used during input processing. |
| | |
| | Args: |
| | y: Output tensor from model of shape [batch_size, features, pred_len] |
| | final_scales: Scale factors used during input resampling |
| | inf_pred_len: Target prediction length |
| | |
| | Returns: |
| | Upsampled output tensor of shape [batch_size, features, inf_pred_len] |
| | """ |
| | B, V, L = y.shape |
| | y_reshaped = y.view(B*V, L) |
| | y_out = y_reshaped[:, :inf_pred_len] |
| |
|
| | unique_scales = torch.unique(final_scales) |
| | for scale in unique_scales: |
| | scale_val = scale.item() |
| | if scale_val > 1: |
| | idx = torch.where(final_scales == scale)[0] |
| |
|
| | if len(idx) > 0: |
| | y_interp = torch.nn.functional.interpolate( |
| | y_reshaped[idx].unsqueeze(1), scale_factor=scale_val, mode='linear', align_corners=True |
| | ) |
| | y_out[idx] = y_interp.reshape(len(idx), y_interp.shape[-1])[:, :inf_pred_len] |
| | return y_out.reshape(B, V, inf_pred_len) |
| |
|
| | def forward(self, x_in: torch.Tensor, get_prob: bool = False, pred_len: Optional[int] = None) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: |
| | """ |
| | Forward pass through the model. |
| | |
| | Args: |
| | x_in: Encoder input tensor of shape [batch_size, sequence_length] or [batch_size, features, sequence_length] |
| | get_prob: Whether to return expert selection probabilities |
| | pred_len: Override for prediction length |
| | |
| | Returns: |
| | - Prediction tensor |
| | - (Optional) Expert selection probabilities if get_prob is True |
| | """ |
| | if pred_len is None: |
| | pred_len = self.train_pred_len |
| |
|
| | x = x_in |
| | |
| | if x_in.dim() == 2: |
| | x = x.unsqueeze(1) |
| |
|
| | B, V, L = x.shape |
| |
|
| | short_lookback = False |
| | orig_pred_len = pred_len |
| | |
| | if L < self.train_seq_len: |
| | |
| | |
| | scale_factor = self.train_seq_len / L |
| | scale_factor = int(np.ceil(scale_factor)) |
| |
|
| | pred_len = pred_len * scale_factor |
| | x = interpolate(x, scale_factor=scale_factor, mode='linear') |
| |
|
| | x = x[:, :, -self.train_seq_len:] |
| | L = self.train_seq_len |
| |
|
| | short_lookback = True |
| |
|
| | |
| | final_scales = None |
| | |
| | if self.lookback_resampling and L > self.train_seq_len: |
| |
|
| | x_resampled, final_scales = self.lookback_resample_search( |
| | x, self.scale_list, self.train_seq_len |
| | ) |
| | |
| | |
| | x = x_resampled |
| | L = x.shape[-1] |
| |
|
| |
|
| | |
| | x = x.reshape(B * V, L) |
| | expert_probs = None |
| | |
| | |
| | if get_prob: |
| | out, expert_probs = self.moe(x, get_prob=True) |
| | else: |
| | out = self.moe(x) |
| |
|
| | |
| | if self.train_pred_len < pred_len: |
| | outputs = [out] |
| | ar_x = torch.cat([x, out], dim=1)[:, -self.train_seq_len:] |
| | for i in range(0, pred_len, self.train_pred_len): |
| | ar_out = self.moe(ar_x) |
| | ar_out = self.apply_long_horizon_scaling(ar_out, ar_x) |
| | outputs.append(ar_out) |
| | ar_x = torch.cat([ar_x, ar_out], dim=1)[:, -self.train_seq_len:] |
| | out = torch.cat(outputs, dim=1)[:, :pred_len] |
| |
|
| | |
| | out = out.reshape(B, V, out.shape[-1]) |
| |
|
| | |
| | if self.lookback_resampling and final_scales is not None and not short_lookback: |
| | out = self.lookback_resample_reverse(out, final_scales, orig_pred_len) |
| |
|
| | |
| | if short_lookback: |
| | out = interpolate(out, scale_factor=1/scale_factor, mode='linear') |
| | out = out[:, :, :orig_pred_len] |
| |
|
| | |
| | if x_in.dim() == 2: |
| | out = out.squeeze(1) |
| |
|
| | |
| | if get_prob: |
| | expert_probs = expert_probs.reshape(B, V, expert_probs.shape[-1]) |
| | |
| | if x_in.dim() == 2: |
| | expert_probs = expert_probs.squeeze(-1) |
| | return out, expert_probs |
| |
|
| | return out |
| |
|
| | def map_to_cycle(self, freq: str) -> int: |
| | """ |
| | Map frequency string notation to cycle length (number of periods). |
| | |
| | Args: |
| | freq: String representing a time frequency (e.g., "h" for hourly, "D" for daily) |
| | |
| | Returns: |
| | Integer representing the number of periods in the cycle |
| | """ |
| | cycle = int(freq.split("/")[1]) |
| | return cycle |
| |
|
| | "-------------------------------------------------------------------------------------------------------------------" |
| | class SuperLinearForCausalLM(PreTrainedModel): |
| | config_class = SuperLinearConfig |
| |
|
| | def __init__(self, config: SuperLinearConfig): |
| | super().__init__(config) |
| | |
| | |
| | backbone_cfg = type("Cfg", (), config.to_dict())() |
| | self.args = backbone_cfg |
| | self.backbone = Model(backbone_cfg) |
| | self.post_init() |
| |
|
| | |
| | |
| | |
| | def forward(self, |
| | inputs_embeds: torch.Tensor = None, |
| | pred_len: Optional[int] = None, |
| | get_prob: bool = False, |
| | **kwargs) -> CausalLMOutputWithCrossAttentions: |
| |
|
| | if inputs_embeds is None: |
| | raise ValueError("inputs_embeds must be provided") |
| | |
| | |
| | x_enc = inputs_embeds |
| | |
| | |
| | if get_prob: |
| | preds, probs = self.backbone(x_enc, pred_len=pred_len, get_prob=True) |
| | else: |
| | preds = self.backbone(x_enc, pred_len=pred_len, get_prob=False) |
| | probs = None |
| | |
| | return CausalLMOutputWithCrossAttentions( |
| | logits=preds, |
| | hidden_states=None, |
| | attentions=probs |
| | ) |
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|