| import torch |
| from torch import nn, Tensor |
| from typing import List |
| from einops import rearrange |
|
|
| from .blocks import conv3x3, conv1x1, Conv2dLayerNorm, _init_weights |
|
|
|
|
| class MultiScale(nn.Module): |
| def __init__( |
| self, |
| channels: int, |
| scales: List[int], |
| heads: int = 8, |
| groups: int = 1, |
| mlp_ratio: float = 4.0, |
| ) -> None: |
| super().__init__() |
| assert channels > 0, "channels should be a positive integer" |
| assert isinstance(scales, (list, tuple)) and len(scales) > 0 and all([scale > 0 for scale in scales]), "scales should be a list or tuple of positive integers" |
| assert heads > 0 and channels % heads == 0, "heads should be a positive integer and channels should be divisible by heads" |
| assert groups > 0 and channels % groups == 0, "groups should be a positive integer and channels should be divisible by groups" |
| scales = sorted(scales) |
| self.scales = scales |
| self.num_scales = len(scales) + 1 |
| self.heads = heads |
| self.groups = groups |
|
|
| |
| self.scale_0 = nn.Sequential( |
| conv1x1(channels, channels, stride=1, bias=False), |
| Conv2dLayerNorm(channels), |
| nn.GELU(), |
| ) |
| for scale in scales: |
| setattr(self, f"conv_{scale}", nn.Sequential( |
| conv3x3( |
| in_channels=channels, |
| out_channels=channels, |
| stride=1, |
| groups=groups, |
| dilation=scale, |
| bias=False, |
| ), |
| conv1x1(channels, channels, stride=1, bias=False) if groups > 1 else nn.Identity(), |
| Conv2dLayerNorm(channels), |
| nn.GELU(), |
| )) |
| |
| |
| self.norm_attn = Conv2dLayerNorm(channels) |
| self.pos_embed = nn.Parameter(torch.randn(1, self.num_scales + 1, channels, 1, 1) / channels ** 0.5) |
| self.to_q = conv1x1(channels, channels, stride=1, bias=False) |
| self.to_k = conv1x1(channels, channels, stride=1, bias=False) |
| self.to_v = conv1x1(channels, channels, stride=1, bias=False) |
|
|
| self.scale = (channels // heads) ** -0.5 |
|
|
| self.attend = nn.Softmax(dim=-1) |
| |
| self.to_out = conv1x1(channels, channels, stride=1) |
|
|
| |
| self.norm_mlp = Conv2dLayerNorm(channels) |
| self.mlp = nn.Sequential( |
| conv1x1(channels, channels * mlp_ratio, stride=1), |
| nn.GELU(), |
| conv1x1(channels * mlp_ratio, channels, stride=1), |
| ) |
| |
| self.apply(_init_weights) |
| |
| def _forward_attn(self, x: Tensor) -> Tensor: |
| assert len(x.shape) == 4, f"Expected input to have shape (B, C, H, W), but got {x.shape}" |
| x = [self.scale_0(x)] + [getattr(self, f"conv_{scale}")(x) for scale in self.scales] |
|
|
| x = torch.stack(x, dim=1) |
| x = torch.cat([x.mean(dim=1, keepdim=True), x], dim=1) |
| x = x + self.pos_embed |
|
|
| x = rearrange(x, "B S C H W -> (B S) C H W") |
| x = self.norm_attn(x) |
| x = rearrange(x, "(B S) C H W -> B S C H W", S=self.num_scales + 1) |
| |
| q = self.to_q(x[:, 0]) |
| k = self.to_k(rearrange(x, "B S C H W -> (B S) C H W")) |
| v = self.to_v(rearrange(x, "B S C H W -> (B S) C H W")) |
|
|
| q = rearrange(q, "B (h d) H W -> B h H W 1 d", h=self.heads) |
| k = rearrange(k, "(B S) (h d) H W -> B h H W S d", S=self.num_scales + 1, h=self.heads) |
| v = rearrange(v, "(B S) (h d) H W -> B h H W S d", S=self.num_scales + 1, h=self.heads) |
|
|
| attn = q @ k.transpose(-2, -1) * self.scale |
| attn = self.attend(attn) |
| out = attn @ v |
|
|
| out = rearrange(out, "B h H W 1 d -> B (h d) H W") |
|
|
| out = self.to_out(out) |
| return out |
| |
| def _forward_mlp(self, x: Tensor) -> Tensor: |
| assert len(x.shape) == 4, f"Expected input to have shape (B, C, H, W), but got {x.shape}" |
| x = self.norm_mlp(x) |
| x = self.mlp(x) |
| return x |
|
|
| def forward(self, x: Tensor) -> Tensor: |
| x = x + self._forward_attn(x) |
| x = x + self._forward_mlp(x) |
| return x |
| |