| | import torch |
| | import torch.nn as nn |
| | from fla.modules import GatedMLP |
| |
|
| | from src.data.containers import BatchTimeSeriesContainer |
| | from src.data.scalers import MinMaxScaler, RobustScaler |
| | from src.data.time_features import compute_batch_time_features |
| | from src.models.blocks import GatedDeltaProductEncoder |
| | from src.utils.utils import device |
| |
|
| |
|
| | def create_scaler(scaler_type: str, epsilon: float = 1e-3): |
| | """Create scaler instance based on type.""" |
| | if scaler_type == "custom_robust": |
| | return RobustScaler(epsilon=epsilon) |
| | elif scaler_type == "min_max": |
| | return MinMaxScaler(epsilon=epsilon) |
| | else: |
| | raise ValueError(f"Unknown scaler: {scaler_type}") |
| |
|
| |
|
| | def apply_channel_noise(values: torch.Tensor, noise_scale: float = 0.1): |
| | """Add noise to constant channels to prevent model instability.""" |
| | is_constant = torch.all(values == values[:, 0:1, :], dim=1) |
| | noise = torch.randn_like(values) * noise_scale * is_constant.unsqueeze(1) |
| | return values + noise |
| |
|
| |
|
| | class TimeSeriesModel(nn.Module): |
| | """Time series forecasting model combining embedding, encoding, and prediction.""" |
| |
|
| | def __init__( |
| | self, |
| | |
| | embed_size: int = 128, |
| | num_encoder_layers: int = 2, |
| | |
| | scaler: str = "custom_robust", |
| | epsilon: float = 1e-3, |
| | scaler_clamp_value: float = None, |
| | handle_constants: bool = False, |
| | |
| | K_max: int = 6, |
| | time_feature_config: dict = None, |
| | encoding_dropout: float = 0.0, |
| | |
| | encoder_config: dict = None, |
| | |
| | loss_type: str = "huber", |
| | quantiles: list[float] = None, |
| | **kwargs, |
| | ): |
| | super().__init__() |
| |
|
| | |
| | self.embed_size = embed_size |
| | self.num_encoder_layers = num_encoder_layers |
| | self.epsilon = epsilon |
| | self.scaler_clamp_value = scaler_clamp_value |
| | self.handle_constants = handle_constants |
| | self.encoding_dropout = encoding_dropout |
| | self.K_max = K_max |
| | self.time_feature_config = time_feature_config or {} |
| | self.encoder_config = encoder_config or {} |
| |
|
| | |
| | self.loss_type = loss_type |
| | self.quantiles = quantiles |
| | if self.loss_type == "quantile" and self.quantiles is None: |
| | raise ValueError("Quantiles must be provided for quantile loss.") |
| | if self.quantiles: |
| | self.register_buffer("qt", torch.tensor(self.quantiles, device=device).view(1, 1, 1, -1)) |
| |
|
| | |
| | self._validate_configuration() |
| |
|
| | |
| | self.scaler = create_scaler(scaler, epsilon) |
| | self._init_embedding_layers() |
| | self._init_encoder_layers(self.encoder_config, num_encoder_layers) |
| | self._init_projection_layers() |
| |
|
| | def _validate_configuration(self): |
| | """Validate essential model configuration parameters.""" |
| | if "num_heads" not in self.encoder_config: |
| | raise ValueError("encoder_config must contain 'num_heads' parameter") |
| |
|
| | if self.embed_size % self.encoder_config["num_heads"] != 0: |
| | raise ValueError( |
| | f"embed_size ({self.embed_size}) must be divisible by num_heads ({self.encoder_config['num_heads']})" |
| | ) |
| |
|
| | def _init_embedding_layers(self): |
| | """Initialize value and time feature embedding layers.""" |
| | self.expand_values = nn.Linear(1, self.embed_size, bias=True) |
| | self.nan_embedding = nn.Parameter( |
| | torch.randn(1, 1, 1, self.embed_size) / self.embed_size, |
| | requires_grad=True, |
| | ) |
| | self.time_feature_projection = nn.Linear(self.K_max, self.embed_size) |
| |
|
| | def _init_encoder_layers(self, encoder_config: dict, num_encoder_layers: int): |
| | """Initialize encoder layers.""" |
| | self.num_encoder_layers = num_encoder_layers |
| |
|
| | |
| | encoder_config = encoder_config.copy() |
| | encoder_config["token_embed_dim"] = self.embed_size |
| | self.encoder_layers = nn.ModuleList( |
| | [ |
| | GatedDeltaProductEncoder(layer_idx=layer_idx, **encoder_config) |
| | for layer_idx in range(self.num_encoder_layers) |
| | ] |
| | ) |
| |
|
| | def _init_projection_layers(self): |
| | if self.loss_type == "quantile": |
| | output_dim = len(self.quantiles) |
| | else: |
| | output_dim = 1 |
| | self.final_output_layer = nn.Linear(self.embed_size, output_dim) |
| |
|
| | self.mlp = GatedMLP( |
| | hidden_size=self.embed_size, |
| | hidden_ratio=4, |
| | hidden_act="swish", |
| | fuse_swiglu=True, |
| | ) |
| | |
| | |
| | head_k_dim = self.embed_size // self.encoder_config["num_heads"] |
| |
|
| | |
| | expand_v = self.encoder_config.get("expand_v", 1.0) |
| | head_v_dim = int(head_k_dim * expand_v) |
| |
|
| | num_initial_hidden_states = self.num_encoder_layers |
| | self.initial_hidden_state = nn.ParameterList( |
| | [ |
| | nn.Parameter( |
| | torch.randn(1, self.encoder_config["num_heads"], head_k_dim, head_v_dim) / head_k_dim, |
| | requires_grad=True, |
| | ) |
| | for _ in range(num_initial_hidden_states) |
| | ] |
| | ) |
| |
|
| | def _preprocess_data(self, data_container: BatchTimeSeriesContainer): |
| | """Extract data shapes and handle constants without padding.""" |
| | history_values = data_container.history_values |
| | future_values = data_container.future_values |
| | history_mask = data_container.history_mask |
| |
|
| | batch_size, history_length, num_channels = history_values.shape |
| | future_length = future_values.shape[1] if future_values is not None else 0 |
| |
|
| | |
| | if self.handle_constants: |
| | history_values = apply_channel_noise(history_values) |
| |
|
| | return { |
| | "history_values": history_values, |
| | "future_values": future_values, |
| | "history_mask": history_mask, |
| | "num_channels": num_channels, |
| | "history_length": history_length, |
| | "future_length": future_length, |
| | "batch_size": batch_size, |
| | } |
| |
|
| | def _compute_scaling(self, history_values: torch.Tensor, history_mask: torch.Tensor = None): |
| | """Compute scaling statistics and apply scaling.""" |
| | scale_statistics = self.scaler.compute_statistics(history_values, history_mask) |
| | return scale_statistics |
| |
|
| | def _apply_scaling_and_masking(self, values: torch.Tensor, scale_statistics: dict, mask: torch.Tensor = None): |
| | """Apply scaling and optional masking to values.""" |
| | scaled_values = self.scaler.scale(values, scale_statistics) |
| |
|
| | if mask is not None: |
| | scaled_values = scaled_values * mask.unsqueeze(-1).float() |
| |
|
| | if self.scaler_clamp_value is not None: |
| | scaled_values = torch.clamp(scaled_values, -self.scaler_clamp_value, self.scaler_clamp_value) |
| |
|
| | return scaled_values |
| |
|
| | def _get_positional_embeddings( |
| | self, |
| | time_features: torch.Tensor, |
| | num_channels: int, |
| | batch_size: int, |
| | drop_enc_allow: bool = False, |
| | ): |
| | """Generate positional embeddings from time features.""" |
| | seq_len = time_features.shape[1] |
| |
|
| | if (torch.rand(1).item() < self.encoding_dropout) and drop_enc_allow: |
| | return torch.zeros(batch_size, seq_len, num_channels, self.embed_size, device=device).to(torch.float32) |
| |
|
| | pos_embed = self.time_feature_projection(time_features) |
| | return pos_embed.unsqueeze(2).expand(-1, -1, num_channels, -1) |
| |
|
| | def _compute_embeddings( |
| | self, |
| | scaled_history: torch.Tensor, |
| | history_pos_embed: torch.Tensor, |
| | history_mask: torch.Tensor | None = None, |
| | ): |
| | """Compute value embeddings and combine with positional embeddings.""" |
| |
|
| | nan_mask = torch.isnan(scaled_history) |
| | history_for_embedding = torch.nan_to_num(scaled_history, nan=0.0) |
| | channel_embeddings = self.expand_values(history_for_embedding.unsqueeze(-1)) |
| | channel_embeddings[nan_mask] = self.nan_embedding.to(channel_embeddings.dtype) |
| | channel_embeddings = channel_embeddings + history_pos_embed |
| |
|
| | |
| | |
| | if history_mask is not None: |
| | mask_broadcast = history_mask.unsqueeze(-1).unsqueeze(-1).to(channel_embeddings.dtype) |
| | channel_embeddings = channel_embeddings * mask_broadcast |
| |
|
| | batch_size, seq_len = scaled_history.shape[:2] |
| | all_channels_embedded = channel_embeddings.view(batch_size, seq_len, -1) |
| |
|
| | return all_channels_embedded |
| |
|
| | def _generate_predictions( |
| | self, |
| | embedded: torch.Tensor, |
| | target_pos_embed: torch.Tensor, |
| | prediction_length: int, |
| | num_channels: int, |
| | history_mask: torch.Tensor = None, |
| | ): |
| | """ |
| | Generate predictions for all channels using vectorized operations. |
| | """ |
| | batch_size, seq_len, _ = embedded.shape |
| | |
| | embedded = embedded.view(batch_size, seq_len, num_channels, self.embed_size) |
| |
|
| | |
| | |
| | channel_embedded = ( |
| | embedded.permute(0, 2, 1, 3).contiguous().view(batch_size * num_channels, seq_len, self.embed_size) |
| | ) |
| |
|
| | |
| | target_pos_embed = ( |
| | target_pos_embed.permute(0, 2, 1, 3) |
| | .contiguous() |
| | .view(batch_size * num_channels, prediction_length, self.embed_size) |
| | ) |
| | x = channel_embedded |
| | target_repr = target_pos_embed |
| | x = torch.concatenate([x, target_repr], dim=1) |
| | if self.encoder_config.get("weaving", True): |
| | |
| | hidden_state = torch.zeros_like(self.initial_hidden_state[0].repeat(batch_size * num_channels, 1, 1, 1)) |
| | for layer_idx, encoder_layer in enumerate(self.encoder_layers): |
| | x, hidden_state = encoder_layer( |
| | x, |
| | hidden_state + self.initial_hidden_state[layer_idx].repeat(batch_size * num_channels, 1, 1, 1), |
| | ) |
| | else: |
| | |
| | for layer_idx, encoder_layer in enumerate(self.encoder_layers): |
| | initial_hidden_state = self.initial_hidden_state[layer_idx].repeat(batch_size * num_channels, 1, 1, 1) |
| | x, _ = encoder_layer(x, initial_hidden_state) |
| |
|
| | |
| | prediction_embeddings = x[:, -prediction_length:, :] |
| |
|
| | predictions = self.final_output_layer(self.mlp(prediction_embeddings)) |
| |
|
| | |
| | |
| | |
| | output_dim = len(self.quantiles) if self.loss_type == "quantile" else 1 |
| | predictions = predictions.view(batch_size, num_channels, prediction_length, output_dim) |
| | predictions = predictions.permute(0, 2, 1, 3) |
| | |
| | if self.loss_type != "quantile": |
| | predictions = predictions.squeeze(-1) |
| | return predictions |
| |
|
| | def forward(self, data_container: BatchTimeSeriesContainer, drop_enc_allow: bool = False): |
| | """Main forward pass.""" |
| | |
| | preprocessed = self._preprocess_data(data_container) |
| |
|
| | |
| | history_time_features, target_time_features = compute_batch_time_features( |
| | start=data_container.start, |
| | history_length=preprocessed["history_length"], |
| | future_length=preprocessed["future_length"], |
| | batch_size=preprocessed["batch_size"], |
| | frequency=data_container.frequency, |
| | K_max=self.K_max, |
| | time_feature_config=self.time_feature_config, |
| | ) |
| |
|
| | |
| | scale_statistics = self._compute_scaling(preprocessed["history_values"], preprocessed["history_mask"]) |
| |
|
| | |
| | history_scaled = self._apply_scaling_and_masking( |
| | preprocessed["history_values"], |
| | scale_statistics, |
| | preprocessed["history_mask"], |
| | ) |
| |
|
| | |
| | future_scaled = None |
| | if preprocessed["future_values"] is not None: |
| | future_scaled = self.scaler.scale(preprocessed["future_values"], scale_statistics) |
| |
|
| | |
| | history_pos_embed = self._get_positional_embeddings( |
| | history_time_features, |
| | preprocessed["num_channels"], |
| | preprocessed["batch_size"], |
| | drop_enc_allow, |
| | ) |
| | target_pos_embed = self._get_positional_embeddings( |
| | target_time_features, |
| | preprocessed["num_channels"], |
| | preprocessed["batch_size"], |
| | drop_enc_allow, |
| | ) |
| |
|
| | |
| | history_embed = self._compute_embeddings(history_scaled, history_pos_embed, preprocessed["history_mask"]) |
| |
|
| | |
| | predictions = self._generate_predictions( |
| | history_embed, |
| | target_pos_embed, |
| | preprocessed["future_length"], |
| | preprocessed["num_channels"], |
| | preprocessed["history_mask"], |
| | ) |
| |
|
| | return { |
| | "result": predictions, |
| | "scale_statistics": scale_statistics, |
| | "future_scaled": future_scaled, |
| | "history_length": preprocessed["history_length"], |
| | "future_length": preprocessed["future_length"], |
| | } |
| |
|
| | def _quantile_loss(self, y_true: torch.Tensor, y_pred: torch.Tensor): |
| | """ |
| | Compute the quantile loss. |
| | y_true: [B, P, N] |
| | y_pred: [B, P, N, Q] |
| | """ |
| | |
| | y_true = y_true.unsqueeze(-1) |
| |
|
| | |
| | errors = y_true - y_pred |
| |
|
| | |
| | |
| | loss = torch.max((self.qt - 1) * errors, self.qt * errors) |
| |
|
| | |
| | return loss.mean() |
| |
|
| | def compute_loss(self, y_true: torch.Tensor, y_pred: dict): |
| | """Compute loss between predictions and scaled ground truth.""" |
| | predictions = y_pred["result"] |
| | scale_statistics = y_pred["scale_statistics"] |
| |
|
| | if y_true is None: |
| | return torch.tensor(0.0, device=predictions.device) |
| |
|
| | future_scaled = self.scaler.scale(y_true, scale_statistics) |
| |
|
| | if self.loss_type == "huber": |
| | if predictions.shape != future_scaled.shape: |
| | raise ValueError( |
| | f"Shape mismatch for Huber loss: predictions {predictions.shape} " |
| | f"vs future_scaled {future_scaled.shape}" |
| | ) |
| | return nn.functional.huber_loss(predictions, future_scaled) |
| | elif self.loss_type == "quantile": |
| | return self._quantile_loss(future_scaled, predictions) |
| | else: |
| | raise ValueError(f"Unknown loss type: {self.loss_type}") |
| |
|