| | import math |
| |
|
| | import torch |
| | import torch.nn as nn |
| |
|
| |
|
| | class PhaseModulatedFourierEmbedder(torch.nn.Module): |
| | def __init__( |
| | self, |
| | num_freqs: int, |
| | input_dim: int = 3, |
| | ): |
| | """ |
| | Initializes the PhaseModulatedFourierEmbedder class. |
| | Args: |
| | num_freqs (int): The number of frequencies to be used. |
| | input_dim (int, optional): The dimension of the input. Defaults to 3. |
| | Attributes: |
| | weight (torch.nn.Parameter): The weight parameter initialized with random values. |
| | carrier (torch.Tensor): The carrier frequencies calculated based on the Nyquist-Shannon sampling theorem. |
| | out_dim (int): The output dimension calculated based on the input dimension and number of frequencies. |
| | """ |
| |
|
| | super().__init__() |
| |
|
| | self.weight = nn.Parameter( |
| | torch.randn(input_dim, num_freqs) * math.sqrt(0.5 * num_freqs) |
| | ) |
| |
|
| | |
| | carrier = (num_freqs / 8) ** torch.linspace(1, 0, num_freqs) |
| | carrier = (carrier + torch.linspace(0, 1, num_freqs)) * 2 * torch.pi |
| | self.register_buffer("carrier", carrier, persistent=False) |
| |
|
| | self.out_dim = input_dim * (num_freqs * 2 + 1) |
| |
|
| | def forward(self, x): |
| | """ |
| | Perform the forward pass of the embedder model. |
| | Args: |
| | x (torch.Tensor): Input tensor of shape (batch_size, ..., input_dim). |
| | Returns: |
| | torch.Tensor: Output tensor of shape (batch_size, ..., output_dim) where |
| | output_dim = input_dim + 2 * input_dim. |
| | """ |
| |
|
| | m = x.float().unsqueeze(-1) |
| | fm = (m * self.weight).view(*x.shape[:-1], -1) |
| | pm = (m * 0.5 * torch.pi + self.carrier).view(*x.shape[:-1], -1) |
| | embedding = torch.cat([x, fm.cos() + pm.cos(), fm.sin() + pm.sin()], dim=-1) |
| |
|
| | return embedding |
| |
|