| """Self-contained TimesFM 2.x wrapper compatible with the TimesFM interface.""" |
|
|
| from __future__ import annotations |
|
|
| import dataclasses |
| import math |
|
|
| import torch |
| import torch.nn.functional as F |
| from torch import nn |
|
|
| try: |
| from safetensors.torch import load_file as _load_safetensors |
| except ImportError: |
| _load_safetensors = None |
|
|
| _TOLERANCE = 1e-6 |
|
|
|
|
| @dataclasses.dataclass(frozen=True) |
| class ResidualBlockConfig: |
| input_dims: int |
| hidden_dims: int |
| output_dims: int |
| use_bias: bool |
| activation: str |
|
|
|
|
| @dataclasses.dataclass(frozen=True) |
| class TransformerConfig: |
| model_dims: int |
| hidden_dims: int |
| num_heads: int |
| attention_norm: str |
| feedforward_norm: str |
| qk_norm: str |
| use_bias: bool |
| use_rotary_position_embeddings: bool |
| ff_activation: str |
| fuse_qkv: bool |
|
|
|
|
| @dataclasses.dataclass(frozen=True) |
| class StackedTransformersConfig: |
| num_layers: int |
| transformer: TransformerConfig |
|
|
|
|
| @dataclasses.dataclass(frozen=True) |
| class TimesFM2Definition: |
| """Framework-agnostic description of TimesFM 2.5 (200M parameters).""" |
|
|
| context_limit: int = 16384 |
| input_patch_len: int = 32 |
| output_patch_len: int = 128 |
| output_quantile_len: int = 1024 |
| quantiles: tuple[float, ...] = ( |
| 0.1, |
| 0.2, |
| 0.3, |
| 0.4, |
| 0.5, |
| 0.6, |
| 0.7, |
| 0.8, |
| 0.9, |
| ) |
| decode_index: int = 5 |
| tokenizer: ResidualBlockConfig = dataclasses.field( |
| default_factory=lambda: ResidualBlockConfig( |
| input_dims=64, |
| hidden_dims=1280, |
| output_dims=1280, |
| use_bias=True, |
| activation="swish", |
| ) |
| ) |
| stacked_transformers: StackedTransformersConfig = dataclasses.field( |
| default_factory=lambda: StackedTransformersConfig( |
| num_layers=20, |
| transformer=TransformerConfig( |
| model_dims=1280, |
| hidden_dims=1280, |
| num_heads=16, |
| attention_norm="rms", |
| feedforward_norm="rms", |
| qk_norm="rms", |
| use_bias=False, |
| use_rotary_position_embeddings=True, |
| ff_activation="swish", |
| fuse_qkv=True, |
| ), |
| ) |
| ) |
| output_projection_point: ResidualBlockConfig = dataclasses.field( |
| default_factory=lambda: ResidualBlockConfig( |
| input_dims=1280, |
| hidden_dims=1280, |
| output_dims=1280, |
| use_bias=False, |
| activation="swish", |
| ) |
| ) |
| output_projection_quantiles: ResidualBlockConfig = dataclasses.field( |
| default_factory=lambda: ResidualBlockConfig( |
| input_dims=1280, |
| hidden_dims=1280, |
| output_dims=10240, |
| use_bias=False, |
| activation="swish", |
| ) |
| ) |
|
|
|
|
| @dataclasses.dataclass(frozen=False) |
| class DecodeCache: |
| next_index: torch.Tensor |
| num_masked: torch.Tensor |
| key: torch.Tensor |
| value: torch.Tensor |
|
|
|
|
| def update_running_stats( |
| n: torch.Tensor, |
| mu: torch.Tensor, |
| sigma: torch.Tensor, |
| x: torch.Tensor, |
| mask: torch.Tensor, |
| ) -> tuple[tuple[torch.Tensor, torch.Tensor, torch.Tensor], tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: |
| """Updates reversible normalization statistics for a new patch.""" |
| is_legit = torch.logical_not(mask) |
| inc_n = torch.sum(is_legit.to(x.dtype), dim=-1) |
|
|
| inc_mu_numerator = torch.sum(x * is_legit, dim=-1) |
| inc_n_safe = torch.where(inc_n == 0, 1.0, inc_n) |
| inc_mu = inc_mu_numerator / inc_n_safe |
| inc_mu = torch.where(inc_n == 0, 0.0, inc_mu) |
|
|
| inc_var_numerator = torch.sum(((x - inc_mu.unsqueeze(-1)) ** 2) * is_legit, dim=-1) |
| inc_var = inc_var_numerator / inc_n_safe |
| inc_var = torch.where(inc_n == 0, 0.0, inc_var) |
| inc_sigma = torch.sqrt(inc_var) |
|
|
| new_n = n + inc_n |
| new_n_safe = torch.where(new_n == 0, 1.0, new_n) |
|
|
| new_mu = (n * mu + inc_mu * inc_n) / new_n_safe |
| new_mu = torch.where(new_n == 0, 0.0, new_mu) |
|
|
| term1 = n * sigma.pow(2) |
| term2 = inc_n * inc_sigma.pow(2) |
| term3 = n * (mu - new_mu).pow(2) |
| term4 = inc_n * (inc_mu - new_mu).pow(2) |
|
|
| new_var = (term1 + term2 + term3 + term4) / new_n_safe |
| new_var = torch.where(new_n == 0, 0.0, new_var) |
| new_sigma = torch.sqrt(torch.clamp(new_var, min=0.0)) |
|
|
| return (new_n, new_mu, new_sigma), (new_n, new_mu, new_sigma) |
|
|
|
|
| def revin(x: torch.Tensor, mu: torch.Tensor, sigma: torch.Tensor, reverse: bool = False) -> torch.Tensor: |
| """Reversible instance normalization.""" |
| if len(mu.shape) == len(x.shape) - 1: |
| mu = mu[..., None] |
| sigma = sigma[..., None] |
| elif len(mu.shape) == len(x.shape) - 2: |
| mu = mu[..., None, None] |
| sigma = sigma[..., None, None] |
|
|
| if reverse: |
| return x * sigma + mu |
|
|
| sigma_safe = torch.where(sigma < _TOLERANCE, torch.ones_like(sigma), sigma) |
| return (x - mu) / sigma_safe |
|
|
|
|
| class ResidualBlock(nn.Module): |
| """Residual block composed of a pair of linear layers.""" |
|
|
| def __init__(self, config: ResidualBlockConfig): |
| super().__init__() |
| self.activation = self._resolve_activation(config.activation) |
| self.hidden_layer = nn.Linear(config.input_dims, config.hidden_dims, bias=config.use_bias) |
| self.output_layer = nn.Linear(config.hidden_dims, config.output_dims, bias=config.use_bias) |
| self.residual_layer = nn.Linear(config.input_dims, config.output_dims, bias=config.use_bias) |
|
|
| @staticmethod |
| def _resolve_activation(name: str) -> nn.Module: |
| if name == "relu": |
| return nn.ReLU() |
| if name == "swish": |
| return nn.SiLU() |
| if name == "none": |
| return nn.Identity() |
| raise ValueError(f"Unsupported activation: {name}") |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| hidden = self.activation(self.hidden_layer(x)) |
| return self.output_layer(hidden) + self.residual_layer(x) |
|
|
|
|
| class RMSNorm(nn.Module): |
| """Root-mean-square normalization.""" |
|
|
| def __init__(self, num_features: int, epsilon: float = 1e-6): |
| super().__init__() |
| self.scale = nn.Parameter(torch.zeros(num_features)) |
| self.epsilon = epsilon |
|
|
| def forward(self, inputs: torch.Tensor) -> torch.Tensor: |
| var = torch.mean(torch.square(inputs), dim=-1, keepdim=True) |
| normed_inputs = inputs * torch.rsqrt(var + self.epsilon) |
| return normed_inputs * self.scale |
|
|
|
|
| def make_attn_mask( |
| query_length: int, |
| num_all_masked_kv: torch.Tensor, |
| query_index_offset: torch.Tensor | None = None, |
| kv_length: int = 0, |
| ) -> torch.Tensor: |
| """Creates a causal mask consistent with cached decoding.""" |
| if kv_length == 0: |
| kv_length = query_length |
|
|
| q_index = torch.arange(query_length, device=num_all_masked_kv.device)[None, None, :, None] |
| if query_index_offset is not None: |
| q_index = q_index + query_index_offset[:, None, None, None] |
| kv_index = torch.arange(kv_length, device=num_all_masked_kv.device)[None, None, None, :] |
|
|
| return torch.logical_and(q_index >= kv_index, kv_index >= num_all_masked_kv[:, None, None, None]) |
|
|
|
|
| class RotaryPositionalEmbedding(nn.Module): |
| """Applies rotary position embeddings to query/key projections.""" |
|
|
| def __init__(self, embedding_dims: int, min_timescale: float = 1.0, max_timescale: float = 10000.0): |
| super().__init__() |
| self.embedding_dims = embedding_dims |
| self.min_timescale = min_timescale |
| self.max_timescale = max_timescale |
|
|
| def forward(self, inputs: torch.Tensor, position: torch.Tensor | None = None) -> torch.Tensor: |
| if self.embedding_dims != inputs.shape[-1]: |
| raise ValueError("Rotary embedding dimension must equal the head dimension.") |
|
|
| half_dim = self.embedding_dims // 2 |
| fraction = 2 * torch.arange(half_dim, device=inputs.device) / self.embedding_dims |
| timescale = (self.min_timescale * (self.max_timescale / self.min_timescale) ** fraction).to(inputs.device) |
|
|
| if position is None: |
| position = torch.arange(inputs.shape[1], dtype=torch.float32, device=inputs.device)[None, :] |
|
|
| if len(inputs.shape) == 4: |
| position = position[..., None, None] |
| timescale = timescale[None, None, None, :] |
| elif len(inputs.shape) == 3: |
| position = position[..., None] |
| timescale = timescale[None, None, :] |
| else: |
| raise ValueError("Expected rank-3 or rank-4 tensor for rotary embeddings.") |
|
|
| sinusoid = position / timescale |
| sin = torch.sin(sinusoid) |
| cos = torch.cos(sinusoid) |
|
|
| first_half, second_half = torch.chunk(inputs, 2, dim=-1) |
| rotated_first = first_half * cos - second_half * sin |
| rotated_second = second_half * cos + first_half * sin |
| return torch.cat([rotated_first, rotated_second], dim=-1) |
|
|
|
|
| class PerDimScale(nn.Module): |
| """Learned per-dimension scaling used prior to attention.""" |
|
|
| def __init__(self, num_dims: int): |
| super().__init__() |
| self.num_dims = num_dims |
| self.per_dim_scale = nn.Parameter(torch.zeros(num_dims)) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| scale_factor = 1.442695041 / math.sqrt(self.num_dims) * F.softplus(self.per_dim_scale) |
| return x * scale_factor |
|
|
|
|
| class MultiHeadAttention(nn.Module): |
| """Multi-head attention supporting fused QKV projections and caching.""" |
|
|
| def __init__( |
| self, |
| num_heads: int, |
| in_features: int, |
| *, |
| use_per_dim_scale: bool = True, |
| use_rotary_position_embeddings: bool = True, |
| use_bias: bool = False, |
| attention_fn=F.scaled_dot_product_attention, |
| qk_norm: str = "rms", |
| fuse_qkv: bool = False, |
| ): |
| super().__init__() |
| self.num_heads = num_heads |
| self.in_features = in_features |
| self.head_dim = in_features // num_heads |
| self.use_bias = use_bias |
| self.attention_fn = attention_fn |
| self.qk_norm = qk_norm |
| self.fuse_qkv = fuse_qkv |
|
|
| if in_features % num_heads != 0: |
| raise ValueError(f"Model dimension {in_features} must be divisible by {num_heads} heads.") |
|
|
| if fuse_qkv: |
| self.qkv_proj = nn.Linear(in_features, 3 * in_features, bias=use_bias) |
| else: |
| self.query = nn.Linear(in_features, in_features, bias=use_bias) |
| self.key = nn.Linear(in_features, in_features, bias=use_bias) |
| self.value = nn.Linear(in_features, in_features, bias=use_bias) |
|
|
| self.out = nn.Linear(in_features, in_features, bias=use_bias) |
|
|
| if qk_norm == "rms": |
| self.query_ln = RMSNorm(self.head_dim) |
| self.key_ln = RMSNorm(self.head_dim) |
| else: |
| self.query_ln = nn.Identity() |
| self.key_ln = nn.Identity() |
|
|
| self.use_rotary_position_embeddings = use_rotary_position_embeddings |
| if use_rotary_position_embeddings: |
| self.rotary_position_embedding = RotaryPositionalEmbedding(self.head_dim) |
|
|
| self.use_per_dim_scale = use_per_dim_scale |
| if use_per_dim_scale: |
| self.per_dim_scale = PerDimScale(self.head_dim) |
|
|
| def forward( |
| self, |
| inputs_q: torch.Tensor, |
| *, |
| decode_cache: DecodeCache | None = None, |
| patch_mask: torch.Tensor | None = None, |
| ) -> tuple[torch.Tensor, DecodeCache | None]: |
| batch, num_patches, _ = inputs_q.shape |
| if patch_mask is None: |
| patch_mask = torch.zeros(batch, num_patches, dtype=torch.bool, device=inputs_q.device) |
|
|
| if self.fuse_qkv: |
| qkv = self.qkv_proj(inputs_q) |
| query, key, value = torch.chunk(qkv, 3, dim=-1) |
| query = query.view(batch, num_patches, self.num_heads, self.head_dim) |
| key = key.view(batch, num_patches, self.num_heads, self.head_dim) |
| value = value.view(batch, num_patches, self.num_heads, self.head_dim) |
| else: |
| query = self.query(inputs_q).view(batch, num_patches, self.num_heads, self.head_dim) |
| key = self.key(inputs_q).view(batch, num_patches, self.num_heads, self.head_dim) |
| value = self.value(inputs_q).view(batch, num_patches, self.num_heads, self.head_dim) |
|
|
| if decode_cache is None: |
| num_masked = torch.sum(patch_mask.to(torch.int32), dim=-1) |
| next_index = torch.zeros_like(num_masked, dtype=torch.int32) |
| else: |
| num_masked = torch.sum(patch_mask.to(torch.int32), dim=-1) + decode_cache.num_masked |
| next_index = decode_cache.next_index.clone() |
|
|
| if self.use_rotary_position_embeddings: |
| position = ( |
| torch.arange(num_patches, device=inputs_q.device)[None, :] |
| + next_index[:, None] |
| - num_masked[:, None] |
| ) |
| query = self.rotary_position_embedding(query, position) |
| key = self.rotary_position_embedding(key, position) |
|
|
| query = self.query_ln(query) |
| key = self.key_ln(key) |
|
|
| if self.use_per_dim_scale: |
| query = self.per_dim_scale(query) |
|
|
| if decode_cache is not None: |
| _, cache_size, _, _ = decode_cache.value.shape |
| start = decode_cache.next_index[0] |
| end = start + num_patches |
|
|
| decode_cache.key[:, start:end] = key |
| decode_cache.value[:, start:end] = value |
|
|
| key = decode_cache.key |
| value = decode_cache.value |
| decode_cache.next_index += num_patches |
| decode_cache.num_masked = num_masked |
| attn_mask = make_attn_mask( |
| query_length=num_patches, |
| num_all_masked_kv=num_masked, |
| query_index_offset=next_index, |
| kv_length=cache_size, |
| ) |
| else: |
| attn_mask = make_attn_mask(query_length=num_patches, num_all_masked_kv=num_masked) |
|
|
| attn_output = F.scaled_dot_product_attention( |
| query.permute(0, 2, 1, 3), |
| key.permute(0, 2, 1, 3), |
| value.permute(0, 2, 1, 3), |
| attn_mask=attn_mask, |
| scale=1.0, |
| ) |
| attn_output = attn_output.permute(0, 2, 1, 3) |
| attn_output = attn_output.reshape(batch, num_patches, self.in_features) |
| return self.out(attn_output), decode_cache |
|
|
|
|
| class Transformer(nn.Module): |
| """Transformer block used by TimesFM.""" |
|
|
| def __init__(self, config: TransformerConfig): |
| super().__init__() |
| if config.attention_norm != "rms" or config.feedforward_norm != "rms": |
| raise ValueError("Only RMS normalization is supported.") |
|
|
| self.pre_attn_ln = RMSNorm(config.model_dims) |
| self.post_attn_ln = RMSNorm(config.model_dims) |
| self.attn = MultiHeadAttention( |
| num_heads=config.num_heads, |
| in_features=config.model_dims, |
| use_per_dim_scale=True, |
| use_rotary_position_embeddings=config.use_rotary_position_embeddings, |
| qk_norm=config.qk_norm, |
| fuse_qkv=config.fuse_qkv, |
| ) |
|
|
| self.pre_ff_ln = RMSNorm(config.model_dims) |
| self.post_ff_ln = RMSNorm(config.model_dims) |
| self.ff0 = nn.Linear(config.model_dims, config.hidden_dims, bias=config.use_bias) |
| self.ff1 = nn.Linear(config.hidden_dims, config.model_dims, bias=config.use_bias) |
| self.activation = ResidualBlock._resolve_activation(config.ff_activation) |
|
|
| def forward( |
| self, |
| input_embeddings: torch.Tensor, |
| patch_mask: torch.Tensor, |
| decode_cache: DecodeCache | None = None, |
| ) -> tuple[torch.Tensor, DecodeCache | None]: |
| attn_output, decode_cache = self.attn( |
| inputs_q=self.pre_attn_ln(input_embeddings), |
| decode_cache=decode_cache, |
| patch_mask=patch_mask, |
| ) |
| attn_output = self.post_attn_ln(attn_output) + input_embeddings |
| feedforward = self.ff1(self.activation(self.ff0(self.pre_ff_ln(attn_output)))) |
| output_embeddings = self.post_ff_ln(feedforward) + attn_output |
| return output_embeddings, decode_cache |
|
|
|
|
| class TimesFM2Core(nn.Module): |
| """Core TimesFM 2.x backbone without external dependencies.""" |
|
|
| def __init__(self, definition: TimesFM2Definition | None = None): |
| super().__init__() |
| self.config = definition or TimesFM2Definition() |
|
|
| self.p = self.config.input_patch_len |
| self.o = self.config.output_patch_len |
| self.os = self.config.output_quantile_len |
| self.m = self.o // self.p |
| self.x = self.config.stacked_transformers.num_layers |
| self.h = self.config.stacked_transformers.transformer.num_heads |
| self.md = self.config.stacked_transformers.transformer.model_dims |
| self.hd = self.md // self.h |
| self.q = len(self.config.quantiles) + 1 |
| self.aridx = self.config.decode_index |
|
|
| self.tokenizer = ResidualBlock(self.config.tokenizer) |
| self.stacked_xf = nn.ModuleList( |
| [Transformer(self.config.stacked_transformers.transformer) for _ in range(self.x)] |
| ) |
| self.output_projection_point = ResidualBlock(self.config.output_projection_point) |
| self.output_projection_quantiles = ResidualBlock(self.config.output_projection_quantiles) |
|
|
| def load_safetensors(self, path: str, strict: bool = True) -> None: |
| if _load_safetensors is None: |
| raise ImportError("Install safetensors to load TimesFM2 checkpoints.") |
| tensors = _load_safetensors(path) |
| self.load_state_dict(tensors, strict=strict) |
| self.eval() |
|
|
| def forward( |
| self, |
| inputs: torch.Tensor, |
| masks: torch.Tensor, |
| decode_caches: list[DecodeCache] | None = None, |
| ) -> tuple[tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], list[DecodeCache]]: |
| tokenizer_inputs = torch.cat([inputs, masks.to(inputs.dtype)], dim=-1) |
| input_embeddings = self.tokenizer(tokenizer_inputs) |
|
|
| if decode_caches is None: |
| decode_caches = [None] * self.x |
|
|
| output_embeddings = input_embeddings |
| new_decode_caches: list[DecodeCache] = [] |
| for layer, cache in zip(self.stacked_xf, decode_caches): |
| output_embeddings, new_cache = layer(output_embeddings, masks[..., -1], cache) |
| new_decode_caches.append(new_cache) |
|
|
| output_ts = self.output_projection_point(output_embeddings) |
| output_quantile_spread = self.output_projection_quantiles(output_embeddings) |
| return (input_embeddings, output_embeddings, output_ts, output_quantile_spread), new_decode_caches |
|
|
| def decode( |
| self, |
| horizon: int, |
| inputs: torch.Tensor, |
| masks: torch.Tensor, |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]: |
| """Autoregressively decodes a batch of sequences.""" |
| batch_size, context = inputs.shape |
| num_decode_steps = (horizon - 1) // self.o |
| num_input_patches = context // self.p |
| use_cache = not torch.is_grad_enabled() |
|
|
| patched_inputs = torch.reshape(inputs, (batch_size, -1, self.p)) |
| patched_masks = torch.reshape(masks, (batch_size, -1, self.p)) |
|
|
| n = torch.zeros(batch_size, device=inputs.device, dtype=inputs.dtype) |
| mu = torch.zeros(batch_size, device=inputs.device, dtype=inputs.dtype) |
| sigma = torch.zeros(batch_size, device=inputs.device, dtype=inputs.dtype) |
| patch_mu: list[torch.Tensor] = [] |
| patch_sigma: list[torch.Tensor] = [] |
| for i in range(num_input_patches): |
| (n, mu, sigma), _ = update_running_stats(n, mu, sigma, patched_inputs[:, i], patched_masks[:, i]) |
| patch_mu.append(mu) |
| patch_sigma.append(sigma) |
|
|
| last_n, last_mu, last_sigma = n, mu, sigma |
| context_mu = torch.stack(patch_mu, dim=1) |
| context_sigma = torch.stack(patch_sigma, dim=1) |
|
|
| decode_caches: list[DecodeCache] | None |
| if use_cache: |
| decode_cache_size = num_input_patches + num_decode_steps * self.m |
| decode_caches = [ |
| DecodeCache( |
| next_index=torch.zeros(batch_size, dtype=torch.int32, device=inputs.device), |
| num_masked=torch.zeros(batch_size, dtype=torch.int32, device=inputs.device), |
| key=torch.zeros( |
| batch_size, |
| decode_cache_size, |
| self.h, |
| self.hd, |
| device=inputs.device, |
| dtype=inputs.dtype, |
| ), |
| value=torch.zeros( |
| batch_size, |
| decode_cache_size, |
| self.h, |
| self.hd, |
| device=inputs.device, |
| dtype=inputs.dtype, |
| ), |
| ) |
| for _ in range(self.x) |
| ] |
| else: |
| decode_caches = None |
|
|
| normed_inputs = revin(patched_inputs, context_mu, context_sigma, reverse=False) |
| normed_inputs = torch.where(patched_masks, torch.zeros((), device=inputs.device, dtype=inputs.dtype), normed_inputs) |
| (_, _, normed_outputs, normed_quantile_spread), decode_caches = self(normed_inputs, patched_masks, decode_caches) |
|
|
| renormed_outputs = torch.reshape( |
| revin(normed_outputs, context_mu, context_sigma, reverse=True), |
| (batch_size, -1, self.o, self.q), |
| ) |
| renormed_quantile_spread = torch.reshape( |
| revin(normed_quantile_spread, context_mu, context_sigma, reverse=True), |
| (batch_size, -1, self.os, self.q), |
| )[:, -1, ...] |
|
|
| ar_outputs: list[torch.Tensor] = [] |
| last_renormed_output = renormed_outputs[:, -1, :, self.aridx] |
|
|
| for _ in range(num_decode_steps): |
| new_patched_input = torch.reshape(last_renormed_output, (batch_size, self.m, self.p)) |
| new_mask = torch.zeros_like(new_patched_input, dtype=torch.bool) |
|
|
| n, mu, sigma = last_n, last_mu, last_sigma |
| new_mus: list[torch.Tensor] = [] |
| new_sigmas: list[torch.Tensor] = [] |
| for i in range(self.m): |
| (n, mu, sigma), _ = update_running_stats(n, mu, sigma, new_patched_input[:, i], new_mask[:, i]) |
| new_mus.append(mu) |
| new_sigmas.append(sigma) |
| last_n, last_mu, last_sigma = n, mu, sigma |
| new_mu = torch.stack(new_mus, dim=1) |
| new_sigma = torch.stack(new_sigmas, dim=1) |
|
|
| new_normed_input = revin(new_patched_input, new_mu, new_sigma, reverse=False) |
| (_, _, new_normed_output, _), decode_caches = self(new_normed_input, new_mask, decode_caches) |
|
|
| new_renormed_output = torch.reshape( |
| revin(new_normed_output, new_mu, new_sigma, reverse=True), |
| (batch_size, self.m, self.o, self.q), |
| ) |
| ar_outputs.append(new_renormed_output[:, -1, ...]) |
| last_renormed_output = new_renormed_output[:, -1, :, self.aridx] |
|
|
| ar_renormed_outputs = torch.stack(ar_outputs, dim=1) if num_decode_steps > 0 else None |
|
|
| return renormed_outputs, renormed_quantile_spread, ar_renormed_outputs |
|
|
|
|
| class TimesFM2(nn.Module): |
| """High-level TimesFM 2.x wrapper mirroring the TimesFM interface.""" |
|
|
| def __init__(self, lookback: int = 512, lookahead: int = 96): |
| super().__init__() |
| self.lookback = lookback |
| self.lookahead = lookahead |
| self.core = TimesFM2Core() |
|
|
| if lookback > self.core.config.context_limit: |
| raise ValueError( |
| f"lookback ({lookback}) exceeds maximum context limit ({self.core.config.context_limit})." |
| ) |
|
|
| def load_state_dict(self, state_dict, strict: bool = True): |
| return self.core.load_state_dict(state_dict, strict=strict) |
|
|
| def state_dict(self, *args, **kwargs): |
| return self.core.state_dict(*args, **kwargs) |
|
|
| def load_safetensors(self, path: str, strict: bool = True) -> None: |
| self.core.load_safetensors(path, strict=strict) |
|
|
| def _prepare_inputs(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: |
| if x.shape[1] < self.lookback: |
| raise ValueError(f"Expected at least {self.lookback} context steps, received {x.shape[1]}.") |
| context = x[:, -self.lookback:] |
| pad_len = (-context.shape[1]) % self.core.p |
| if pad_len > 0: |
| context = F.pad(context, (pad_len, 0)) |
| pad_mask = torch.ones(context.shape[0], pad_len, dtype=torch.bool, device=context.device) |
| mask = torch.cat( |
| [pad_mask, torch.zeros(context.shape[0], self.lookback, dtype=torch.bool, device=context.device)], |
| dim=1, |
| ) |
| else: |
| mask = torch.zeros_like(context, dtype=torch.bool) |
|
|
| if context.shape[1] > self.core.config.context_limit: |
| context = context[:, -self.core.config.context_limit :] |
| mask = mask[:, -self.core.config.context_limit :] |
|
|
| return context, mask |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| *, |
| return_quantiles: bool = False, |
| ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: |
| if x.dim() != 2: |
| raise ValueError(f"Expected input tensor of shape (batch, time), received {tuple(x.shape)}.") |
|
|
| inputs, mask = self._prepare_inputs(x.to(dtype=torch.float32)) |
| renormed_outputs, _, ar_outputs = self.core.decode(self.lookahead, inputs, mask) |
| batch_size = inputs.shape[0] |
|
|
| to_cat = [renormed_outputs[:, -1, ...]] |
| if ar_outputs is not None: |
| to_cat.append(ar_outputs.reshape(batch_size, -1, self.core.q)) |
| full_forecast = torch.cat(to_cat, dim=1)[:, : self.lookahead, :] |
|
|
| point_forecast = full_forecast[..., self.core.aridx] |
| if return_quantiles: |
| return point_forecast, full_forecast |
| return point_forecast |
|
|
|
|
| __all__ = ["TimesFM2", "TimesFM2Core", "TimesFM2Definition"] |
|
|