| import torch |
|
|
| from einops import rearrange |
| from torch import nn |
|
|
| from .blocks import AdaRMSNorm |
| from .transformer import Attention, FeedForward, RotaryEmbedding, LayerNorm |
|
|
| def checkpoint(function, *args, **kwargs): |
| kwargs.setdefault("use_reentrant", False) |
| return torch.utils.checkpoint.checkpoint(function, *args, **kwargs) |
|
|
| |
| class ContinuousLocalTransformer(nn.Module): |
| def __init__( |
| self, |
| *, |
| dim, |
| depth, |
| dim_in = None, |
| dim_out = None, |
| causal = False, |
| local_attn_window_size = 64, |
| heads = 8, |
| ff_mult = 2, |
| cond_dim = 0, |
| cross_attn_cond_dim = 0, |
| **kwargs |
| ): |
| super().__init__() |
| |
| dim_head = dim//heads |
|
|
| self.layers = nn.ModuleList([]) |
|
|
| self.project_in = nn.Linear(dim_in, dim) if dim_in is not None else nn.Identity() |
|
|
| self.project_out = nn.Linear(dim, dim_out) if dim_out is not None else nn.Identity() |
|
|
| self.local_attn_window_size = local_attn_window_size |
|
|
| self.cond_dim = cond_dim |
|
|
| self.cross_attn_cond_dim = cross_attn_cond_dim |
|
|
| self.rotary_pos_emb = RotaryEmbedding(max(dim_head // 2, 32)) |
| |
| for _ in range(depth): |
|
|
| self.layers.append(nn.ModuleList([ |
| AdaRMSNorm(dim, cond_dim, eps=1e-8) if cond_dim > 0 else LayerNorm(dim), |
| Attention( |
| dim=dim, |
| dim_heads=dim_head, |
| causal=causal, |
| zero_init_output=True, |
| natten_kernel_size=local_attn_window_size, |
| ), |
| Attention( |
| dim=dim, |
| dim_heads=dim_head, |
| dim_context = cross_attn_cond_dim, |
| zero_init_output=True |
| ) if self.cross_attn_cond_dim > 0 else nn.Identity(), |
| AdaRMSNorm(dim, cond_dim, eps=1e-8) if cond_dim > 0 else LayerNorm(dim), |
| FeedForward(dim = dim, mult = ff_mult, no_bias=True) |
| ])) |
|
|
| def forward(self, x, mask = None, cond = None, cross_attn_cond = None, cross_attn_cond_mask = None, prepend_cond = None): |
| |
| x = checkpoint(self.project_in, x) |
|
|
| if prepend_cond is not None: |
| x = torch.cat([prepend_cond, x], dim=1) |
|
|
| pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1]) |
|
|
| for attn_norm, attn, xattn, ff_norm, ff in self.layers: |
|
|
| residual = x |
| if cond is not None: |
| x = checkpoint(attn_norm, x, cond) |
| else: |
| x = checkpoint(attn_norm, x) |
|
|
| x = checkpoint(attn, x, mask = mask, rotary_pos_emb=pos_emb) + residual |
|
|
| if cross_attn_cond is not None: |
| x = checkpoint(xattn, x, context=cross_attn_cond, context_mask=cross_attn_cond_mask) + x |
|
|
| residual = x |
|
|
| if cond is not None: |
| x = checkpoint(ff_norm, x, cond) |
| else: |
| x = checkpoint(ff_norm, x) |
|
|
| x = checkpoint(ff, x) + residual |
|
|
| return checkpoint(self.project_out, x) |
|
|
| class TransformerDownsampleBlock1D(nn.Module): |
| def __init__( |
| self, |
| in_channels, |
| embed_dim = 768, |
| depth = 3, |
| heads = 12, |
| downsample_ratio = 2, |
| local_attn_window_size = 64, |
| **kwargs |
| ): |
| super().__init__() |
|
|
| self.downsample_ratio = downsample_ratio |
|
|
| self.transformer = ContinuousLocalTransformer( |
| dim=embed_dim, |
| depth=depth, |
| heads=heads, |
| local_attn_window_size=local_attn_window_size, |
| **kwargs |
| ) |
|
|
| self.project_in = nn.Linear(in_channels, embed_dim, bias=False) if in_channels != embed_dim else nn.Identity() |
|
|
| self.project_down = nn.Linear(embed_dim * self.downsample_ratio, embed_dim, bias=False) |
| |
| |
| def forward(self, x): |
|
|
| x = checkpoint(self.project_in, x) |
|
|
| |
| x = self.transformer(x) |
|
|
| |
| x = rearrange(x, "b (n r) c -> b n (c r)", r=self.downsample_ratio) |
|
|
| |
| x = checkpoint(self.project_down, x) |
|
|
| return x |
|
|
| class TransformerUpsampleBlock1D(nn.Module): |
| def __init__( |
| self, |
| in_channels, |
| embed_dim, |
| depth = 3, |
| heads = 12, |
| upsample_ratio = 2, |
| local_attn_window_size = 64, |
| **kwargs |
| ): |
| super().__init__() |
|
|
| self.upsample_ratio = upsample_ratio |
|
|
| self.transformer = ContinuousLocalTransformer( |
| dim=embed_dim, |
| depth=depth, |
| heads=heads, |
| local_attn_window_size = local_attn_window_size, |
| **kwargs |
| ) |
|
|
| self.project_in = nn.Linear(in_channels, embed_dim, bias=False) if in_channels != embed_dim else nn.Identity() |
|
|
| self.project_up = nn.Linear(embed_dim, embed_dim * self.upsample_ratio, bias=False) |
| |
| def forward(self, x): |
|
|
| |
| x = checkpoint(self.project_in, x) |
|
|
| |
| x = checkpoint(self.project_up, x) |
|
|
| |
| x = rearrange(x, "b n (c r) -> b (n r) c", r=self.upsample_ratio) |
|
|
| |
| x = self.transformer(x) |
|
|
| return x |
| |
|
|
| class TransformerEncoder1D(nn.Module): |
| def __init__( |
| self, |
| in_channels, |
| out_channels, |
| embed_dims = [96, 192, 384, 768], |
| heads = [12, 12, 12, 12], |
| depths = [3, 3, 3, 3], |
| ratios = [2, 2, 2, 2], |
| local_attn_window_size = 64, |
| **kwargs |
| ): |
| super().__init__() |
| |
| layers = [] |
| |
| for layer in range(len(depths)): |
| prev_dim = embed_dims[layer - 1] if layer > 0 else embed_dims[0] |
|
|
| layers.append( |
| TransformerDownsampleBlock1D( |
| in_channels = prev_dim, |
| embed_dim = embed_dims[layer], |
| heads = heads[layer], |
| depth = depths[layer], |
| downsample_ratio = ratios[layer], |
| local_attn_window_size = local_attn_window_size, |
| **kwargs |
| ) |
| ) |
| |
| self.layers = nn.Sequential(*layers) |
|
|
| self.project_in = nn.Linear(in_channels, embed_dims[0], bias=False) |
| self.project_out = nn.Linear(embed_dims[-1], out_channels, bias=False) |
|
|
| def forward(self, x): |
| x = rearrange(x, "b c n -> b n c") |
| x = checkpoint(self.project_in, x) |
| x = self.layers(x) |
| x = checkpoint(self.project_out, x) |
| x = rearrange(x, "b n c -> b c n") |
|
|
| return x |
|
|
|
|
| class TransformerDecoder1D(nn.Module): |
| def __init__( |
| self, |
| in_channels, |
| out_channels, |
| embed_dims = [768, 384, 192, 96], |
| heads = [12, 12, 12, 12], |
| depths = [3, 3, 3, 3], |
| ratios = [2, 2, 2, 2], |
| local_attn_window_size = 64, |
| **kwargs |
| ): |
|
|
| super().__init__() |
|
|
| layers = [] |
| |
| for layer in range(len(depths)): |
| prev_dim = embed_dims[layer - 1] if layer > 0 else embed_dims[0] |
|
|
| layers.append( |
| TransformerUpsampleBlock1D( |
| in_channels = prev_dim, |
| embed_dim = embed_dims[layer], |
| heads = heads[layer], |
| depth = depths[layer], |
| upsample_ratio = ratios[layer], |
| local_attn_window_size = local_attn_window_size, |
| **kwargs |
| ) |
| ) |
| |
| self.layers = nn.Sequential(*layers) |
|
|
| self.project_in = nn.Linear(in_channels, embed_dims[0], bias=False) |
| self.project_out = nn.Linear(embed_dims[-1], out_channels, bias=False) |
|
|
| def forward(self, x): |
| x = rearrange(x, "b c n -> b n c") |
| x = checkpoint(self.project_in, x) |
| x = self.layers(x) |
| x = checkpoint(self.project_out, x) |
| x = rearrange(x, "b n c -> b c n") |
| return x |