LightGTS / modeling_LightGTS.py
pchen182224's picture
Upload 9 files
c882c3e verified
from transformers import PreTrainedModel
from configuration_LightGTS import LightGTSConfig
from ts_generation_mixin import TSGenerationMixin
import torch
from torch import nn
from torch import Tensor
from typing import Callable, Optional
import math
import torch.nn.functional as F
import numpy as np
class LightGTSPreTrainedModel(PreTrainedModel):
config_class = LightGTSConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["TSTEncoderLayer"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_sdpa = False
_supports_cache_class = True
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, torch.nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, torch.nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
class LightGTSForPrediction(LightGTSPreTrainedModel, TSGenerationMixin):
def __init__(self, config: LightGTSConfig):
super().__init__(config)
self.config = config
self.model = LightGTS(c_in=config.c_in,
target_dim=config.target_dim,
patch_len=config.patch_len,
stride=config.stride,
num_patch=config.num_patch,
e_layers=config.e_layers,
d_layers=config.d_layers,
n_heads=config.n_heads,
d_model=config.d_model,
shared_embedding=True,
d_ff=config.d_ff,
dropout=config.dropout,
attn_dropout=config.attn_dropout,
head_dropout=config.head_dropout,
act='relu',
head_type=config.head_type,
res_attention=False,
learn_pe=False
)
def forward(self, input, labels=None, patch_len=None, stride=None, target_dim=None):
self.config.patch_len = patch_len
self.config.stride = stride
self.config.target_dim = target_dim
#切patch
batch_size,seq_len,n_vars = input.shape
num_patch = (max(seq_len, self.config.patch_len)-self.config.patch_len) // self.config.stride + 1
self.config.num_patch = num_patch
outputs = input.view(batch_size, num_patch, self.config.patch_len, n_vars)
outputs = outputs.transpose(2, 3)
outputs = self.model(outputs, target_dim=self.config.target_dim, patch_len=self.config.patch_len, stride=self.config.stride)
loss = None
if labels is not None:
if outputs.shape != labels.shape:
outputs = outputs.view(labels.shape)
loss = self.loss_fn(outputs, labels)
return {"prediction": outputs, "loss": loss}
class LightGTSForFinetune(LightGTSPreTrainedModel, TSGenerationMixin):
def __init__(self, config: LightGTSConfig):
super().__init__(config)
self.config = config
self.model = LightGTS(c_in=config.c_in,
target_dim=config.target_dim,
patch_len=config.patch_len,
stride=config.stride,
num_patch=config.num_patch,
e_layers=config.e_layers,
d_layers=config.d_layers,
n_heads=config.n_heads,
d_model=config.d_model,
shared_embedding=True,
d_ff=config.d_ff,
dropout=config.dropout,
attn_dropout=config.attn_dropout,
head_dropout=config.head_dropout,
act='relu',
head_type=config.head_type,
res_attention=False,
learn_pe=False
)
def forward(self, input, labels=None, patch_len=None, stride=None, target_dim=None):
if patch_len is not None:
self.config.patch_len = patch_len
if stride is not None:
self.config.stride = stride
if target_dim is not None:
self.config.target_dim = target_dim
#切patch
batch_size,seq_len,n_vars = input.shape
num_patch = (max(seq_len, self.config.patch_len)-self.config.patch_len) // self.config.stride + 1
self.config.num_patch = num_patch
outputs = input.view(batch_size, num_patch, self.config.patch_len, n_vars)
outputs = outputs.transpose(2, 3)
outputs = self.model(outputs, target_dim=self.config.target_dim, patch_len=self.config.patch_len, stride=self.config.stride)
loss = None
if labels is not None:
if outputs.shape != labels.shape:
outputs = outputs.view(labels.shape)
loss = self.loss_fn(outputs, labels)
return {"prediction": outputs, "loss": loss}
class LightGTS(nn.Module):
"""
Output dimension:
[bs x target_dim x nvars] for prediction
[bs x target_dim] for regression
[bs x target_dim] for classification
[bs x num_patch x n_vars x patch_len] for pretrain
"""
def __init__(self, c_in:int, target_dim:int, patch_len:int, stride:int, num_patch:int, mask_mode:str = 'patch',mask_nums:int = 3,
e_layers:int=3, d_layers:int=3, d_model=128, n_heads=16, shared_embedding=True, d_ff:int=256,
norm:str='BatchNorm', attn_dropout:float=0.4, dropout:float=0., act:str="gelu",
res_attention:bool=True, pre_norm:bool=False, store_attn:bool=False,
pe:str='sincos', learn_pe:bool=False, head_dropout = 0,
head_type = "prediction", individual = False,
y_range:Optional[tuple]=None, verbose:bool=False, **kwargs):
super().__init__()
assert head_type in ['pretrain', 'prediction', 'regression', 'classification'], 'head type should be either pretrain, prediction, or regression'
# Basic
self.num_patch = num_patch
self.target_dim=target_dim
self.out_patch_num = math.ceil(target_dim / patch_len)
self.target_patch_len = 48
# Embedding
self.embedding = nn.Linear(self.target_patch_len, d_model)
# self.decoder_embedding = nn.Parameter(torch.randn(1, 1,1, d_model),requires_grad=True)
self.cls_embedding = nn.Parameter(torch.randn(1, 1, 1, d_model),requires_grad=True)
# self.sep_embedding = nn.Parameter(torch.randn(1, 1, 1, d_model),requires_grad=True)
# Position Embedding
# self.pos = positional_encoding(pe, learn_pe, 1 + num_patch + self.out_patch_num, d_model)
# self.drop_out = nn.Dropout(dropout)
# Encoder
self.encoder = TSTEncoder(d_model, n_heads, d_ff=d_ff, norm=norm, attn_dropout=attn_dropout, dropout=dropout,
pre_norm=pre_norm, activation=act, res_attention=res_attention, n_layers=e_layers,
store_attn=store_attn)
# Decoder
self.decoder = Decoder(d_layers, patch_len=patch_len, d_model=d_model, n_heads=n_heads, d_ff=d_ff,attn_dropout= attn_dropout, dropout=dropout)
# Head
self.n_vars = c_in
self.head_type = head_type
self.mask_mode = mask_mode
self.mask_nums = mask_nums
self.d_model = d_model
self.patch_len = patch_len
if head_type == "pretrain":
self.head = PretrainHead(d_model, patch_len, head_dropout) # custom head passed as a partial func with all its kwargs
elif head_type == "prediction":
self.head = decoder_PredictHead(d_model, self.patch_len, self.target_patch_len, head_dropout)
def get_dynamic_weights(self, n_preds, decay_rate=0.5):
"""
Generate dynamic weights for the replicated tokens using an exponential decay scheme.
Args:
- n_preds (int): Number of predictions to generate weights for.
- decay_rate (float): The base of the exponential decay. Lower values decay faster (default: 0.9).
Returns:
- torch.Tensor: A tensor of weights with exponential decay.
"""
# Exponential decay weights
weights = decay_rate ** torch.arange(n_preds)
return weights
def decoder_predict(self, bs, n_vars, dec_cross):
"""
dec_cross: tensor [bs x n_vars x num_patch x d_model]
"""
# dec_in = self.decoder_embedding.expand(bs, self.n_vars, self.out_patch_num, -1)
# dec_in = self.embedding(self.decoder_len).expand(bs, -1, -1, -1)
# dec_in = self.decoder_embedding.expand(bs, n_vars, self.out_patch_num, -1)
# dec_in = dec_cross.mean(2).unsqueeze(2).expand(-1,-1,self.out_patch_num,-1)
dec_in = dec_cross[:,:,-1,:].unsqueeze(2).expand(-1,-1,self.out_patch_num,-1)
weights = self.get_dynamic_weights(self.out_patch_num).to(dec_in.device)
dec_in = dec_in * weights.unsqueeze(0).unsqueeze(0).unsqueeze(-1)
# dec_in = torch.cat((dec_in, self.sep_tokens), dim=2)
# dec_in = dec_cross[:,:,-self.out_patch_num:,:]
# dec_in = torch.ones([bs, n_vars, self.out_patch_num, self.d_model]).to(dec_cross.device)
# dec_in = dec_in + self.pos[-self.out_patch_num:,:]
decoder_output = self.decoder(dec_in, dec_cross)
decoder_output = decoder_output.transpose(2,3)
return decoder_output
def forward(self, z, target_dim=None, patch_len=None, stride=None):
"""
z: tensor [bs x num_patch x n_vars x patch_len]
"""
if target_dim is not None:
self.target_dim = target_dim
if patch_len is not None:
self.patch_len = patch_len
if stride is not None:
self.stride = stride
self.out_patch_num = math.ceil(self.target_dim / self.patch_len)
bs, num_patch, n_vars, patch_len = z.shape
# tokenizer
cls_tokens = self.cls_embedding.expand(bs, n_vars, -1, -1)
embedding = nn.Linear(patch_len, self.d_model, bias=False)
embedding.weight.data = resample_patchemb(old=self.embedding.weight.data, new_patch_len=self.patch_len)
z = embedding(z).permute(0,2,1,3) # [bs x n_vars x num_patch x d_model]
z = torch.cat((cls_tokens, z), dim=2) # [bs x n_vars x (1 + num_patch) x d_model]
# z = self.drop_out(z + self.pos[:1 + self.num_patch, :])
# encoder
z = torch.reshape(z, (-1, 1 + num_patch, self.d_model)) # [bs*n_vars x num_patch x d_model]
z = self.encoder(z)
z = torch.reshape(z, (-1, n_vars, 1 + num_patch, self.d_model)) # [bs, n_vars x num_patch x d_model]
# decoder
z = self.decoder_predict(bs, n_vars, z[:,:,:,:])
# predict
z = self.head(z[:,:,:,:], self.patch_len)
z = z[:,:self.target_dim, :]
# z: [bs x target_dim x nvars] for prediction
# [bs x target_dim] for regression
# [bs x target_dim] for classification
# [bs x num_patch x n_vars x patch_len] for pretrain
return z
class TSTEncoder(nn.Module):
def __init__(self, d_model, n_heads, d_ff=None,
norm='BatchNorm', attn_dropout=0., dropout=0., activation='gelu',
res_attention=False, n_layers=1, pre_norm=False, store_attn=False):
super().__init__()
self.layers = nn.ModuleList([TSTEncoderLayer(d_model, n_heads=n_heads, d_ff=d_ff, norm=norm,
attn_dropout=attn_dropout, dropout=dropout,
activation=activation, res_attention=res_attention,
pre_norm=pre_norm, store_attn=store_attn) for i in range(n_layers)])
self.res_attention = res_attention
def forward(self, src:Tensor):
"""
src: tensor [bs x q_len x d_model]
"""
output = src
scores = None
if self.res_attention:
for mod in self.layers: output, scores = mod(output, prev=scores)
return output
else:
for mod in self.layers: output = mod(output)
return output
class TSTEncoderLayer(nn.Module):
def __init__(self, d_model, n_heads, d_ff=256, store_attn=False,
norm='LayerNorm', attn_dropout=0, dropout=0., bias=True,
activation="gelu", res_attention=False, pre_norm=False):
super().__init__()
assert not d_model%n_heads, f"d_model ({d_model}) must be divisible by n_heads ({n_heads})"
d_k = d_model // n_heads
d_v = d_model // n_heads
# Multi-Head attention
self.res_attention = res_attention
self.self_attn = MultiheadAttention(d_model, n_heads, d_k, d_v, attn_dropout=attn_dropout, proj_dropout=dropout, res_attention=res_attention)
# Add & Norm
self.dropout_attn = nn.Dropout(dropout)
if "batch" in norm.lower():
self.norm_attn = nn.Sequential(Transpose(1,2), nn.BatchNorm1d(d_model), Transpose(1,2))
else:
self.norm_attn = nn.LayerNorm(d_model)
# Position-wise Feed-Forward
self.ff = nn.Sequential(nn.Linear(d_model, d_ff, bias=bias),
get_activation_fn(activation),
nn.Dropout(dropout),
nn.Linear(d_ff, d_model, bias=bias))
# Add & Norm
self.dropout_ffn = nn.Dropout(dropout)
if "batch" in norm.lower():
self.norm_ffn = nn.Sequential(Transpose(1,2), nn.BatchNorm1d(d_model), Transpose(1,2))
else:
self.norm_ffn = nn.LayerNorm(d_model)
self.pre_norm = pre_norm
self.store_attn = store_attn
# # se block
# self.SE = SE_Block(inchannel=7)
def forward(self, src:Tensor, prev:Optional[Tensor]=None):
"""
src: tensor [bs x q_len x d_model]
"""
# Multi-Head attention sublayer
if self.pre_norm:
src = self.norm_attn(src)
## Multi-Head attention
if self.res_attention:
src2, attn, scores = self.self_attn(src, src, src, prev)
else:
# attention_mask = causal_attention_mask(src.shape[1]).to(src.device)
# src2, attn = self.self_attn(src, src, src, attn_mask=attention_mask)
src2, attn = self.self_attn(src, src, src)
if self.store_attn:
self.attn = attn
# total, num_patch, d_model = src2.size()
# bs = int(total/7)
# src2 = self.SE(src2.reshape(bs, 7, num_patch, -1)).reshape(total, num_patch, -1)
## Add & Norm
src = src + self.dropout_attn(src2) # Add: residual connection with residual dropout
if not self.pre_norm:
src = self.norm_attn(src)
# Feed-forward sublayer
if self.pre_norm:
src = self.norm_ffn(src)
## Position-wise Feed-Forward
src2 = self.ff(src)
## Add & Norm
src = src + self.dropout_ffn(src2) # Add: residual connection with residual dropout
if not self.pre_norm:
src = self.norm_ffn(src)
if self.res_attention:
return src, scores
else:
return src
class Decoder(nn.Module):
def __init__(self, d_layers, patch_len, d_model, n_heads, d_ff=None, attn_dropout=0.2, dropout=0.1):
super(Decoder, self).__init__()
self.decoder_layers = nn.ModuleList()
for i in range(d_layers):
self.decoder_layers.append(DecoderLayer(patch_len, d_model, n_heads, d_ff, attn_dropout, dropout))
def forward(self, x, cross):
output = x
for layer in self.decoder_layers:
output = layer(output, cross)
return output
class DecoderLayer(nn.Module):
def __init__(self, patch_len, d_model, n_heads, d_ff=None, attn_dropout = 0.2, dropout=0.5, norm="BatchNorm"):
super(DecoderLayer, self).__init__()
self.self_attention = MultiheadAttention(d_model, n_heads, res_attention=False, attn_dropout=attn_dropout)
self.cross_attention = MultiheadAttention(d_model, n_heads, attn_dropout=attn_dropout, rope_type=True)
# self.pos_embed = nn.Conv1d(d_model, d_model, kernel_size=3, padding=1, groups=d_model)
if 'batch' in norm.lower():
self.norm1 = nn.Sequential(Transpose(1,2), nn.BatchNorm1d(d_model), Transpose(1,2))
self.norm2 = nn.Sequential(Transpose(1,2), nn.BatchNorm1d(d_model), Transpose(1,2))
self.norm3 = nn.Sequential(Transpose(1,2), nn.BatchNorm1d(d_model), Transpose(1,2))
else:
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
self.MLP1 = CMlp(in_features = d_model, hidden_features = d_ff, out_features = d_model, drop=dropout)
def forward(self, x, cross):
batch, n_vars, num_patch, d_model = x.shape
x = x.reshape(batch*n_vars, num_patch, d_model)
# x = x.permute(0,2,1)
# x = x + self.pos_embed(x)
# x = x.permute(0,2,1)
cross = cross.reshape(batch*n_vars, -1, d_model)
attention_mask = causal_attention_mask(num_patch).to(x.device)
x_attn , _= self.self_attention(x, attn_mask=attention_mask)
x_attn = self.norm1(x_attn) + x
x_cross , _ = self.cross_attention(x_attn, cross, cross)
x_cross = self.dropout(self.norm2(x_cross)) + x_attn
x_ff = self.MLP1(x_cross)
x_ff = self.norm3(x_ff) + x_cross
x_ff = x_ff.reshape(batch, n_vars, num_patch, d_model)
return x_ff
def causal_attention_mask(seq_length):
"""
创建一个因果注意力掩码。掩码中的每个位置 (i, j)
表示在计算第i个位置的attention时, 第j个位置是否可以被看见。
如果j <= i, 这个位置被设为1(可见), 否则设为0(不可见)。
Args:
seq_length (int): 序列的长度
Returns:
torch.Tensor: 因果注意力掩码,大小为 (seq_length, seq_length)
"""
mask = torch.triu(torch.ones(seq_length, seq_length) * float('-inf'), diagonal=1)
return mask
class CMlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Conv1d(in_features, hidden_features, 1)
self.act = act_layer()
self.fc2 = nn.Conv1d(hidden_features, out_features, 1)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = x.permute(0,2,1)
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
x = x.permute(0,2,1)
return x
class Transpose(nn.Module):
def __init__(self, *dims, contiguous=False):
super().__init__()
self.dims, self.contiguous = dims, contiguous
def forward(self, x):
if self.contiguous: return x.transpose(*self.dims).contiguous()
else: return x.transpose(*self.dims)
class MultiheadAttention(nn.Module):
def __init__(self, d_model, n_heads, d_k=None, d_v=None, res_attention=False, attn_dropout=0., proj_dropout=0., qkv_bias=True, lsa=False, rope_type=False):
"""Multi Head Attention Layer
Input shape:
Q: [batch_size (bs) x max_q_len x d_model]
K, V: [batch_size (bs) x q_len x d_model]
mask: [q_len x q_len]
"""
super().__init__()
d_k = d_model // n_heads if d_k is None else d_k
d_v = d_model // n_heads if d_v is None else d_v
self.n_heads, self.d_k, self.d_v = n_heads, d_k, d_v
self.W_Q = nn.Linear(d_model, d_k * n_heads, bias=qkv_bias)
self.W_K = nn.Linear(d_model, d_k * n_heads, bias=qkv_bias)
self.W_V = nn.Linear(d_model, d_v * n_heads, bias=qkv_bias)
# Scaled Dot-Product Attention (multiple heads)
self.res_attention = res_attention
self.sdp_attn = ScaledDotProductAttention(d_model, n_heads, attn_dropout=attn_dropout, res_attention=self.res_attention, lsa=lsa, rope_type=rope_type)
# Poject output
self.to_out = nn.Sequential(nn.Linear(n_heads * d_v, d_model), nn.Dropout(proj_dropout))
def forward(self, Q:Tensor, K:Optional[Tensor]=None, V:Optional[Tensor]=None, prev:Optional[Tensor]=None,
key_padding_mask:Optional[Tensor]=None, attn_mask:Optional[Tensor]=None):
bs = Q.size(0)
if K is None: K = Q
if V is None: V = Q
# Linear (+ split in multiple heads)
q_s = self.W_Q(Q).view(bs, -1, self.n_heads, self.d_k).transpose(1,2) # q_s : [bs x n_heads x max_q_len x d_k]
k_s = self.W_K(K).view(bs, -1, self.n_heads, self.d_k).permute(0,2,3,1) # k_s : [bs x n_heads x d_k x q_len] - transpose(1,2) + transpose(2,3)
v_s = self.W_V(V).view(bs, -1, self.n_heads, self.d_v).transpose(1,2) # v_s : [bs x n_heads x q_len x d_v]
# Apply Scaled Dot-Product Attention (multiple heads)
if self.res_attention:
output, attn_weights, attn_scores = self.sdp_attn(q_s, k_s, v_s, prev=prev, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
else:
output, attn_weights = self.sdp_attn(q_s, k_s, v_s, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
# output: [bs x n_heads x q_len x d_v], attn: [bs x n_heads x q_len x q_len], scores: [bs x n_heads x max_q_len x q_len]
# back to the original inputs dimensions
output = output.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * self.d_v) # output: [bs x q_len x n_heads * d_v]
output = self.to_out(output)
if self.res_attention: return output, attn_weights, attn_scores
else: return output, attn_weights
class ScaledDotProductAttention(nn.Module):
r"""Scaled Dot-Product Attention module (Attention is all you need by Vaswani et al., 2017) with optional residual attention from previous layer
(Realformer: Transformer likes residual attention by He et al, 2020) and locality self sttention (Vision Transformer for Small-Size Datasets
by Lee et al, 2021)"""
def __init__(self, d_model, n_heads, attn_dropout=0., res_attention=False, lsa=False, rope_type=False):
super().__init__()
self.attn_dropout = nn.Dropout(attn_dropout)
self.res_attention = res_attention
head_dim = d_model // n_heads
self.scale = nn.Parameter(torch.tensor(head_dim ** -0.5), requires_grad=lsa)
self.lsa = lsa
self.rope_type = rope_type
def forward(self, q:Tensor, k:Tensor, v:Tensor, prev:Optional[Tensor]=None, key_padding_mask:Optional[Tensor]=None, attn_mask:Optional[Tensor]=None):
'''
Input shape:
q : [bs x n_heads x max_q_len x d_k]
k : [bs x n_heads x d_k x seq_len]
v : [bs x n_heads x seq_len x d_v]
prev : [bs x n_heads x q_len x seq_len]
key_padding_mask: [bs x seq_len]
attn_mask : [1 x seq_len x seq_len]
Output shape:
output: [bs x n_heads x q_len x d_v]
attn : [bs x n_heads x q_len x seq_len]
scores : [bs x n_heads x q_len x seq_len]
'''
# using RoPE
if self.rope_type:
q, k = RoPE_decoder(q, k.permute(0,1,3,2))
else:
q, k = RoPE(q, k.permute(0,1,3,2))
k = k.permute(0,1,3,2)
# Scaled MatMul (q, k) - similarity scores for all pairs of positions in an input sequence
attn_scores = torch.matmul(q, k) * self.scale # attn_scores : [bs x n_heads x max_q_len x q_len]
# Add pre-softmax attention scores from the previous layer (optional)
if prev is not None: attn_scores = attn_scores + prev
# Attention mask (optional)
if attn_mask is not None: # attn_mask with shape [q_len x seq_len] - only used when q_len == seq_len
if attn_mask.dtype == torch.bool:
attn_scores.masked_fill_(attn_mask, -np.inf)
else:
attn_scores += attn_mask
# Key padding mask (optional)
if key_padding_mask is not None: # mask with shape [bs x q_len] (only when max_w_len == q_len)
attn_scores.masked_fill_(key_padding_mask.unsqueeze(1).unsqueeze(2), -np.inf)
# normalize the attention weights
attn_weights = F.softmax(attn_scores, dim=-1) # attn_weights : [bs x n_heads x max_q_len x q_len]
attn_weights = self.attn_dropout(attn_weights)
# compute the new values given the attention weights
output = torch.matmul(attn_weights, v) # output: [bs x n_heads x max_q_len x d_v]
if self.res_attention: return output, attn_weights, attn_scores
else: return output, attn_weights
def RoPE(q, k):
# q,k: (bs, head, max_len, output_dim)
batch_size = q.shape[0]
nums_head = q.shape[1]
max_len = q.shape[2]
output_dim = q.shape[-1]
# (bs, head, max_len, output_dim)
pos_emb = sinusoidal_position_embedding(batch_size, nums_head, max_len, output_dim, q.device, factor=1)
# cos_pos,sin_pos: (bs, head, max_len, output_dim)
# 看rope公式可知,相邻cos,sin之间是相同的,所以复制一遍。如(1,2,3)变成(1,1,2,2,3,3)
cos_pos = pos_emb[..., 1::2].repeat_interleave(2, dim=-1) # 将奇数列信息抽取出来也就是cos 拿出来并复制
sin_pos = pos_emb[..., ::2].repeat_interleave(2, dim=-1) # 将偶数列信息抽取出来也就是sin 拿出来并复制
# q,k: (bs, head, max_len, output_dim)
q2 = torch.stack([-q[..., 1::2], q[..., ::2]], dim=-1)
q2 = q2.reshape(q.shape) # reshape后就是正负交替了
# 更新qw, *对应位置相乘
q = q * cos_pos + q2 * sin_pos
k2 = torch.stack([-k[..., 1::2], k[..., ::2]], dim=-1)
k2 = k2.reshape(k.shape)
# 更新kw, *对应位置相乘
k = k * cos_pos + k2 * sin_pos
return q, k
def RoPE_decoder(q, k):
# q,k: (bs, head, max_len, output_dim)
batch_size = q.shape[0]
nums_head = q.shape[1]
q_max_len = q.shape[2]
k_max_len = k.shape[2]
output_dim = q.shape[-1]
# (bs, head, max_len, output_dim)
pos_emb = sinusoidal_position_embedding(batch_size, nums_head, k_max_len + q_max_len, output_dim, q.device, factor=1)
# cos_pos,sin_pos: (bs, head, max_len, output_dim)
# 看rope公式可知,相邻cos,sin之间是相同的,所以复制一遍。如(1,2,3)变成(1,1,2,2,3,3)
cos_pos = pos_emb[..., 1::2].repeat_interleave(2, dim=-1) # 将奇数列信息抽取出来也就是cos 拿出来并复制
sin_pos = pos_emb[..., ::2].repeat_interleave(2, dim=-1) # 将偶数列信息抽取出来也就是sin 拿出来并复制
# q,k: (bs, head, max_len, output_dim)
q2 = torch.stack([-q[..., 1::2], q[..., ::2]], dim=-1)
q2 = q2.reshape(q.shape) # reshape后就是正负交替了
# 更新qw, *对应位置相乘
q = q * cos_pos[:,:,-q_max_len:,:] + q2 * sin_pos[:,:,-q_max_len:,:]
k2 = torch.stack([-k[..., 1::2], k[..., ::2]], dim=-1)
k2 = k2.reshape(k.shape)
# 更新kw, *对应位置相乘
k = k * cos_pos[:,:,:k_max_len,:] + k2 * sin_pos[:,:,:k_max_len,:]
return q, k
def sinusoidal_position_embedding(batch_size, nums_head, max_len, output_dim, device, factor=1.0):
# (max_len * factor, 1)
position = torch.arange(0, max_len * factor, 1 / factor, dtype=torch.float).unsqueeze(-1)
# (output_dim//2)
ids = torch.arange(0, output_dim // 2, dtype=torch.float) # i 范围是 [0, d/2]
theta = torch.pow(10000, -2 * ids / output_dim)
# (max_len * factor, output_dim//2)
embeddings = position * theta
# (max_len * factor, output_dim//2, 2)
embeddings = torch.stack([torch.sin(embeddings), torch.cos(embeddings)], dim=-1)
# (bs, head, max_len * factor, output_dim//2, 2)
embeddings = embeddings.repeat((batch_size, nums_head, *([1] * len(embeddings.shape))))
# (bs, head, max_len * factor, output_dim)
embeddings = torch.reshape(embeddings, (batch_size, nums_head, -1, output_dim))
embeddings = embeddings.to(device)
# 如果 factor > 1, 使用插值位置来生成更细粒度的嵌入
if factor > 1.0:
interpolation_indices = torch.linspace(0, embeddings.shape[2] - 1, max_len).long()
embeddings = embeddings[:, :, interpolation_indices, :]
return embeddings
class PretrainHead(nn.Module):
def __init__(self, d_model, patch_len, dropout):
super().__init__()
self.dropout = nn.Dropout(dropout)
self.linear = nn.Linear(d_model, patch_len)
def forward(self, x):
"""
x: tensor [bs x nvars x d_model x num_patch]
output: tensor [bs x nvars x num_patch x patch_len]
"""
x = x.transpose(2,3) # [bs x nvars x num_patch x d_model]
x = self.linear( self.dropout(x) ) # [bs x nvars x num_patch x patch_len]
x = x.permute(0,2,1,3) # [bs x num_patch x nvars x patch_len]
return x
class decoder_PredictHead(nn.Module):
def __init__(self, d_model, patch_len, target_patch_len, dropout):
super().__init__()
self.dropout = nn.Dropout(dropout)
self.linear = nn.Linear(d_model, target_patch_len)
self.d_model = d_model
def forward(self, x, patch_len):
"""
x: tensor [bs x nvars x d_model x num_patch]
output: tensor [bs x nvars x num_patch x patch_len]
"""
Linear = nn.Linear(self.d_model, patch_len, bias=False)
Linear.weight.data = resample_patchemb(old=self.linear.weight.data.T, new_patch_len=patch_len).T
x = x.transpose(2,3) # [bs x nvars x num_patch x d_model]
x = Linear( self.dropout(x) ) # [bs x nvars x num_patch x patch_len]
x = x.permute(0,2,3,1) # [bs x num_patch x x patch_len x nvars]
return x.reshape(x.shape[0],-1,x.shape[3])
def resample_patchemb(old: torch.Tensor, new_patch_len: int):
assert old.dim() == 2, "输入张量应为2D (d_model, patch_size)"
if old.size(1) == new_patch_len:
return old
old = old.T
old_shape = old.size(0)
factor = new_patch_len/old_shape
# 定义辅助函数:批量resize
def resize(x_tensor, new_shape):
return F.interpolate(x_tensor.unsqueeze(0), size=new_shape, mode='linear').squeeze(0)
# 构造缩放矩阵
basis_vectors = torch.eye(old_shape, dtype=torch.float32, device=old.device)
resize_mat = resize(basis_vectors, new_patch_len).T
# 计算伪逆
resize_mat_pinv = torch.linalg.pinv(resize_mat.T)
# 直接矩阵操作完成重采样
resampled_kernels = resize_mat_pinv @ old * math.sqrt(factor)
return resampled_kernels.T
def get_activation_fn(activation):
if callable(activation): return activation()
elif activation.lower() == "relu": return nn.ReLU()
elif activation.lower() == "gelu": return nn.GELU()
raise ValueError(f'{activation} is not available. You can use "relu", "gelu", or a callable')