import torch import torch.nn as nn # ------------------------- # Timestep embeddings # ------------------------- class GaussianFourierProjection(nn.Module): """ Gaussian Fourier features for continuous time t in [0, 1]. Produces 2 * embed_dim features: [sin(W t), cos(W t)]. """ def __init__(self, embed_dim, scale): super().__init__() assert embed_dim % 2 == 0, "embed_dim must be even." self.embed_dim = embed_dim self.register_buffer("W", torch.randn(embed_dim // 2) * scale, persistent=False) # Fixed random frequencies def forward(self, t): # Ensure float t = t.float().unsqueeze(-1) # Broadcoast to [B, 1] angles = t * self.W # B, embed_dim // 2 return torch.cat([torch.sin(angles), torch.cos(angles)], dim=-1) class TimeEmbedding(nn.Module): def __init__(self, hidden_dim, fourier_dim, scale): super().__init__() assert fourier_dim % 2 == 0, "fourier_dim must be even for sine/cosine pairs." self.fourier = GaussianFourierProjection(fourier_dim, scale) self.mlp = nn.Sequential( nn.Linear(fourier_dim, hidden_dim), nn.SiLU(), nn.Linear(hidden_dim, hidden_dim), ) def forward(self, t): ft = self.fourier(t) # (B, fourier_dim) return self.mlp(ft) # (B, hidden_dim)