| |
| |
| |
| |
| |
| |
| |
|
|
| import math |
| import torch |
| from einops import rearrange |
|
|
| from .base import BaseModule |
|
|
|
|
| class Mish(BaseModule): |
| def forward(self, x): |
| return x * torch.tanh(torch.nn.functional.softplus(x)) |
|
|
|
|
| class Upsample(BaseModule): |
| def __init__(self, dim): |
| super(Upsample, self).__init__() |
| self.conv = torch.nn.ConvTranspose2d(dim, dim, 4, 2, 1) |
|
|
| def forward(self, x): |
| return self.conv(x) |
|
|
|
|
| class Downsample(BaseModule): |
| def __init__(self, dim): |
| super(Downsample, self).__init__() |
| self.conv = torch.nn.Conv2d(dim, dim, 3, 2, 1) |
|
|
| def forward(self, x): |
| return self.conv(x) |
|
|
|
|
| class Rezero(BaseModule): |
| def __init__(self, fn): |
| super(Rezero, self).__init__() |
| self.fn = fn |
| self.g = torch.nn.Parameter(torch.zeros(1)) |
|
|
| def forward(self, x): |
| return self.fn(x) * self.g |
|
|
|
|
| class Block(BaseModule): |
| def __init__(self, dim, dim_out, groups=8): |
| super(Block, self).__init__() |
| self.block = torch.nn.Sequential(torch.nn.Conv2d(dim, dim_out, 3, |
| padding=1), torch.nn.GroupNorm( |
| groups, dim_out), Mish()) |
|
|
| def forward(self, x): |
| output = self.block(x) |
| return output |
|
|
|
|
| class ResnetBlock(BaseModule): |
| def __init__(self, dim, dim_out, time_emb_dim, groups=8): |
| super(ResnetBlock, self).__init__() |
| self.mlp = torch.nn.Sequential(Mish(), torch.nn.Linear(time_emb_dim, |
| dim_out)) |
|
|
| self.block1 = Block(dim, dim_out, groups=groups) |
| self.block2 = Block(dim_out, dim_out, groups=groups) |
| if dim != dim_out: |
| self.res_conv = torch.nn.Conv2d(dim, dim_out, 1) |
| else: |
| self.res_conv = torch.nn.Identity() |
|
|
| def forward(self, x, time_emb): |
| h = self.block1(x) |
| h += self.mlp(time_emb).unsqueeze(-1).unsqueeze(-1) |
| h = self.block2(h) |
| output = h + self.res_conv(x) |
| return output |
|
|
|
|
| class LinearAttention(BaseModule): |
| def __init__(self, dim, heads=4, dim_head=32, q_norm=True): |
| super(LinearAttention, self).__init__() |
| self.heads = heads |
| hidden_dim = dim_head * heads |
| self.to_qkv = torch.nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) |
| self.to_out = torch.nn.Conv2d(hidden_dim, dim, 1) |
| self.q_norm = q_norm |
|
|
| def forward(self, x): |
| b, c, h, w = x.shape |
| qkv = self.to_qkv(x) |
| q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', |
| heads=self.heads, qkv=3) |
| k = k.softmax(dim=-1) |
| if self.q_norm: |
| q = q.softmax(dim=-2) |
|
|
| context = torch.einsum('bhdn,bhen->bhde', k, v) |
| out = torch.einsum('bhde,bhdn->bhen', context, q) |
| out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', |
| heads=self.heads, h=h, w=w) |
| return self.to_out(out) |
|
|
|
|
| class Residual(BaseModule): |
| def __init__(self, fn): |
| super(Residual, self).__init__() |
| self.fn = fn |
|
|
| def forward(self, x, *args, **kwargs): |
| output = self.fn(x, *args, **kwargs) + x |
| return output |
|
|
|
|
| def get_timestep_embedding( |
| timesteps: torch.Tensor, |
| embedding_dim: int, |
| flip_sin_to_cos: bool = False, |
| downscale_freq_shift: float = 1, |
| scale: float = 1, |
| max_period: int = 10000, |
| ): |
| """ |
| This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. |
| :param timesteps: a 1-D Tensor of N indices, one per batch element. |
| These may be fractional. |
| :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the |
| embeddings. :return: an [N x dim] Tensor of positional embeddings. |
| """ |
| assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" |
|
|
| half_dim = embedding_dim // 2 |
| exponent = -math.log(max_period) * torch.arange( |
| start=0, end=half_dim, dtype=torch.float32, device=timesteps.device |
| ) |
| exponent = exponent / (half_dim - downscale_freq_shift) |
|
|
| emb = torch.exp(exponent) |
| emb = timesteps[:, None].float() * emb[None, :] |
|
|
| |
| emb = scale * emb |
|
|
| |
| emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) |
|
|
| |
| if flip_sin_to_cos: |
| emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) |
|
|
| |
| if embedding_dim % 2 == 1: |
| emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) |
| return emb |
|
|
|
|
| class Timesteps(BaseModule): |
| def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float): |
| super().__init__() |
| self.num_channels = num_channels |
| self.flip_sin_to_cos = flip_sin_to_cos |
| self.downscale_freq_shift = downscale_freq_shift |
|
|
| def forward(self, timesteps): |
| t_emb = get_timestep_embedding( |
| timesteps, |
| self.num_channels, |
| flip_sin_to_cos=self.flip_sin_to_cos, |
| downscale_freq_shift=self.downscale_freq_shift, |
| ) |
| return t_emb |
|
|
|
|
| class PitchPosEmb(BaseModule): |
| def __init__(self, dim, flip_sin_to_cos=False, downscale_freq_shift=0): |
| super(PitchPosEmb, self).__init__() |
| self.dim = dim |
| self.flip_sin_to_cos = flip_sin_to_cos |
| self.downscale_freq_shift = downscale_freq_shift |
|
|
| def forward(self, x): |
| |
| b, l = x.shape |
| x = rearrange(x, 'b l -> (b l)') |
| emb = get_timestep_embedding( |
| x, |
| self.dim, |
| flip_sin_to_cos=self.flip_sin_to_cos, |
| downscale_freq_shift=self.downscale_freq_shift, |
| ) |
| emb = rearrange(emb, '(b l) d -> b d l', b=b, l=l) |
| return emb |
|
|
|
|
| class TimbreBlock(BaseModule): |
| def __init__(self, out_dim): |
| super(TimbreBlock, self).__init__() |
| base_dim = out_dim // 4 |
|
|
| self.block11 = torch.nn.Sequential(torch.nn.Conv2d(1, 2 * base_dim, |
| 3, 1, 1), |
| torch.nn.InstanceNorm2d(2 * base_dim, affine=True), |
| torch.nn.GLU(dim=1)) |
| self.block12 = torch.nn.Sequential(torch.nn.Conv2d(base_dim, 2 * base_dim, |
| 3, 1, 1), |
| torch.nn.InstanceNorm2d(2 * base_dim, affine=True), |
| torch.nn.GLU(dim=1)) |
| self.block21 = torch.nn.Sequential(torch.nn.Conv2d(base_dim, 4 * base_dim, |
| 3, 1, 1), |
| torch.nn.InstanceNorm2d(4 * base_dim, affine=True), |
| torch.nn.GLU(dim=1)) |
| self.block22 = torch.nn.Sequential(torch.nn.Conv2d(2 * base_dim, 4 * base_dim, |
| 3, 1, 1), |
| torch.nn.InstanceNorm2d(4 * base_dim, affine=True), |
| torch.nn.GLU(dim=1)) |
| self.block31 = torch.nn.Sequential(torch.nn.Conv2d(2 * base_dim, 8 * base_dim, |
| 3, 1, 1), |
| torch.nn.InstanceNorm2d(8 * base_dim, affine=True), |
| torch.nn.GLU(dim=1)) |
| self.block32 = torch.nn.Sequential(torch.nn.Conv2d(4 * base_dim, 8 * base_dim, |
| 3, 1, 1), |
| torch.nn.InstanceNorm2d(8 * base_dim, affine=True), |
| torch.nn.GLU(dim=1)) |
| self.final_conv = torch.nn.Conv2d(4 * base_dim, out_dim, 1) |
|
|
| def forward(self, x): |
| y = self.block11(x) |
| y = self.block12(y) |
| y = self.block21(y) |
| y = self.block22(y) |
| y = self.block31(y) |
| y = self.block32(y) |
| y = self.final_conv(y) |
|
|
| return y.sum((2, 3)) / (y.shape[2] * y.shape[3]) |