enver1323 commited on
Commit
8db2ca6
·
1 Parent(s): 1b75d67

feat: upload model

Browse files
Files changed (4) hide show
  1. README.md +10 -0
  2. config.json +27 -0
  3. model.py +707 -0
  4. model.safetensors +3 -0
README.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - model_hub_mixin
4
+ - pytorch_model_hub_mixin
5
+ ---
6
+
7
+ This model has been pushed to the Hub using the [PytorchModelHubMixin](https://huggingface.co/docs/huggingface_hub/package_reference/mixins#huggingface_hub.PyTorchModelHubMixin) integration:
8
+ - Code: [More Information Needed]
9
+ - Paper: [More Information Needed]
10
+ - Docs: [More Information Needed]
config.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation": "gelu",
3
+ "affine": false,
4
+ "attn_dropout": 0.0,
5
+ "c_in": 1,
6
+ "c_out": null,
7
+ "classification": true,
8
+ "d_ff": 2048,
9
+ "d_model": 512,
10
+ "decomposition": false,
11
+ "dropout": 0.05,
12
+ "individual": false,
13
+ "kernel_size": 25,
14
+ "n_heads": 8,
15
+ "n_layers": 2,
16
+ "norm": "BatchNorm",
17
+ "padding_patch": true,
18
+ "patch_len": 16,
19
+ "pre_norm": false,
20
+ "pred_dim": 2,
21
+ "res_attention": true,
22
+ "revin": true,
23
+ "seq_len": 82,
24
+ "store_attn": false,
25
+ "stride": 8,
26
+ "subtract_last": false
27
+ }
model.py ADDED
@@ -0,0 +1,707 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ from typing import Optional
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import Module
6
+ import torch.nn.functional as F
7
+ from torch import Tensor
8
+ from huggingface_hub import PyTorchModelHubMixin
9
+
10
+ warnings.filterwarnings("ignore", category=UserWarning)
11
+
12
+
13
+ class Transpose(Module):
14
+ def __init__(self, *dims, contiguous=False):
15
+ super(Transpose, self).__init__()
16
+ self.dims, self.contiguous = dims, contiguous
17
+
18
+ def forward(self, x):
19
+ if self.contiguous:
20
+ return x.transpose(*self.dims).contiguous()
21
+ else:
22
+ return x.transpose(*self.dims)
23
+
24
+ def __repr__(self):
25
+ if self.contiguous:
26
+ return f"{self.__class__.__name__}(dims={', '.join([str(d) for d in self.dims])}).contiguous()"
27
+ else:
28
+ return (
29
+ f"{self.__class__.__name__}({', '.join([str(d) for d in self.dims])})"
30
+ )
31
+
32
+
33
+ pytorch_acts = [
34
+ nn.ELU,
35
+ nn.LeakyReLU,
36
+ nn.PReLU,
37
+ nn.ReLU,
38
+ nn.ReLU6,
39
+ nn.SELU,
40
+ nn.CELU,
41
+ nn.GELU,
42
+ nn.Sigmoid,
43
+ nn.Softplus,
44
+ nn.Tanh,
45
+ nn.Softmax,
46
+ ]
47
+ pytorch_act_names = [a.__name__.lower() for a in pytorch_acts]
48
+
49
+
50
+ def get_act_fn(act, **act_kwargs):
51
+ if act is None:
52
+ return
53
+ elif isinstance(act, nn.Module):
54
+ return act
55
+ elif callable(act):
56
+ return act(**act_kwargs)
57
+ idx = pytorch_act_names.index(act.lower())
58
+ return pytorch_acts[idx](**act_kwargs)
59
+
60
+
61
+ class RevIN(nn.Module):
62
+ def __init__(
63
+ self,
64
+ c_in: int,
65
+ affine: bool = True,
66
+ subtract_last: bool = False,
67
+ dim: int = 2,
68
+ eps: float = 1e-5,
69
+ ):
70
+ super().__init__()
71
+ self.c_in, self.affine, self.subtract_last, self.dim, self.eps = (
72
+ c_in,
73
+ affine,
74
+ subtract_last,
75
+ dim,
76
+ eps,
77
+ )
78
+ if self.affine:
79
+ self.weight = nn.Parameter(torch.ones(1, c_in, 1))
80
+ self.bias = nn.Parameter(torch.zeros(1, c_in, 1))
81
+
82
+ def forward(self, x: Tensor, mode: Tensor):
83
+ if mode:
84
+ return self.normalize(x)
85
+ else:
86
+ return self.denormalize(x)
87
+
88
+ def normalize(self, x):
89
+ if self.subtract_last:
90
+ self.sub = x[..., -1].unsqueeze(-1).detach()
91
+ else:
92
+ self.sub = torch.mean(x, dim=-1, keepdim=True).detach()
93
+ self.std = (
94
+ torch.std(x, dim=-1, keepdim=True, unbiased=False).detach() + self.eps
95
+ )
96
+ if self.affine:
97
+ x = x.sub(self.sub)
98
+ x = x.div(self.std)
99
+ x = x.mul(self.weight)
100
+ x = x.add(self.bias)
101
+ return x
102
+ else:
103
+ x = x.sub(self.sub)
104
+ x = x.div(self.std)
105
+ return x
106
+
107
+ def denormalize(self, x):
108
+ if self.affine:
109
+ x = x.sub(self.bias)
110
+ x = x.div(self.weight)
111
+ x = x.mul(self.std)
112
+ x = x.add(self.sub)
113
+ return x
114
+ else:
115
+ x = x.mul(self.std)
116
+ x = x.add(self.sub)
117
+ return x
118
+
119
+
120
+ class MovingAverage(nn.Module):
121
+ def __init__(
122
+ self,
123
+ kernel_size: int,
124
+ ):
125
+ super().__init__()
126
+ padding_left = (kernel_size - 1) // 2
127
+ padding_right = kernel_size - padding_left - 1
128
+ self.padding = torch.nn.ReplicationPad1d((padding_left, padding_right))
129
+ self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=1)
130
+
131
+ def forward(self, x: Tensor):
132
+ return self.avg(self.padding(x))
133
+
134
+
135
+ class SeriesDecomposition(nn.Module):
136
+ def __init__(
137
+ self,
138
+ kernel_size: int, # the size of the window
139
+ ):
140
+ super().__init__()
141
+ self.moving_avg = MovingAverage(kernel_size)
142
+
143
+ def forward(self, x: Tensor):
144
+ moving_mean = self.moving_avg(x)
145
+ residual = x - moving_mean
146
+ return residual, moving_mean
147
+
148
+
149
+ class _ScaledDotProductAttention(nn.Module):
150
+ def __init__(self, d_model, n_heads, attn_dropout=0.0, res_attention=False):
151
+ super().__init__()
152
+ self.attn_dropout = nn.Dropout(attn_dropout)
153
+ self.res_attention = res_attention
154
+ head_dim = d_model // n_heads
155
+ self.scale = nn.Parameter(torch.tensor(head_dim**-0.5), requires_grad=False)
156
+
157
+ def forward(self, q: Tensor, k: Tensor, v: Tensor, prev: Optional[Tensor] = None):
158
+ attn_scores = torch.matmul(q, k) * self.scale
159
+
160
+ if prev is not None:
161
+ attn_scores = attn_scores + prev
162
+
163
+ attn_weights = F.softmax(attn_scores, dim=-1)
164
+ attn_weights = self.attn_dropout(attn_weights)
165
+
166
+ output = torch.matmul(attn_weights, v)
167
+
168
+ if self.res_attention:
169
+ return output, attn_weights, attn_scores
170
+ else:
171
+ return output, attn_weights
172
+
173
+
174
+ class _MultiheadAttention(nn.Module):
175
+ def __init__(
176
+ self,
177
+ d_model,
178
+ n_heads,
179
+ d_k=None,
180
+ d_v=None,
181
+ res_attention=False,
182
+ attn_dropout=0.0,
183
+ proj_dropout=0.0,
184
+ qkv_bias=True,
185
+ ):
186
+ "Multi Head Attention Layer"
187
+
188
+ super().__init__()
189
+ d_k = d_v = d_model // n_heads
190
+
191
+ self.n_heads, self.d_k, self.d_v = n_heads, d_k, d_v
192
+
193
+ self.W_Q = nn.Linear(d_model, d_k * n_heads, bias=qkv_bias)
194
+ self.W_K = nn.Linear(d_model, d_k * n_heads, bias=qkv_bias)
195
+ self.W_V = nn.Linear(d_model, d_v * n_heads, bias=qkv_bias)
196
+
197
+ # Scaled Dot-Product Attention (multiple heads)
198
+ self.res_attention = res_attention
199
+ self.sdp_attn = _ScaledDotProductAttention(
200
+ d_model,
201
+ n_heads,
202
+ attn_dropout=attn_dropout,
203
+ res_attention=self.res_attention,
204
+ )
205
+
206
+ # Poject output
207
+ self.to_out = nn.Sequential(
208
+ nn.Linear(n_heads * d_v, d_model), nn.Dropout(proj_dropout)
209
+ )
210
+
211
+ def forward(
212
+ self,
213
+ Q: Tensor,
214
+ K: Optional[Tensor] = None,
215
+ V: Optional[Tensor] = None,
216
+ prev: Optional[Tensor] = None,
217
+ ):
218
+ bs = Q.size(0)
219
+ if K is None:
220
+ K = Q
221
+ if V is None:
222
+ V = Q
223
+
224
+ # Linear (+ split in multiple heads)
225
+ q_s = (
226
+ self.W_Q(Q).view(bs, -1, self.n_heads, self.d_k).transpose(1, 2)
227
+ ) # q_s: [bs x n_heads x max_q_len x d_k]
228
+ k_s = (
229
+ self.W_K(K).view(bs, -1, self.n_heads, self.d_k).permute(0, 2, 3, 1)
230
+ ) # k_s: [bs x n_heads x d_k x q_len] - transpose(1,2) + transpose(2,3)
231
+ v_s = (
232
+ self.W_V(V).view(bs, -1, self.n_heads, self.d_v).transpose(1, 2)
233
+ ) # v_s: [bs x n_heads x q_len x d_v]
234
+
235
+ # Apply Scaled Dot-Product Attention (multiple heads)
236
+ if self.res_attention:
237
+ output, attn_weights, attn_scores = self.sdp_attn(q_s, k_s, v_s, prev=prev)
238
+ else:
239
+ output, attn_weights = self.sdp_attn(q_s, k_s, v_s)
240
+ # output: [bs x n_heads x q_len x d_v], attn: [bs x n_heads x q_len x q_len], scores: [bs x n_heads x max_q_len x q_len]
241
+
242
+ # back to the original inputs dimensions
243
+ output = (
244
+ output.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * self.d_v)
245
+ ) # output: [bs x q_len x n_heads * d_v]
246
+ output = self.to_out(output)
247
+
248
+ if self.res_attention:
249
+ return output, attn_weights, attn_scores
250
+ else:
251
+ return output, attn_weights
252
+
253
+
254
+ class Flatten_Head(nn.Module):
255
+ def __init__(self, individual, n_vars, nf, pred_dim):
256
+ super().__init__()
257
+
258
+ if isinstance(pred_dim, (tuple, list)):
259
+ pred_dim = pred_dim[-1]
260
+ self.individual = individual
261
+ self.n = n_vars if individual else 1
262
+ self.nf, self.pred_dim = nf, pred_dim
263
+
264
+ if individual:
265
+ self.layers = nn.ModuleList()
266
+ for i in range(self.n):
267
+ self.layers.append(
268
+ nn.Sequential(nn.Flatten(start_dim=-2), nn.Linear(nf, pred_dim))
269
+ )
270
+ else:
271
+ self.layer = nn.Sequential(
272
+ nn.Flatten(start_dim=-2), nn.Linear(nf, pred_dim)
273
+ )
274
+
275
+ def forward(self, x: Tensor):
276
+ """
277
+ Args:
278
+ x: [bs x nvars x d_model x n_patch]
279
+ output: [bs x nvars x pred_dim]
280
+ """
281
+ if self.individual:
282
+ x_out = []
283
+ for i, layer in enumerate(self.layers):
284
+ x_out.append(layer(x[:, i]))
285
+ x = torch.stack(x_out, dim=1)
286
+ return x
287
+ else:
288
+ return self.layer(x)
289
+
290
+
291
+ class _TSTiEncoderLayer(nn.Module):
292
+ def __init__(
293
+ self,
294
+ q_len,
295
+ d_model,
296
+ n_heads,
297
+ d_k=None,
298
+ d_v=None,
299
+ d_ff=256,
300
+ store_attn=False,
301
+ norm="BatchNorm",
302
+ attn_dropout=0,
303
+ dropout=0.0,
304
+ bias=True,
305
+ activation="gelu",
306
+ res_attention=False,
307
+ pre_norm=False,
308
+ ):
309
+ super().__init__()
310
+ assert (
311
+ not d_model % n_heads
312
+ ), f"d_model ({d_model}) must be divisible by n_heads ({n_heads})"
313
+ d_k = d_model // n_heads if d_k is None else d_k
314
+ d_v = d_model // n_heads if d_v is None else d_v
315
+
316
+ # Multi-Head attention
317
+ self.res_attention = res_attention
318
+ self.self_attn = _MultiheadAttention(
319
+ d_model,
320
+ n_heads,
321
+ d_k,
322
+ d_v,
323
+ attn_dropout=attn_dropout,
324
+ proj_dropout=dropout,
325
+ res_attention=res_attention,
326
+ )
327
+
328
+ # Add & Norm
329
+ self.dropout_attn = nn.Dropout(dropout)
330
+ if "batch" in norm.lower():
331
+ self.norm_attn = nn.Sequential(
332
+ Transpose(1, 2), nn.BatchNorm1d(d_model), Transpose(1, 2)
333
+ )
334
+ else:
335
+ self.norm_attn = nn.LayerNorm(d_model)
336
+
337
+ # Position-wise Feed-Forward
338
+ self.ff = nn.Sequential(
339
+ nn.Linear(d_model, d_ff, bias=bias),
340
+ get_act_fn(activation),
341
+ nn.Dropout(dropout),
342
+ nn.Linear(d_ff, d_model, bias=bias),
343
+ )
344
+
345
+ # Add & Norm
346
+ self.dropout_ffn = nn.Dropout(dropout)
347
+ if "batch" in norm.lower():
348
+ self.norm_ffn = nn.Sequential(
349
+ Transpose(1, 2), nn.BatchNorm1d(d_model), Transpose(1, 2)
350
+ )
351
+ else:
352
+ self.norm_ffn = nn.LayerNorm(d_model)
353
+
354
+ self.pre_norm = pre_norm
355
+ self.store_attn = store_attn
356
+
357
+ def forward(self, src: Tensor, prev: Optional[Tensor] = None):
358
+ """
359
+ Args:
360
+ src: [bs x q_len x d_model]
361
+ """
362
+
363
+ # Multi-Head attention sublayer
364
+ if self.pre_norm:
365
+ src = self.norm_attn(src)
366
+ ## Multi-Head attention
367
+ if self.res_attention:
368
+ src2, attn, scores = self.self_attn(src, src, src, prev)
369
+ else:
370
+ src2, attn = self.self_attn(src, src, src)
371
+ if self.store_attn:
372
+ self.attn = attn
373
+ ## Add & Norm
374
+ src = src + self.dropout_attn(
375
+ src2
376
+ ) # Add: residual connection with residual dropout
377
+ if not self.pre_norm:
378
+ src = self.norm_attn(src)
379
+
380
+ # Feed-forward sublayer
381
+ if self.pre_norm:
382
+ src = self.norm_ffn(src)
383
+ ## Position-wise Feed-Forward
384
+ src2 = self.ff(src)
385
+ ## Add & Norm
386
+ src = src + self.dropout_ffn(
387
+ src2
388
+ ) # Add: residual connection with residual dropout
389
+ if not self.pre_norm:
390
+ src = self.norm_ffn(src)
391
+
392
+ if self.res_attention:
393
+ return src, scores
394
+ else:
395
+ return src
396
+
397
+
398
+ class _TSTiEncoder(nn.Module): # i means channel-independent
399
+ def __init__(
400
+ self,
401
+ c_in,
402
+ patch_num,
403
+ patch_len,
404
+ n_layers=3,
405
+ d_model=128,
406
+ n_heads=16,
407
+ d_k=None,
408
+ d_v=None,
409
+ d_ff=256,
410
+ norm="BatchNorm",
411
+ attn_dropout=0.0,
412
+ dropout=0.0,
413
+ act="gelu",
414
+ store_attn=False,
415
+ res_attention=True,
416
+ pre_norm=False,
417
+ ):
418
+
419
+ super().__init__()
420
+
421
+ self.patch_num = patch_num
422
+ self.patch_len = patch_len
423
+
424
+ # Input encoding
425
+ q_len = patch_num
426
+ self.W_P = nn.Linear(
427
+ patch_len, d_model
428
+ ) # Eq 1: projection of feature vectors onto a d-dim vector space
429
+ self.seq_len = q_len
430
+
431
+ # Positional encoding
432
+ W_pos = torch.empty((q_len, d_model))
433
+ nn.init.uniform_(W_pos, -0.02, 0.02)
434
+ self.W_pos = nn.Parameter(W_pos)
435
+
436
+ # Residual dropout
437
+ self.dropout = nn.Dropout(dropout)
438
+
439
+ # Encoder
440
+ self.layers = nn.ModuleList(
441
+ [
442
+ _TSTiEncoderLayer(
443
+ q_len,
444
+ d_model,
445
+ n_heads=n_heads,
446
+ d_k=d_k,
447
+ d_v=d_v,
448
+ d_ff=d_ff,
449
+ norm=norm,
450
+ attn_dropout=attn_dropout,
451
+ dropout=dropout,
452
+ activation=act,
453
+ res_attention=res_attention,
454
+ pre_norm=pre_norm,
455
+ store_attn=store_attn,
456
+ )
457
+ for i in range(n_layers)
458
+ ]
459
+ )
460
+ self.res_attention = res_attention
461
+
462
+ def forward(self, x: Tensor):
463
+ """
464
+ Args:
465
+ x: [bs x nvars x patch_len x patch_num]
466
+ """
467
+
468
+ n_vars = x.shape[1]
469
+ # Input encoding
470
+ x = x.permute(0, 1, 3, 2) # x: [bs x nvars x patch_num x patch_len]
471
+ x = self.W_P(x) # x: [bs x nvars x patch_num x d_model]
472
+
473
+ x = torch.reshape(
474
+ x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3])
475
+ ) # x: [bs * nvars x patch_num x d_model]
476
+ x = self.dropout(x + self.W_pos) # x: [bs * nvars x patch_num x d_model]
477
+
478
+ # Encoder
479
+ if self.res_attention:
480
+ scores = None
481
+ for mod in self.layers:
482
+ x, scores = mod(x, prev=scores)
483
+ else:
484
+ for mod in self.layers:
485
+ x = mod(x)
486
+ x = torch.reshape(
487
+ x, (-1, n_vars, x.shape[-2], x.shape[-1])
488
+ ) # x: [bs x nvars x patch_num x d_model]
489
+ x = x.permute(0, 1, 3, 2) # x: [bs x nvars x d_model x patch_num]
490
+
491
+ return x
492
+
493
+
494
+ class _PatchTST_backbone(nn.Module):
495
+ def __init__(
496
+ self,
497
+ c_in,
498
+ seq_len,
499
+ pred_dim,
500
+ patch_len,
501
+ stride,
502
+ n_layers=3,
503
+ d_model=128,
504
+ n_heads=16,
505
+ d_k=None,
506
+ d_v=None,
507
+ d_ff=256,
508
+ norm="BatchNorm",
509
+ attn_dropout=0.0,
510
+ dropout=0.0,
511
+ act="gelu",
512
+ res_attention=True,
513
+ pre_norm=False,
514
+ store_attn=False,
515
+ padding_patch=True,
516
+ individual=False,
517
+ revin=True,
518
+ affine=True,
519
+ subtract_last=False,
520
+ ):
521
+
522
+ super().__init__()
523
+
524
+ self.revin = revin
525
+ self.revin_layer = RevIN(c_in, affine=affine, subtract_last=subtract_last)
526
+
527
+ self.patch_len = patch_len
528
+ self.stride = stride
529
+ self.padding_patch = padding_patch
530
+ patch_num = int((seq_len - patch_len) / stride + 1) + 1
531
+ self.patch_num = patch_num
532
+ self.padding_patch_layer = nn.ReplicationPad1d((stride, 0))
533
+
534
+ self.unfold = nn.Unfold(kernel_size=(1, patch_len), stride=stride)
535
+ self.patch_len = patch_len
536
+
537
+ self.backbone = _TSTiEncoder(
538
+ c_in,
539
+ patch_num=patch_num,
540
+ patch_len=patch_len,
541
+ n_layers=n_layers,
542
+ d_model=d_model,
543
+ n_heads=n_heads,
544
+ d_k=d_k,
545
+ d_v=d_v,
546
+ d_ff=d_ff,
547
+ attn_dropout=attn_dropout,
548
+ dropout=dropout,
549
+ act=act,
550
+ res_attention=res_attention,
551
+ pre_norm=pre_norm,
552
+ store_attn=store_attn,
553
+ )
554
+
555
+ # Head
556
+ self.head_nf = d_model * patch_num
557
+ self.n_vars = c_in
558
+ self.individual = individual
559
+ self.head = Flatten_Head(self.individual, self.n_vars, self.head_nf, pred_dim)
560
+
561
+ def forward(self, z: Tensor):
562
+ """
563
+ Args:
564
+ z: [bs x c_in x seq_len]
565
+ """
566
+
567
+ if self.revin:
568
+ z = self.revin_layer(z, torch.tensor(True, dtype=torch.bool))
569
+
570
+ z = self.padding_patch_layer(z)
571
+ b, c, s = z.size()
572
+ z = z.reshape(-1, 1, 1, s)
573
+ z = self.unfold(z)
574
+ z = z.permute(0, 2, 1).reshape(b, c, -1, self.patch_len).permute(0, 1, 3, 2)
575
+
576
+ z = self.backbone(z)
577
+ z = self.head(z)
578
+
579
+ if self.revin:
580
+ z = self.revin_layer(z, torch.tensor(False, dtype=torch.bool))
581
+ return z
582
+
583
+
584
+ class PatchTST(nn.Module, PyTorchModelHubMixin):
585
+ def __init__(
586
+ self,
587
+ c_in,
588
+ c_out,
589
+ seq_len,
590
+ pred_dim=None,
591
+ n_layers=2,
592
+ n_heads=8,
593
+ d_model=512,
594
+ d_ff=2048,
595
+ dropout=0.05,
596
+ attn_dropout=0.0,
597
+ patch_len=16,
598
+ stride=8,
599
+ padding_patch=True,
600
+ revin=True,
601
+ affine=False,
602
+ individual=False,
603
+ subtract_last=False,
604
+ decomposition=False,
605
+ kernel_size=25,
606
+ activation="gelu",
607
+ norm="BatchNorm",
608
+ pre_norm=False,
609
+ res_attention=True,
610
+ store_attn=False,
611
+ classification=False,
612
+ ):
613
+
614
+ super().__init__()
615
+
616
+ if pred_dim is None:
617
+ pred_dim = seq_len
618
+
619
+ self.decomposition = decomposition
620
+ if self.decomposition:
621
+ self.decomp_module = SeriesDecomposition(kernel_size)
622
+ self.model_trend = _PatchTST_backbone(
623
+ c_in=c_in,
624
+ seq_len=seq_len,
625
+ pred_dim=pred_dim,
626
+ patch_len=patch_len,
627
+ stride=stride,
628
+ n_layers=n_layers,
629
+ d_model=d_model,
630
+ n_heads=n_heads,
631
+ d_ff=d_ff,
632
+ norm=norm,
633
+ attn_dropout=attn_dropout,
634
+ dropout=dropout,
635
+ act=activation,
636
+ res_attention=res_attention,
637
+ pre_norm=pre_norm,
638
+ store_attn=store_attn,
639
+ padding_patch=padding_patch,
640
+ individual=individual,
641
+ revin=revin,
642
+ affine=affine,
643
+ subtract_last=subtract_last,
644
+ )
645
+ self.model_res = _PatchTST_backbone(
646
+ c_in=c_in,
647
+ seq_len=seq_len,
648
+ pred_dim=pred_dim,
649
+ patch_len=patch_len,
650
+ stride=stride,
651
+ n_layers=n_layers,
652
+ d_model=d_model,
653
+ n_heads=n_heads,
654
+ d_ff=d_ff,
655
+ norm=norm,
656
+ attn_dropout=attn_dropout,
657
+ dropout=dropout,
658
+ act=activation,
659
+ res_attention=res_attention,
660
+ pre_norm=pre_norm,
661
+ store_attn=store_attn,
662
+ padding_patch=padding_patch,
663
+ individual=individual,
664
+ revin=revin,
665
+ affine=affine,
666
+ subtract_last=subtract_last,
667
+ )
668
+ self.patch_num = self.model_trend.patch_num
669
+ else:
670
+ self.model = _PatchTST_backbone(
671
+ c_in=c_in,
672
+ seq_len=seq_len,
673
+ pred_dim=pred_dim,
674
+ patch_len=patch_len,
675
+ stride=stride,
676
+ n_layers=n_layers,
677
+ d_model=d_model,
678
+ n_heads=n_heads,
679
+ d_ff=d_ff,
680
+ norm=norm,
681
+ attn_dropout=attn_dropout,
682
+ dropout=dropout,
683
+ act=activation,
684
+ res_attention=res_attention,
685
+ pre_norm=pre_norm,
686
+ store_attn=store_attn,
687
+ padding_patch=padding_patch,
688
+ individual=individual,
689
+ revin=revin,
690
+ affine=affine,
691
+ subtract_last=subtract_last,
692
+ )
693
+ self.patch_num = self.model.patch_num
694
+ self.classification = classification
695
+
696
+ def forward(self, x):
697
+ if self.decomposition:
698
+ res_init, trend_init = self.decomp_module(x)
699
+ res = self.model_res(res_init)
700
+ trend = self.model_trend(trend_init)
701
+ x = res + trend
702
+ else:
703
+ x = self.model(x)
704
+
705
+ if self.classification:
706
+ x = x.squeeze(-2)
707
+ return x
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:36ca37c96811a3cd2f528d01f9203be172cc9d9ea38562d6cb30baf682c1f332
3
+ size 25337280