from transformers import PreTrainedModel from configuration_LightGTS import LightGTSConfig from ts_generation_mixin import TSGenerationMixin import torch from torch import nn from torch import Tensor from typing import Callable, Optional import math import torch.nn.functional as F import numpy as np class LightGTSPreTrainedModel(PreTrainedModel): config_class = LightGTSConfig base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["TSTEncoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True _supports_sdpa = False _supports_cache_class = True def _init_weights(self, module): std = self.config.initializer_range if isinstance(module, torch.nn.Linear): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, torch.nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() class LightGTSForPrediction(LightGTSPreTrainedModel, TSGenerationMixin): def __init__(self, config: LightGTSConfig): super().__init__(config) self.config = config self.model = LightGTS(c_in=config.c_in, target_dim=config.target_dim, patch_len=config.patch_len, stride=config.stride, num_patch=config.num_patch, e_layers=config.e_layers, d_layers=config.d_layers, n_heads=config.n_heads, d_model=config.d_model, shared_embedding=True, d_ff=config.d_ff, dropout=config.dropout, attn_dropout=config.attn_dropout, head_dropout=config.head_dropout, act='relu', head_type=config.head_type, res_attention=False, learn_pe=False ) def forward(self, input, labels=None, patch_len=None, stride=None, target_dim=None): self.config.patch_len = patch_len self.config.stride = stride self.config.target_dim = target_dim #切patch batch_size,seq_len,n_vars = input.shape num_patch = (max(seq_len, self.config.patch_len)-self.config.patch_len) // self.config.stride + 1 self.config.num_patch = num_patch outputs = input.view(batch_size, num_patch, self.config.patch_len, n_vars) outputs = outputs.transpose(2, 3) outputs = self.model(outputs, target_dim=self.config.target_dim, patch_len=self.config.patch_len, stride=self.config.stride) loss = None if labels is not None: if outputs.shape != labels.shape: outputs = outputs.view(labels.shape) loss = self.loss_fn(outputs, labels) return {"prediction": outputs, "loss": loss} class LightGTSForFinetune(LightGTSPreTrainedModel, TSGenerationMixin): def __init__(self, config: LightGTSConfig): super().__init__(config) self.config = config self.model = LightGTS(c_in=config.c_in, target_dim=config.target_dim, patch_len=config.patch_len, stride=config.stride, num_patch=config.num_patch, e_layers=config.e_layers, d_layers=config.d_layers, n_heads=config.n_heads, d_model=config.d_model, shared_embedding=True, d_ff=config.d_ff, dropout=config.dropout, attn_dropout=config.attn_dropout, head_dropout=config.head_dropout, act='relu', head_type=config.head_type, res_attention=False, learn_pe=False ) def forward(self, input, labels=None, patch_len=None, stride=None, target_dim=None): if patch_len is not None: self.config.patch_len = patch_len if stride is not None: self.config.stride = stride if target_dim is not None: self.config.target_dim = target_dim #切patch batch_size,seq_len,n_vars = input.shape num_patch = (max(seq_len, self.config.patch_len)-self.config.patch_len) // self.config.stride + 1 self.config.num_patch = num_patch outputs = input.view(batch_size, num_patch, self.config.patch_len, n_vars) outputs = outputs.transpose(2, 3) outputs = self.model(outputs, target_dim=self.config.target_dim, patch_len=self.config.patch_len, stride=self.config.stride) loss = None if labels is not None: if outputs.shape != labels.shape: outputs = outputs.view(labels.shape) loss = self.loss_fn(outputs, labels) return {"prediction": outputs, "loss": loss} class LightGTS(nn.Module): """ Output dimension: [bs x target_dim x nvars] for prediction [bs x target_dim] for regression [bs x target_dim] for classification [bs x num_patch x n_vars x patch_len] for pretrain """ def __init__(self, c_in:int, target_dim:int, patch_len:int, stride:int, num_patch:int, mask_mode:str = 'patch',mask_nums:int = 3, e_layers:int=3, d_layers:int=3, d_model=128, n_heads=16, shared_embedding=True, d_ff:int=256, norm:str='BatchNorm', attn_dropout:float=0.4, dropout:float=0., act:str="gelu", res_attention:bool=True, pre_norm:bool=False, store_attn:bool=False, pe:str='sincos', learn_pe:bool=False, head_dropout = 0, head_type = "prediction", individual = False, y_range:Optional[tuple]=None, verbose:bool=False, **kwargs): super().__init__() assert head_type in ['pretrain', 'prediction', 'regression', 'classification'], 'head type should be either pretrain, prediction, or regression' # Basic self.num_patch = num_patch self.target_dim=target_dim self.out_patch_num = math.ceil(target_dim / patch_len) self.target_patch_len = 48 # Embedding self.embedding = nn.Linear(self.target_patch_len, d_model) # self.decoder_embedding = nn.Parameter(torch.randn(1, 1,1, d_model),requires_grad=True) self.cls_embedding = nn.Parameter(torch.randn(1, 1, 1, d_model),requires_grad=True) # self.sep_embedding = nn.Parameter(torch.randn(1, 1, 1, d_model),requires_grad=True) # Position Embedding # self.pos = positional_encoding(pe, learn_pe, 1 + num_patch + self.out_patch_num, d_model) # self.drop_out = nn.Dropout(dropout) # Encoder self.encoder = TSTEncoder(d_model, n_heads, d_ff=d_ff, norm=norm, attn_dropout=attn_dropout, dropout=dropout, pre_norm=pre_norm, activation=act, res_attention=res_attention, n_layers=e_layers, store_attn=store_attn) # Decoder self.decoder = Decoder(d_layers, patch_len=patch_len, d_model=d_model, n_heads=n_heads, d_ff=d_ff,attn_dropout= attn_dropout, dropout=dropout) # Head self.n_vars = c_in self.head_type = head_type self.mask_mode = mask_mode self.mask_nums = mask_nums self.d_model = d_model self.patch_len = patch_len if head_type == "pretrain": self.head = PretrainHead(d_model, patch_len, head_dropout) # custom head passed as a partial func with all its kwargs elif head_type == "prediction": self.head = decoder_PredictHead(d_model, self.patch_len, self.target_patch_len, head_dropout) def get_dynamic_weights(self, n_preds, decay_rate=0.5): """ Generate dynamic weights for the replicated tokens using an exponential decay scheme. Args: - n_preds (int): Number of predictions to generate weights for. - decay_rate (float): The base of the exponential decay. Lower values decay faster (default: 0.9). Returns: - torch.Tensor: A tensor of weights with exponential decay. """ # Exponential decay weights weights = decay_rate ** torch.arange(n_preds) return weights def decoder_predict(self, bs, n_vars, dec_cross): """ dec_cross: tensor [bs x n_vars x num_patch x d_model] """ # dec_in = self.decoder_embedding.expand(bs, self.n_vars, self.out_patch_num, -1) # dec_in = self.embedding(self.decoder_len).expand(bs, -1, -1, -1) # dec_in = self.decoder_embedding.expand(bs, n_vars, self.out_patch_num, -1) # dec_in = dec_cross.mean(2).unsqueeze(2).expand(-1,-1,self.out_patch_num,-1) dec_in = dec_cross[:,:,-1,:].unsqueeze(2).expand(-1,-1,self.out_patch_num,-1) weights = self.get_dynamic_weights(self.out_patch_num).to(dec_in.device) dec_in = dec_in * weights.unsqueeze(0).unsqueeze(0).unsqueeze(-1) # dec_in = torch.cat((dec_in, self.sep_tokens), dim=2) # dec_in = dec_cross[:,:,-self.out_patch_num:,:] # dec_in = torch.ones([bs, n_vars, self.out_patch_num, self.d_model]).to(dec_cross.device) # dec_in = dec_in + self.pos[-self.out_patch_num:,:] decoder_output = self.decoder(dec_in, dec_cross) decoder_output = decoder_output.transpose(2,3) return decoder_output def forward(self, z, target_dim=None, patch_len=None, stride=None): """ z: tensor [bs x num_patch x n_vars x patch_len] """ if target_dim is not None: self.target_dim = target_dim if patch_len is not None: self.patch_len = patch_len if stride is not None: self.stride = stride self.out_patch_num = math.ceil(self.target_dim / self.patch_len) bs, num_patch, n_vars, patch_len = z.shape # tokenizer cls_tokens = self.cls_embedding.expand(bs, n_vars, -1, -1) embedding = nn.Linear(patch_len, self.d_model, bias=False) embedding.weight.data = resample_patchemb(old=self.embedding.weight.data, new_patch_len=self.patch_len) z = embedding(z).permute(0,2,1,3) # [bs x n_vars x num_patch x d_model] z = torch.cat((cls_tokens, z), dim=2) # [bs x n_vars x (1 + num_patch) x d_model] # z = self.drop_out(z + self.pos[:1 + self.num_patch, :]) # encoder z = torch.reshape(z, (-1, 1 + num_patch, self.d_model)) # [bs*n_vars x num_patch x d_model] z = self.encoder(z) z = torch.reshape(z, (-1, n_vars, 1 + num_patch, self.d_model)) # [bs, n_vars x num_patch x d_model] # decoder z = self.decoder_predict(bs, n_vars, z[:,:,:,:]) # predict z = self.head(z[:,:,:,:], self.patch_len) z = z[:,:self.target_dim, :] # z: [bs x target_dim x nvars] for prediction # [bs x target_dim] for regression # [bs x target_dim] for classification # [bs x num_patch x n_vars x patch_len] for pretrain return z class TSTEncoder(nn.Module): def __init__(self, d_model, n_heads, d_ff=None, norm='BatchNorm', attn_dropout=0., dropout=0., activation='gelu', res_attention=False, n_layers=1, pre_norm=False, store_attn=False): super().__init__() self.layers = nn.ModuleList([TSTEncoderLayer(d_model, n_heads=n_heads, d_ff=d_ff, norm=norm, attn_dropout=attn_dropout, dropout=dropout, activation=activation, res_attention=res_attention, pre_norm=pre_norm, store_attn=store_attn) for i in range(n_layers)]) self.res_attention = res_attention def forward(self, src:Tensor): """ src: tensor [bs x q_len x d_model] """ output = src scores = None if self.res_attention: for mod in self.layers: output, scores = mod(output, prev=scores) return output else: for mod in self.layers: output = mod(output) return output class TSTEncoderLayer(nn.Module): def __init__(self, d_model, n_heads, d_ff=256, store_attn=False, norm='LayerNorm', attn_dropout=0, dropout=0., bias=True, activation="gelu", res_attention=False, pre_norm=False): super().__init__() assert not d_model%n_heads, f"d_model ({d_model}) must be divisible by n_heads ({n_heads})" d_k = d_model // n_heads d_v = d_model // n_heads # Multi-Head attention self.res_attention = res_attention self.self_attn = MultiheadAttention(d_model, n_heads, d_k, d_v, attn_dropout=attn_dropout, proj_dropout=dropout, res_attention=res_attention) # Add & Norm self.dropout_attn = nn.Dropout(dropout) if "batch" in norm.lower(): self.norm_attn = nn.Sequential(Transpose(1,2), nn.BatchNorm1d(d_model), Transpose(1,2)) else: self.norm_attn = nn.LayerNorm(d_model) # Position-wise Feed-Forward self.ff = nn.Sequential(nn.Linear(d_model, d_ff, bias=bias), get_activation_fn(activation), nn.Dropout(dropout), nn.Linear(d_ff, d_model, bias=bias)) # Add & Norm self.dropout_ffn = nn.Dropout(dropout) if "batch" in norm.lower(): self.norm_ffn = nn.Sequential(Transpose(1,2), nn.BatchNorm1d(d_model), Transpose(1,2)) else: self.norm_ffn = nn.LayerNorm(d_model) self.pre_norm = pre_norm self.store_attn = store_attn # # se block # self.SE = SE_Block(inchannel=7) def forward(self, src:Tensor, prev:Optional[Tensor]=None): """ src: tensor [bs x q_len x d_model] """ # Multi-Head attention sublayer if self.pre_norm: src = self.norm_attn(src) ## Multi-Head attention if self.res_attention: src2, attn, scores = self.self_attn(src, src, src, prev) else: # attention_mask = causal_attention_mask(src.shape[1]).to(src.device) # src2, attn = self.self_attn(src, src, src, attn_mask=attention_mask) src2, attn = self.self_attn(src, src, src) if self.store_attn: self.attn = attn # total, num_patch, d_model = src2.size() # bs = int(total/7) # src2 = self.SE(src2.reshape(bs, 7, num_patch, -1)).reshape(total, num_patch, -1) ## Add & Norm src = src + self.dropout_attn(src2) # Add: residual connection with residual dropout if not self.pre_norm: src = self.norm_attn(src) # Feed-forward sublayer if self.pre_norm: src = self.norm_ffn(src) ## Position-wise Feed-Forward src2 = self.ff(src) ## Add & Norm src = src + self.dropout_ffn(src2) # Add: residual connection with residual dropout if not self.pre_norm: src = self.norm_ffn(src) if self.res_attention: return src, scores else: return src class Decoder(nn.Module): def __init__(self, d_layers, patch_len, d_model, n_heads, d_ff=None, attn_dropout=0.2, dropout=0.1): super(Decoder, self).__init__() self.decoder_layers = nn.ModuleList() for i in range(d_layers): self.decoder_layers.append(DecoderLayer(patch_len, d_model, n_heads, d_ff, attn_dropout, dropout)) def forward(self, x, cross): output = x for layer in self.decoder_layers: output = layer(output, cross) return output class DecoderLayer(nn.Module): def __init__(self, patch_len, d_model, n_heads, d_ff=None, attn_dropout = 0.2, dropout=0.5, norm="BatchNorm"): super(DecoderLayer, self).__init__() self.self_attention = MultiheadAttention(d_model, n_heads, res_attention=False, attn_dropout=attn_dropout) self.cross_attention = MultiheadAttention(d_model, n_heads, attn_dropout=attn_dropout, rope_type=True) # self.pos_embed = nn.Conv1d(d_model, d_model, kernel_size=3, padding=1, groups=d_model) if 'batch' in norm.lower(): self.norm1 = nn.Sequential(Transpose(1,2), nn.BatchNorm1d(d_model), Transpose(1,2)) self.norm2 = nn.Sequential(Transpose(1,2), nn.BatchNorm1d(d_model), Transpose(1,2)) self.norm3 = nn.Sequential(Transpose(1,2), nn.BatchNorm1d(d_model), Transpose(1,2)) else: self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.norm3 = nn.LayerNorm(d_model) self.dropout = nn.Dropout(dropout) self.MLP1 = CMlp(in_features = d_model, hidden_features = d_ff, out_features = d_model, drop=dropout) def forward(self, x, cross): batch, n_vars, num_patch, d_model = x.shape x = x.reshape(batch*n_vars, num_patch, d_model) # x = x.permute(0,2,1) # x = x + self.pos_embed(x) # x = x.permute(0,2,1) cross = cross.reshape(batch*n_vars, -1, d_model) attention_mask = causal_attention_mask(num_patch).to(x.device) x_attn , _= self.self_attention(x, attn_mask=attention_mask) x_attn = self.norm1(x_attn) + x x_cross , _ = self.cross_attention(x_attn, cross, cross) x_cross = self.dropout(self.norm2(x_cross)) + x_attn x_ff = self.MLP1(x_cross) x_ff = self.norm3(x_ff) + x_cross x_ff = x_ff.reshape(batch, n_vars, num_patch, d_model) return x_ff def causal_attention_mask(seq_length): """ 创建一个因果注意力掩码。掩码中的每个位置 (i, j) 表示在计算第i个位置的attention时, 第j个位置是否可以被看见。 如果j <= i, 这个位置被设为1(可见), 否则设为0(不可见)。 Args: seq_length (int): 序列的长度 Returns: torch.Tensor: 因果注意力掩码,大小为 (seq_length, seq_length) """ mask = torch.triu(torch.ones(seq_length, seq_length) * float('-inf'), diagonal=1) return mask class CMlp(nn.Module): def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Conv1d(in_features, hidden_features, 1) self.act = act_layer() self.fc2 = nn.Conv1d(hidden_features, out_features, 1) self.drop = nn.Dropout(drop) def forward(self, x): x = x.permute(0,2,1) x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) x = x.permute(0,2,1) return x class Transpose(nn.Module): def __init__(self, *dims, contiguous=False): super().__init__() self.dims, self.contiguous = dims, contiguous def forward(self, x): if self.contiguous: return x.transpose(*self.dims).contiguous() else: return x.transpose(*self.dims) class MultiheadAttention(nn.Module): def __init__(self, d_model, n_heads, d_k=None, d_v=None, res_attention=False, attn_dropout=0., proj_dropout=0., qkv_bias=True, lsa=False, rope_type=False): """Multi Head Attention Layer Input shape: Q: [batch_size (bs) x max_q_len x d_model] K, V: [batch_size (bs) x q_len x d_model] mask: [q_len x q_len] """ super().__init__() d_k = d_model // n_heads if d_k is None else d_k d_v = d_model // n_heads if d_v is None else d_v self.n_heads, self.d_k, self.d_v = n_heads, d_k, d_v self.W_Q = nn.Linear(d_model, d_k * n_heads, bias=qkv_bias) self.W_K = nn.Linear(d_model, d_k * n_heads, bias=qkv_bias) self.W_V = nn.Linear(d_model, d_v * n_heads, bias=qkv_bias) # Scaled Dot-Product Attention (multiple heads) self.res_attention = res_attention self.sdp_attn = ScaledDotProductAttention(d_model, n_heads, attn_dropout=attn_dropout, res_attention=self.res_attention, lsa=lsa, rope_type=rope_type) # Poject output self.to_out = nn.Sequential(nn.Linear(n_heads * d_v, d_model), nn.Dropout(proj_dropout)) def forward(self, Q:Tensor, K:Optional[Tensor]=None, V:Optional[Tensor]=None, prev:Optional[Tensor]=None, key_padding_mask:Optional[Tensor]=None, attn_mask:Optional[Tensor]=None): bs = Q.size(0) if K is None: K = Q if V is None: V = Q # Linear (+ split in multiple heads) q_s = self.W_Q(Q).view(bs, -1, self.n_heads, self.d_k).transpose(1,2) # q_s : [bs x n_heads x max_q_len x d_k] k_s = self.W_K(K).view(bs, -1, self.n_heads, self.d_k).permute(0,2,3,1) # k_s : [bs x n_heads x d_k x q_len] - transpose(1,2) + transpose(2,3) v_s = self.W_V(V).view(bs, -1, self.n_heads, self.d_v).transpose(1,2) # v_s : [bs x n_heads x q_len x d_v] # Apply Scaled Dot-Product Attention (multiple heads) if self.res_attention: output, attn_weights, attn_scores = self.sdp_attn(q_s, k_s, v_s, prev=prev, key_padding_mask=key_padding_mask, attn_mask=attn_mask) else: output, attn_weights = self.sdp_attn(q_s, k_s, v_s, key_padding_mask=key_padding_mask, attn_mask=attn_mask) # output: [bs x n_heads x q_len x d_v], attn: [bs x n_heads x q_len x q_len], scores: [bs x n_heads x max_q_len x q_len] # back to the original inputs dimensions output = output.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * self.d_v) # output: [bs x q_len x n_heads * d_v] output = self.to_out(output) if self.res_attention: return output, attn_weights, attn_scores else: return output, attn_weights class ScaledDotProductAttention(nn.Module): r"""Scaled Dot-Product Attention module (Attention is all you need by Vaswani et al., 2017) with optional residual attention from previous layer (Realformer: Transformer likes residual attention by He et al, 2020) and locality self sttention (Vision Transformer for Small-Size Datasets by Lee et al, 2021)""" def __init__(self, d_model, n_heads, attn_dropout=0., res_attention=False, lsa=False, rope_type=False): super().__init__() self.attn_dropout = nn.Dropout(attn_dropout) self.res_attention = res_attention head_dim = d_model // n_heads self.scale = nn.Parameter(torch.tensor(head_dim ** -0.5), requires_grad=lsa) self.lsa = lsa self.rope_type = rope_type def forward(self, q:Tensor, k:Tensor, v:Tensor, prev:Optional[Tensor]=None, key_padding_mask:Optional[Tensor]=None, attn_mask:Optional[Tensor]=None): ''' Input shape: q : [bs x n_heads x max_q_len x d_k] k : [bs x n_heads x d_k x seq_len] v : [bs x n_heads x seq_len x d_v] prev : [bs x n_heads x q_len x seq_len] key_padding_mask: [bs x seq_len] attn_mask : [1 x seq_len x seq_len] Output shape: output: [bs x n_heads x q_len x d_v] attn : [bs x n_heads x q_len x seq_len] scores : [bs x n_heads x q_len x seq_len] ''' # using RoPE if self.rope_type: q, k = RoPE_decoder(q, k.permute(0,1,3,2)) else: q, k = RoPE(q, k.permute(0,1,3,2)) k = k.permute(0,1,3,2) # Scaled MatMul (q, k) - similarity scores for all pairs of positions in an input sequence attn_scores = torch.matmul(q, k) * self.scale # attn_scores : [bs x n_heads x max_q_len x q_len] # Add pre-softmax attention scores from the previous layer (optional) if prev is not None: attn_scores = attn_scores + prev # Attention mask (optional) if attn_mask is not None: # attn_mask with shape [q_len x seq_len] - only used when q_len == seq_len if attn_mask.dtype == torch.bool: attn_scores.masked_fill_(attn_mask, -np.inf) else: attn_scores += attn_mask # Key padding mask (optional) if key_padding_mask is not None: # mask with shape [bs x q_len] (only when max_w_len == q_len) attn_scores.masked_fill_(key_padding_mask.unsqueeze(1).unsqueeze(2), -np.inf) # normalize the attention weights attn_weights = F.softmax(attn_scores, dim=-1) # attn_weights : [bs x n_heads x max_q_len x q_len] attn_weights = self.attn_dropout(attn_weights) # compute the new values given the attention weights output = torch.matmul(attn_weights, v) # output: [bs x n_heads x max_q_len x d_v] if self.res_attention: return output, attn_weights, attn_scores else: return output, attn_weights def RoPE(q, k): # q,k: (bs, head, max_len, output_dim) batch_size = q.shape[0] nums_head = q.shape[1] max_len = q.shape[2] output_dim = q.shape[-1] # (bs, head, max_len, output_dim) pos_emb = sinusoidal_position_embedding(batch_size, nums_head, max_len, output_dim, q.device, factor=1) # cos_pos,sin_pos: (bs, head, max_len, output_dim) # 看rope公式可知,相邻cos,sin之间是相同的,所以复制一遍。如(1,2,3)变成(1,1,2,2,3,3) cos_pos = pos_emb[..., 1::2].repeat_interleave(2, dim=-1) # 将奇数列信息抽取出来也就是cos 拿出来并复制 sin_pos = pos_emb[..., ::2].repeat_interleave(2, dim=-1) # 将偶数列信息抽取出来也就是sin 拿出来并复制 # q,k: (bs, head, max_len, output_dim) q2 = torch.stack([-q[..., 1::2], q[..., ::2]], dim=-1) q2 = q2.reshape(q.shape) # reshape后就是正负交替了 # 更新qw, *对应位置相乘 q = q * cos_pos + q2 * sin_pos k2 = torch.stack([-k[..., 1::2], k[..., ::2]], dim=-1) k2 = k2.reshape(k.shape) # 更新kw, *对应位置相乘 k = k * cos_pos + k2 * sin_pos return q, k def RoPE_decoder(q, k): # q,k: (bs, head, max_len, output_dim) batch_size = q.shape[0] nums_head = q.shape[1] q_max_len = q.shape[2] k_max_len = k.shape[2] output_dim = q.shape[-1] # (bs, head, max_len, output_dim) pos_emb = sinusoidal_position_embedding(batch_size, nums_head, k_max_len + q_max_len, output_dim, q.device, factor=1) # cos_pos,sin_pos: (bs, head, max_len, output_dim) # 看rope公式可知,相邻cos,sin之间是相同的,所以复制一遍。如(1,2,3)变成(1,1,2,2,3,3) cos_pos = pos_emb[..., 1::2].repeat_interleave(2, dim=-1) # 将奇数列信息抽取出来也就是cos 拿出来并复制 sin_pos = pos_emb[..., ::2].repeat_interleave(2, dim=-1) # 将偶数列信息抽取出来也就是sin 拿出来并复制 # q,k: (bs, head, max_len, output_dim) q2 = torch.stack([-q[..., 1::2], q[..., ::2]], dim=-1) q2 = q2.reshape(q.shape) # reshape后就是正负交替了 # 更新qw, *对应位置相乘 q = q * cos_pos[:,:,-q_max_len:,:] + q2 * sin_pos[:,:,-q_max_len:,:] k2 = torch.stack([-k[..., 1::2], k[..., ::2]], dim=-1) k2 = k2.reshape(k.shape) # 更新kw, *对应位置相乘 k = k * cos_pos[:,:,:k_max_len,:] + k2 * sin_pos[:,:,:k_max_len,:] return q, k def sinusoidal_position_embedding(batch_size, nums_head, max_len, output_dim, device, factor=1.0): # (max_len * factor, 1) position = torch.arange(0, max_len * factor, 1 / factor, dtype=torch.float).unsqueeze(-1) # (output_dim//2) ids = torch.arange(0, output_dim // 2, dtype=torch.float) # i 范围是 [0, d/2] theta = torch.pow(10000, -2 * ids / output_dim) # (max_len * factor, output_dim//2) embeddings = position * theta # (max_len * factor, output_dim//2, 2) embeddings = torch.stack([torch.sin(embeddings), torch.cos(embeddings)], dim=-1) # (bs, head, max_len * factor, output_dim//2, 2) embeddings = embeddings.repeat((batch_size, nums_head, *([1] * len(embeddings.shape)))) # (bs, head, max_len * factor, output_dim) embeddings = torch.reshape(embeddings, (batch_size, nums_head, -1, output_dim)) embeddings = embeddings.to(device) # 如果 factor > 1, 使用插值位置来生成更细粒度的嵌入 if factor > 1.0: interpolation_indices = torch.linspace(0, embeddings.shape[2] - 1, max_len).long() embeddings = embeddings[:, :, interpolation_indices, :] return embeddings class PretrainHead(nn.Module): def __init__(self, d_model, patch_len, dropout): super().__init__() self.dropout = nn.Dropout(dropout) self.linear = nn.Linear(d_model, patch_len) def forward(self, x): """ x: tensor [bs x nvars x d_model x num_patch] output: tensor [bs x nvars x num_patch x patch_len] """ x = x.transpose(2,3) # [bs x nvars x num_patch x d_model] x = self.linear( self.dropout(x) ) # [bs x nvars x num_patch x patch_len] x = x.permute(0,2,1,3) # [bs x num_patch x nvars x patch_len] return x class decoder_PredictHead(nn.Module): def __init__(self, d_model, patch_len, target_patch_len, dropout): super().__init__() self.dropout = nn.Dropout(dropout) self.linear = nn.Linear(d_model, target_patch_len) self.d_model = d_model def forward(self, x, patch_len): """ x: tensor [bs x nvars x d_model x num_patch] output: tensor [bs x nvars x num_patch x patch_len] """ Linear = nn.Linear(self.d_model, patch_len, bias=False) Linear.weight.data = resample_patchemb(old=self.linear.weight.data.T, new_patch_len=patch_len).T x = x.transpose(2,3) # [bs x nvars x num_patch x d_model] x = Linear( self.dropout(x) ) # [bs x nvars x num_patch x patch_len] x = x.permute(0,2,3,1) # [bs x num_patch x x patch_len x nvars] return x.reshape(x.shape[0],-1,x.shape[3]) def resample_patchemb(old: torch.Tensor, new_patch_len: int): assert old.dim() == 2, "输入张量应为2D (d_model, patch_size)" if old.size(1) == new_patch_len: return old old = old.T old_shape = old.size(0) factor = new_patch_len/old_shape # 定义辅助函数:批量resize def resize(x_tensor, new_shape): return F.interpolate(x_tensor.unsqueeze(0), size=new_shape, mode='linear').squeeze(0) # 构造缩放矩阵 basis_vectors = torch.eye(old_shape, dtype=torch.float32, device=old.device) resize_mat = resize(basis_vectors, new_patch_len).T # 计算伪逆 resize_mat_pinv = torch.linalg.pinv(resize_mat.T) # 直接矩阵操作完成重采样 resampled_kernels = resize_mat_pinv @ old * math.sqrt(factor) return resampled_kernels.T def get_activation_fn(activation): if callable(activation): return activation() elif activation.lower() == "relu": return nn.ReLU() elif activation.lower() == "gelu": return nn.GELU() raise ValueError(f'{activation} is not available. You can use "relu", "gelu", or a callable')