| from typing import Optional, Tuple, Dict, List, Union |
| import copy |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.nn.functional import interpolate |
|
|
| from transformers import PreTrainedModel |
| from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions |
| from .configuration_super_linear import SuperLinearConfig |
|
|
|
|
| "-------------------------------------------------------------------------------------------------------------------" |
| class RevIN(nn.Module): |
| def __init__(self, num_features: int, eps=1e-5, affine=True, norm_type=None, subtract_last=False): |
| """ |
| :param num_features: the number of features or channels |
| :param eps: a value added for numerical stability |
| :param affine: if True, RevIN has learnable affine parameters |
| """ |
| super(RevIN, self).__init__() |
| self.num_features = num_features |
| self.eps = eps |
| self.affine = affine |
| self.subtract_last = subtract_last |
| self.norm_type = norm_type |
| if self.affine: |
| self._init_params() |
|
|
| def forward(self, x, mode: str): |
| if mode == 'norm': |
| self._get_statistics(x) |
| x = self._normalize(x) |
| elif mode == 'denorm': |
| x = self._denormalize(x) |
| else: |
| raise NotImplementedError |
| return x |
|
|
| def _init_params(self): |
| |
| self.affine_weight = nn.Parameter(torch.ones(self.num_features)) |
| self.affine_bias = nn.Parameter(torch.zeros(self.num_features)) |
|
|
| def _get_statistics(self, x): |
| dim2reduce = tuple(range(1, x.ndim-1)) |
|
|
| if self.subtract_last: |
| self.last = x[:, -1:, :].detach() |
| self.mean = torch.mean(x[:, :-1, :], dim=dim2reduce, keepdim=True).detach() |
| else: |
| self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach() |
| |
| self.stdev = torch.sqrt(torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps).detach() |
| |
| if self.norm_type == "l1": |
| self.stdev = torch.mean(torch.abs(x - self.mean), dim=dim2reduce, keepdim=True).detach() |
| elif self.norm_type == "l2": |
| self.stdev = torch.sqrt(torch.mean((x - self.mean) ** 2, dim=dim2reduce, keepdim=True) + self.eps).detach() |
|
|
| def _normalize(self, x): |
| if self.subtract_last: |
| x = x - self.last |
| else: |
| x = x - self.mean |
| x = x / self.stdev |
|
|
| if self.norm_type in ["l1", "l2"]: |
| x = x / self.stdev |
|
|
| if self.affine: |
| x = x * self.affine_weight |
| x = x + self.affine_bias |
| return x |
|
|
| def _denormalize(self, x): |
| if self.affine: |
| x = x - self.affine_bias |
| x = x / (self.affine_weight + self.eps*self.eps) |
| |
| if self.norm_type in ["l1", "l2"]: |
| x = x * self.stdev |
| |
| x = x * self.stdev |
| if self.subtract_last: |
| x = x + self.last |
| else: |
| x = x + self.mean |
| |
| return x |
| "-------------------------------------------------------------------------------------------------------------------" |
| class Linear(nn.Module): |
| """Simple linear layer expert.""" |
| def __init__(self, input_len, output_len): |
| super(Linear, self).__init__() |
| self.Linear = nn.Linear(input_len, output_len) |
|
|
| def forward(self, x): |
| |
| x = x.clone() |
| x = self.Linear(x).clone() |
| return x |
| |
| class Naive(nn.Module): |
| """Naive forecasting expert - repeats last value.""" |
| def __init__(self, input_len, output_len): |
| super(Naive, self).__init__() |
| self.output_len = output_len |
|
|
| def forward(self, x): |
| |
| x = x[:,-1].unsqueeze(1).repeat(1, self.output_len) |
| return x |
| |
| class Mean(nn.Module): |
| """Mean forecasting expert - repeats mean value.""" |
| def __init__(self, input_len, output_len): |
| super(Mean, self).__init__() |
| self.output_len = output_len |
|
|
| def forward(self, x): |
| |
| x = x.mean(dim=1).unsqueeze(1).repeat(1, self.output_len) |
| return x |
|
|
| class RLinear(nn.Module): |
| """Reversible Instance Normalization Linear layer expert.""" |
| def __init__(self, input_len, output_len): |
| super(RLinear, self).__init__() |
| self.Linear = nn.Linear(input_len, output_len) |
| self.revin_layer = RevIN(num_features = None, affine=False, norm_type = None, subtract_last = False) |
|
|
| def forward(self, x): |
| |
| x_shape = x.shape |
| if len(x_shape) == 2: |
| x = x.unsqueeze(-1) |
| x = x.clone() |
| x = self.revin_layer(x, 'norm') |
| |
| x = self.Linear(x.permute(0,2,1)).permute(0,2,1).clone() |
| x = self.revin_layer(x, 'denorm') |
| if len(x_shape) == 2: |
| x = x.squeeze(-1) |
| return x |
|
|
| "-------------------------------------------------------------------------------------------------------------------" |
| class SparseMoE(nn.Module): |
| """ |
| Sparse Mixture of Experts (MoE) module that routes inputs to the most relevant experts. |
| |
| This implementation uses a gating network to determine which experts should process each input. |
| Only the top-k experts are used for each input, creating a sparse computation pattern. |
| |
| Args: |
| configs: Configuration object containing MoE parameters |
| experts: Collection of expert modules (neural networks) |
| """ |
| def __init__(self, configs, experts=None): |
| super(SparseMoE, self).__init__() |
| self.noise_std = configs.noisy_gating_std |
| self.experts = nn.ModuleList(experts) |
| self.num_experts = len(experts) |
| self.k = configs.top_k_experts |
| |
| if self.k > self.num_experts: |
| self.k = self.num_experts |
| |
| self.moe_temp = configs.moe_temp |
| self.use_fft = configs.use_fft |
| self.fft_len = configs.fft_len |
| self.moe_norm = configs.moe_norm |
| |
| |
| if self.use_fft: |
| self.gating_network = nn.Linear(self.fft_len//2, self.num_experts, bias=True) |
| else: |
| self.gating_network = nn.Linear(configs.train_seq_len, self.num_experts, bias=True) |
|
|
| if self.moe_norm: |
| self.batch_norm = nn.BatchNorm1d(self.num_experts) |
|
|
| def get_periodogram(self, inputs, n=10000): |
| """ |
| Calculate the periodogram (power spectral density) of input time series. |
| |
| The periodogram is used as a frequency-domain representation of the signal |
| to help the gating network identify periodic patterns. |
| |
| Args: |
| inputs: Input time series tensor of shape [batch_size, sequence_length] or [batch_size, sequence_length, features] |
| n: Number of points in FFT computation |
| |
| Returns: |
| Normalized periodogram of the input signals |
| """ |
| x_0 = inputs - torch.mean(inputs, dim=1, keepdim=True) |
|
|
| |
| dft = torch.fft.fft(x_0, dim=1, n=n) / np.sqrt(n) |
| dft = dft[:, :n//2] |
| I = torch.abs(dft) ** 2 |
|
|
| |
| I_sum = torch.sum(I, dim=1, keepdim=True) |
| I_sum[I_sum == 0] = 1 |
| I = I / I_sum |
| |
| return I |
|
|
| def forward(self, x, get_prob=False, get_prob_only=False): |
| """ |
| Forward pass through the Mixture of Experts. |
| |
| Args: |
| x: Input tensor of shape [batch_size, sequence_length] |
| get_prob: Whether to return expert selection probabilities |
| get_prob_only: Whether to return only probabilities without computation |
| |
| Returns: |
| - Output tensor from the selected experts |
| - (Optional) Expert selection probabilities if get_prob is True |
| """ |
| |
| if self.use_fft: |
| x_0 = self.get_periodogram(x, n=self.fft_len) |
| else: |
| x_0 = x |
| |
| |
| gate_outputs = self.gating_network(x_0) |
| |
| if self.moe_norm: |
| gate_outputs = self.batch_norm(gate_outputs) |
|
|
| |
| if not self.training: |
| gate_outputs = gate_outputs / self.moe_temp |
|
|
| if get_prob_only: |
| expert_probs = F.softmax(gate_outputs, dim=1) |
| return expert_probs |
|
|
| |
| if self.training: |
| noise = torch.randn_like(gate_outputs).to(x.device) * self.noise_std |
| noisy_gate_outputs = gate_outputs + noise |
| topk_values, topk_indices = torch.topk(noisy_gate_outputs, self.k, dim=1) |
| else: |
| topk_values, topk_indices = torch.topk(gate_outputs, self.k, dim=1) |
|
|
| |
| topk_gates = F.softmax(topk_values, dim=1) |
| |
| |
| expert_outputs = torch.stack([self.experts[i](x) for i in range(self.num_experts)], dim=1) |
|
|
| |
| topk_indices_expanded = topk_indices.unsqueeze(-1).expand(-1, -1, expert_outputs.size(2)) |
| sparse_expert_outputs = torch.gather(expert_outputs, 1, topk_indices_expanded) |
|
|
| |
| output = torch.sum(topk_gates.unsqueeze(2) * sparse_expert_outputs, dim=1) |
| |
| if get_prob: |
| expert_probs = F.softmax(gate_outputs, dim=1) |
| return output, expert_probs |
| |
| return output |
|
|
|
|
| class Model(nn.Module): |
| """ |
| Main model class that employs a Mixture of Experts for time series forecasting. |
| |
| This model can work with various types of linear layers as experts and supports |
| both standard prediction and auto-regressive prediction for longer horizons. |
| |
| Args: |
| configs: Configuration object containing model parameters |
| """ |
| def __init__(self, configs): |
| super(Model, self).__init__() |
|
|
| self.configs = copy.deepcopy(configs) |
|
|
| |
| self.train_pred_len = configs.train_pred_len |
| self.train_seq_len = configs.train_seq_len |
| self.layer_type = configs.layer_type |
|
|
|
|
| |
| self.long_horizon_scaling = configs.long_horizon_scaling |
| self.lookback_resampling = configs.lookback_resampling |
| lookback_scale_str = configs.scale_list |
| if isinstance(lookback_scale_str, str): |
| self.scale_list = [float(x.strip()) for x in lookback_scale_str.split(',')] |
| else: |
| self.scale_list = lookback_scale_str |
| self.threshold = configs.threshold |
| self.freq_bound = configs.freq_bound |
| self.penalty_scale = configs.penalty_scale |
| self.fft_len = configs.fft_len |
|
|
| |
| freq_experts_str = configs.freq_experts |
| if freq_experts_str == "": |
| self.freq_experts = None |
| else: |
| self.freq_experts = freq_experts_str.split('_') |
|
|
| |
| self.top_k_experts = configs.top_k_experts |
| self.freeze_experts = configs.freeze_experts |
|
|
| |
| self.experts = {} |
| if self.freq_experts is not None: |
| for expert_freq in self.freq_experts: |
| if expert_freq.lower() == "naive": |
| self.experts[expert_freq] = Naive(self.train_seq_len, self.train_pred_len) |
| elif expert_freq.lower() == "mean": |
| self.experts[expert_freq] = Mean(self.train_seq_len, self.train_pred_len) |
| else: |
| self.experts[expert_freq] = RLinear(self.train_seq_len, self.train_pred_len) |
| self.n_experts = len(self.experts) |
| else: |
| raise ValueError("Please specify experts in the configuration.") |
|
|
| |
| comp_moe = configs.comp_moe |
| if comp_moe > 0: |
| if comp_moe == 1: |
| print("Creating complementary expert") |
| self.experts["comp"] = RLinear(self.train_seq_len, self.train_pred_len) |
| else: |
| for i in range(comp_moe): |
| print(f"Creating complementary expert {i}") |
| self.experts["comp_"+str(i)] = RLinear(self.train_seq_len, self.train_pred_len) |
| |
| |
| self.moe = SparseMoE(configs, experts=self.experts.values()) |
| |
| print("Experts:", self.experts.keys()) |
|
|
| def add_experts(self, experts: Dict[str, nn.Module]) -> nn.Module: |
| """ |
| Add new experts to the model. |
| |
| Args: |
| experts: Dictionary of expert instances to add |
| |
| Returns: |
| Updated MoE layer |
| """ |
| for name, expert in experts.items(): |
| if name not in self.experts: |
| self.experts[name] = expert |
| print(f"Added expert: {name}") |
| else: |
| print(f"Expert {name} already exists. Skipping addition.") |
| |
| self.moe = SparseMoE(self.configs, experts=self.experts.values()) |
| return self.moe |
|
|
| def apply_long_horizon_scaling(self, ar_out: torch.Tensor, ar_x: torch.Tensor) -> torch.Tensor: |
| """ |
| Apply scaling to auto-regressive outputs to maintain statistical properties during long horizon prediction. |
| |
| This function identifies cases where the variance of the new predictions exceeds the variance |
| of the input sequence and applies scaling to maintain consistent statistical properties. |
| |
| Args: |
| ar_out: Auto-regressive output tensor of shape [batch_size * features, pred_len] |
| ar_x: Input sequence tensor of shape [batch_size * features, seq_len] |
| |
| Returns: |
| Scaled auto-regressive output tensor |
| """ |
| if not (self.long_horizon_scaling and not self.training): |
| return ar_out |
| |
| |
| std_new = torch.std(ar_out, dim=1, keepdim=True) |
| mean_new = torch.mean(ar_out, dim=1, keepdim=True) |
| std_old = torch.std(ar_x, dim=1, keepdim=True) |
| |
| |
| inds = torch.where(std_new / std_old > 1)[0] |
| |
| if len(inds) > 0: |
| |
| ar_out_centered = ar_out[inds] - mean_new[inds] |
| |
| |
| scaling = std_old[inds] / (std_new[inds] + 1e-8) |
| |
| |
| ar_out_adjusted = ar_out_centered * scaling + mean_new[inds] |
| ar_out[inds] = ar_out_adjusted |
| |
| return ar_out |
|
|
| def lookback_resample_search(self, x, scale_list=[2,4,6], min_lookback=512): |
| """ |
| Search for optimal resampling scale based on lookback analysis of expert selection. |
| |
| This function analyzes the frequency content and expert selection lookback to determine |
| the best resampling scale for each input sequence, potentially improving model performance |
| by matching input characteristics to expert capabilities. |
| |
| Args: |
| x: Input tensor of shape [batch_size, features, sequence_length] |
| scale_list: List of potential downsampling scales to evaluate |
| min_lookback: Minimum sequence length required after resampling |
| |
| Returns: |
| Tuple of (resampled_input, final_scales) where: |
| - resampled_input: Optimally resampled input tensor |
| - final_scales: Scale factors used for each sample |
| """ |
| B, V, L = x.shape |
|
|
| lookback = self.train_seq_len |
| x_0 = x.reshape(B*V, L)[:, -lookback:] |
| output_x = x_0.clone()[:, -lookback:] |
|
|
| x_reshape = x.reshape(B*V, L) |
| x_fft_init = self.moe.get_periodogram(x_reshape, n=self.fft_len) |
|
|
| right_cumsum = torch.cumsum(x_fft_init, dim=-1) |
| mask = right_cumsum > 1-self.threshold |
| j_threshold = mask.float().argmax(dim=-1) |
|
|
| freqs = np.array([np.linspace(0, 0.5, self.fft_len//2)]) |
| threshhold_freqs = np.take_along_axis(freqs, j_threshold.unsqueeze(-1).detach().cpu().numpy(), axis=1) |
| |
| |
| threshhold_freqs[threshhold_freqs == 0] = self.freq_bound |
| max_scale_factor = (self.freq_bound/ threshhold_freqs).astype(int).flatten() |
|
|
|
|
| if self.threshold==0: |
| max_scale_factor = np.inf * np.ones(B*V, dtype=int) |
|
|
| |
| energy_loss_penalties = {} |
| total_energy = torch.sum(x_fft_init, dim=-1) |
| |
| for scale in scale_list: |
| if scale <= 1: |
| continue |
| |
| |
| nyquist_after_downsample = 0.5 / scale |
| |
| |
| freq_bins = torch.linspace(0, 0.5, self.fft_len//2, device=x_fft_init.device) |
| lost_freq_mask = freq_bins > nyquist_after_downsample |
| |
| |
| lost_energy = torch.sum(x_fft_init[:, lost_freq_mask], dim=-1) |
| |
| energy_loss_fraction = lost_energy / (total_energy + 1e-10) |
| energy_loss_penalties[scale] = energy_loss_fraction |
|
|
| |
| prob = self.moe(x_0, get_prob_only=True) |
| best_scores = -torch.sum(prob * torch.log(prob + 1e-10), dim=-1) |
| final_scales = torch.ones(B*V, device=x.device) |
|
|
| for scale in scale_list: |
| x_interp = torch.nn.functional.interpolate( |
| x, scale_factor=1/scale, mode='linear', align_corners=True |
| ) |
| |
| if x_interp.shape[2] >= min_lookback: |
| x_interp_reshaped = x_interp.reshape(B*V, x_interp.shape[-1]) |
| x_interp_reshaped = x_interp_reshaped[:, -lookback:] |
| prob = self.moe(x_interp_reshaped, get_prob_only=True) |
|
|
| scores = -torch.sum(prob * torch.log(prob + 1e-10), dim=-1) |
| |
| |
| if scale in energy_loss_penalties: |
| energy_penalty = energy_loss_penalties[scale] |
| scores = scores + energy_penalty*self.penalty_scale |
|
|
| idx = np.where((scores < best_scores).cpu() & torch.tensor(max_scale_factor >= scale))[0] |
|
|
| if len(idx) > 0: |
| output_x[idx] = x_interp_reshaped[idx] |
| best_scores[idx] = scores[idx] |
| final_scales[idx] = scale |
|
|
| return output_x.reshape(B, V, output_x.shape[-1]), final_scales |
|
|
| def lookback_resample_reverse(self, y, final_scales, inf_pred_len=None): |
| """ |
| Reverse the resampling operation on the output. |
| |
| This function upsamples the model outputs back to the original scale |
| based on the resampling factors used during input processing. |
| |
| Args: |
| y: Output tensor from model of shape [batch_size, features, pred_len] |
| final_scales: Scale factors used during input resampling |
| inf_pred_len: Target prediction length |
| |
| Returns: |
| Upsampled output tensor of shape [batch_size, features, inf_pred_len] |
| """ |
| B, V, L = y.shape |
| y_reshaped = y.view(B*V, L) |
| y_out = y_reshaped[:, :inf_pred_len] |
|
|
| unique_scales = torch.unique(final_scales) |
| for scale in unique_scales: |
| scale_val = scale.item() |
| if scale_val > 1: |
| idx = torch.where(final_scales == scale)[0] |
|
|
| if len(idx) > 0: |
| y_interp = torch.nn.functional.interpolate( |
| y_reshaped[idx].unsqueeze(1), scale_factor=scale_val, mode='linear', align_corners=True |
| ) |
| y_out[idx] = y_interp.reshape(len(idx), y_interp.shape[-1])[:, :inf_pred_len] |
| return y_out.reshape(B, V, inf_pred_len) |
|
|
| def forward(self, x_in: torch.Tensor, get_prob: bool = False, pred_len: Optional[int] = None) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: |
| """ |
| Forward pass through the model. |
| |
| Args: |
| x_in: Encoder input tensor of shape [batch_size, sequence_length] or [batch_size, features, sequence_length] |
| get_prob: Whether to return expert selection probabilities |
| pred_len: Override for prediction length |
| |
| Returns: |
| - Prediction tensor |
| - (Optional) Expert selection probabilities if get_prob is True |
| """ |
| if pred_len is None: |
| pred_len = self.train_pred_len |
|
|
| x = x_in |
| |
| if x_in.dim() == 2: |
| x = x.unsqueeze(1) |
|
|
| B, V, L = x.shape |
|
|
| short_lookback = False |
| orig_pred_len = pred_len |
| |
| if L < self.train_seq_len: |
| |
| |
| scale_factor = self.train_seq_len / L |
| scale_factor = int(np.ceil(scale_factor)) |
|
|
| pred_len = pred_len * scale_factor |
| x = interpolate(x, scale_factor=scale_factor, mode='linear') |
|
|
| x = x[:, :, -self.train_seq_len:] |
| L = self.train_seq_len |
|
|
| short_lookback = True |
|
|
| |
| final_scales = None |
| |
| if self.lookback_resampling and L > self.train_seq_len: |
|
|
| x_resampled, final_scales = self.lookback_resample_search( |
| x, self.scale_list, self.train_seq_len |
| ) |
| |
| |
| x = x_resampled |
| L = x.shape[-1] |
|
|
|
|
| |
| x = x.reshape(B * V, L) |
| expert_probs = None |
| |
| |
| if get_prob: |
| out, expert_probs = self.moe(x, get_prob=True) |
| else: |
| out = self.moe(x) |
|
|
| |
| if self.train_pred_len < pred_len: |
| outputs = [out] |
| ar_x = torch.cat([x, out], dim=1)[:, -self.train_seq_len:] |
| for i in range(0, pred_len, self.train_pred_len): |
| ar_out = self.moe(ar_x) |
| ar_out = self.apply_long_horizon_scaling(ar_out, ar_x) |
| outputs.append(ar_out) |
| ar_x = torch.cat([ar_x, ar_out], dim=1)[:, -self.train_seq_len:] |
| out = torch.cat(outputs, dim=1)[:, :pred_len] |
|
|
| |
| out = out.reshape(B, V, out.shape[-1]) |
|
|
| |
| if self.lookback_resampling and final_scales is not None and not short_lookback: |
| out = self.lookback_resample_reverse(out, final_scales, orig_pred_len) |
|
|
| |
| if short_lookback: |
| out = interpolate(out, scale_factor=1/scale_factor, mode='linear') |
| out = out[:, :, :orig_pred_len] |
|
|
| |
| if x_in.dim() == 2: |
| out = out.squeeze(1) |
| |
| if get_prob: |
| expert_probs = expert_probs.reshape(B, V, expert_probs.shape[-1]) |
| |
| if x_in.dim() == 2: |
| expert_probs = expert_probs.squeeze(-1) |
| return out, expert_probs |
|
|
| return out |
|
|
| def map_to_cycle(self, freq: str) -> int: |
| """ |
| Map frequency string notation to cycle length (number of periods). |
| |
| Args: |
| freq: String representing a time frequency (e.g., "h" for hourly, "D" for daily) |
| |
| Returns: |
| Integer representing the number of periods in the cycle |
| """ |
| cycle = int(freq.split("/")[1]) |
| return cycle |
|
|
| "-------------------------------------------------------------------------------------------------------------------" |
| class SuperLinearForCausalLM(PreTrainedModel): |
| config_class = SuperLinearConfig |
|
|
| def __init__(self, config: SuperLinearConfig): |
| super().__init__(config) |
| |
| |
| backbone_cfg = type("Cfg", (), config.to_dict())() |
| self.args = backbone_cfg |
| self.backbone = Model(backbone_cfg) |
| self.post_init() |
|
|
| |
| |
| |
| def forward(self, |
| inputs_embeds: torch.Tensor = None, |
| pred_len: Optional[int] = None, |
| get_prob: bool = False, |
| **kwargs) -> CausalLMOutputWithCrossAttentions: |
|
|
| if inputs_embeds is None: |
| raise ValueError("inputs_embeds must be provided") |
| |
| |
| x_enc = inputs_embeds |
| |
| |
| if get_prob: |
| preds, probs = self.backbone(x_enc, pred_len=pred_len, get_prob=True) |
| else: |
| preds = self.backbone(x_enc, pred_len=pred_len, get_prob=False) |
| probs = None |
| |
| return CausalLMOutputWithCrossAttentions( |
| logits=preds, |
| hidden_states=None, |
| attentions=probs |
| ) |
|
|
|
|
|
|
|
|
|
|
|
|