| from torch import nn, Tensor |
| from torch.nn import functional as F |
|
|
| from typing import Union |
| from functools import partial |
|
|
| from .utils import _init_weights |
| from .refine import ConvRefine, LightConvRefine, LighterConvRefine |
|
|
|
|
| class ConvUpsample(nn.Module): |
| def __init__( |
| self, |
| in_channels: int, |
| out_channels: int, |
| scale_factor: int = 2, |
| norm_layer: Union[nn.BatchNorm2d, nn.GroupNorm, None] = nn.BatchNorm2d, |
| activation: nn.Module = nn.ReLU(inplace=True), |
| groups: int = 1, |
| ) -> None: |
| super().__init__() |
| assert scale_factor >= 1, f"Scale factor should be greater than or equal to 1, but got {scale_factor}" |
| self.scale_factor = scale_factor |
| self.upsample = partial( |
| F.interpolate, |
| scale_factor=scale_factor, |
| mode="bilinear", |
| align_corners=False, |
| recompute_scale_factor=False, |
| antialias=False, |
| ) if scale_factor > 1 else nn.Identity() |
|
|
| self.refine = ConvRefine( |
| in_channels=in_channels, |
| out_channels=out_channels, |
| norm_layer=norm_layer, |
| activation=activation, |
| groups=groups, |
| ) |
|
|
| self.apply(_init_weights) |
| |
| def forward(self, x: Tensor) -> Tensor: |
| x = self.upsample(x) |
| x = self.refine(x) |
| return x |
|
|
|
|
| class LightConvUpsample(nn.Module): |
| def __init__( |
| self, |
| in_channels: int, |
| out_channels: int, |
| scale_factor: int = 2, |
| norm_layer: Union[nn.BatchNorm2d, nn.GroupNorm, None] = nn.BatchNorm2d, |
| activation: nn.Module = nn.ReLU(inplace=True), |
| ) -> None: |
| super().__init__() |
| assert scale_factor >= 1, f"Scale factor should be greater than or equal to 1, but got {scale_factor}" |
| self.scale_factor = scale_factor |
| self.upsample = partial( |
| F.interpolate, |
| scale_factor=scale_factor, |
| mode="bilinear", |
| align_corners=False, |
| recompute_scale_factor=False, |
| antialias=False, |
| ) if scale_factor > 1 else nn.Identity() |
|
|
| self.refine = LightConvRefine( |
| in_channels=in_channels, |
| out_channels=out_channels, |
| norm_layer=norm_layer, |
| activation=activation, |
| ) |
|
|
| self.apply(_init_weights) |
| |
| def forward(self, x: Tensor) -> Tensor: |
| x = self.upsample(x) |
| x = self.refine(x) |
| return x |
|
|
|
|
| class LighterConvUpsample(nn.Module): |
| def __init__( |
| self, |
| in_channels: int, |
| out_channels: int, |
| scale_factor: int = 2, |
| norm_layer: Union[nn.BatchNorm2d, nn.GroupNorm, None] = nn.BatchNorm2d, |
| activation: nn.Module = nn.ReLU(inplace=True), |
| ) -> None: |
| super().__init__() |
| assert scale_factor >= 1, f"Scale factor should be greater than or equal to 1, but got {scale_factor}" |
| self.scale_factor = scale_factor |
| self.upsample = partial( |
| F.interpolate, |
| scale_factor=scale_factor, |
| mode="bilinear", |
| align_corners=False, |
| recompute_scale_factor=False, |
| antialias=False, |
| ) if scale_factor > 1 else nn.Identity() |
|
|
| self.refine = LighterConvRefine( |
| in_channels=in_channels, |
| out_channels=out_channels, |
| norm_layer=norm_layer, |
| activation=activation, |
| ) |
|
|
| self.apply(_init_weights) |
| |
| def forward(self, x: Tensor) -> Tensor: |
| x = self.upsample(x) |
| x = self.refine(x) |
| return x |
|
|