| |
| |
| from torch import nn, Tensor |
| import torch.nn.functional as F |
| from torch.hub import load_state_dict_from_url |
| from typing import Optional |
|
|
| from ..utils import make_vgg_layers, vgg_cfgs, vgg_urls |
| from ..utils import _init_weights |
|
|
|
|
|
|
| class VGG(nn.Module): |
| def __init__( |
| self, |
| features: nn.Module, |
| reduction: Optional[int] = None, |
| ) -> None: |
| super().__init__() |
| self.features = features |
| self.reg_layer = nn.Sequential( |
| nn.Conv2d(512, 256, kernel_size=3, padding=1), |
| nn.ReLU(inplace=True), |
| nn.Conv2d(256, 128, kernel_size=3, padding=1), |
| nn.ReLU(inplace=True), |
| ) |
|
|
| self.reg_layer.apply(_init_weights) |
| |
| |
| self.encoder_reduction = 16 |
| self.reduction = self.encoder_reduction if reduction is None else reduction |
| self.channels = 128 |
|
|
| def forward(self, x: Tensor) -> Tensor: |
| x = self.features(x) |
| if self.encoder_reduction != self.reduction: |
| x = F.interpolate(x, scale_factor=self.encoder_reduction / self.reduction, mode="bilinear") |
| x = self.reg_layer(x) |
| |
| return x |
|
|
|
|
| def _load_weights(model: VGG, url: str) -> VGG: |
| state_dict = load_state_dict_from_url(url) |
| missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) |
| print("Loading pre-trained weights") |
| if len(missing_keys) > 0: |
| print(f"Missing keys: {missing_keys}") |
| if len(unexpected_keys) > 0: |
| print(f"Unexpected keys: {unexpected_keys}") |
| return model |
|
|
|
|
| def vgg11(reduction: int = 8) -> VGG: |
| model = VGG(make_vgg_layers(vgg_cfgs["A"]), reduction=reduction) |
| return _load_weights(model, vgg_urls["vgg11"]) |
|
|
| def vgg11_bn(reduction: int = 8) -> VGG: |
| model = VGG(make_vgg_layers(vgg_cfgs["A"], batch_norm=True), reduction=reduction) |
| return _load_weights(model, vgg_urls["vgg11_bn"]) |
|
|
| def vgg13(reduction: int = 8) -> VGG: |
| model = VGG(make_vgg_layers(vgg_cfgs["B"]), reduction=reduction) |
| return _load_weights(model, vgg_urls["vgg13"]) |
|
|
| def vgg13_bn(reduction: int = 8) -> VGG: |
| model = VGG(make_vgg_layers(vgg_cfgs["B"], batch_norm=True), reduction=reduction) |
| return _load_weights(model, vgg_urls["vgg13_bn"]) |
|
|
| def vgg16(reduction: int = 8) -> VGG: |
| model = VGG(make_vgg_layers(vgg_cfgs["D"]), reduction=reduction) |
| return _load_weights(model, vgg_urls["vgg16"]) |
|
|
| def vgg16_bn(reduction: int = 8) -> VGG: |
| model = VGG(make_vgg_layers(vgg_cfgs["D"], batch_norm=True), reduction=reduction) |
| return _load_weights(model, vgg_urls["vgg16_bn"]) |
|
|
| def vgg19(reduction: int = 8) -> VGG: |
| model = VGG(make_vgg_layers(vgg_cfgs["E"]), reduction=reduction) |
| return _load_weights(model, vgg_urls["vgg19"]) |
|
|
| def vgg19_bn(reduction: int = 8) -> VGG: |
| model = VGG(make_vgg_layers(vgg_cfgs["E"], batch_norm=True), reduction=reduction) |
| return _load_weights(model, vgg_urls["vgg19_bn"]) |
|
|