Spaces:
Runtime error
Runtime error
| import functools | |
| import torch | |
| from torch.nn import init | |
| """ | |
| # -------------------------------------------- | |
| # select the network of G, D and F | |
| # -------------------------------------------- | |
| """ | |
| # -------------------------------------------- | |
| # Generator, netG, G | |
| # -------------------------------------------- | |
| def define_G(opt): | |
| opt_net = opt['netG'] | |
| net_type = opt_net['net_type'] | |
| # ---------------------------------------- | |
| # denoising task | |
| # ---------------------------------------- | |
| # ---------------------------------------- | |
| # DnCNN | |
| # ---------------------------------------- | |
| if net_type == 'dncnn': | |
| from models.network_dncnn import DnCNN as net | |
| netG = net(in_nc=opt_net['in_nc'], | |
| out_nc=opt_net['out_nc'], | |
| nc=opt_net['nc'], | |
| nb=opt_net['nb'], # total number of conv layers | |
| act_mode=opt_net['act_mode']) | |
| # ---------------------------------------- | |
| # Flexible DnCNN | |
| # ---------------------------------------- | |
| elif net_type == 'fdncnn': | |
| from models.network_dncnn import FDnCNN as net | |
| netG = net(in_nc=opt_net['in_nc'], | |
| out_nc=opt_net['out_nc'], | |
| nc=opt_net['nc'], | |
| nb=opt_net['nb'], # total number of conv layers | |
| act_mode=opt_net['act_mode']) | |
| # ---------------------------------------- | |
| # FFDNet | |
| # ---------------------------------------- | |
| elif net_type == 'ffdnet': | |
| from models.network_ffdnet import FFDNet as net | |
| netG = net(in_nc=opt_net['in_nc'], | |
| out_nc=opt_net['out_nc'], | |
| nc=opt_net['nc'], | |
| nb=opt_net['nb'], | |
| act_mode=opt_net['act_mode']) | |
| # ---------------------------------------- | |
| # others | |
| # ---------------------------------------- | |
| # ---------------------------------------- | |
| # super-resolution task | |
| # ---------------------------------------- | |
| # ---------------------------------------- | |
| # SRMD | |
| # ---------------------------------------- | |
| elif net_type == 'srmd': | |
| from models.network_srmd import SRMD as net | |
| netG = net(in_nc=opt_net['in_nc'], | |
| out_nc=opt_net['out_nc'], | |
| nc=opt_net['nc'], | |
| nb=opt_net['nb'], | |
| upscale=opt_net['scale'], | |
| act_mode=opt_net['act_mode'], | |
| upsample_mode=opt_net['upsample_mode']) | |
| # ---------------------------------------- | |
| # super-resolver prior of DPSR | |
| # ---------------------------------------- | |
| elif net_type == 'dpsr': | |
| from models.network_dpsr import MSRResNet_prior as net | |
| netG = net(in_nc=opt_net['in_nc'], | |
| out_nc=opt_net['out_nc'], | |
| nc=opt_net['nc'], | |
| nb=opt_net['nb'], | |
| upscale=opt_net['scale'], | |
| act_mode=opt_net['act_mode'], | |
| upsample_mode=opt_net['upsample_mode']) | |
| # ---------------------------------------- | |
| # modified SRResNet v0.0 | |
| # ---------------------------------------- | |
| elif net_type == 'msrresnet0': | |
| from models.network_msrresnet import MSRResNet0 as net | |
| netG = net(in_nc=opt_net['in_nc'], | |
| out_nc=opt_net['out_nc'], | |
| nc=opt_net['nc'], | |
| nb=opt_net['nb'], | |
| upscale=opt_net['scale'], | |
| act_mode=opt_net['act_mode'], | |
| upsample_mode=opt_net['upsample_mode']) | |
| # ---------------------------------------- | |
| # modified SRResNet v0.1 | |
| # ---------------------------------------- | |
| elif net_type == 'msrresnet1': | |
| from models.network_msrresnet import MSRResNet1 as net | |
| netG = net(in_nc=opt_net['in_nc'], | |
| out_nc=opt_net['out_nc'], | |
| nc=opt_net['nc'], | |
| nb=opt_net['nb'], | |
| upscale=opt_net['scale'], | |
| act_mode=opt_net['act_mode'], | |
| upsample_mode=opt_net['upsample_mode']) | |
| # ---------------------------------------- | |
| # RRDB | |
| # ---------------------------------------- | |
| elif net_type == 'rrdb': # RRDB | |
| from models.network_rrdb import RRDB as net | |
| netG = net(in_nc=opt_net['in_nc'], | |
| out_nc=opt_net['out_nc'], | |
| nc=opt_net['nc'], | |
| nb=opt_net['nb'], | |
| gc=opt_net['gc'], | |
| upscale=opt_net['scale'], | |
| act_mode=opt_net['act_mode'], | |
| upsample_mode=opt_net['upsample_mode']) | |
| # ---------------------------------------- | |
| # RRDBNet | |
| # ---------------------------------------- | |
| elif net_type == 'rrdbnet': # RRDBNet | |
| from models.network_rrdbnet import RRDBNet as net | |
| netG = net(in_nc=opt_net['in_nc'], | |
| out_nc=opt_net['out_nc'], | |
| nf=opt_net['nf'], | |
| nb=opt_net['nb'], | |
| gc=opt_net['gc'], | |
| sf=opt_net['scale']) | |
| # ---------------------------------------- | |
| # IMDB | |
| # ---------------------------------------- | |
| elif net_type == 'imdn': # IMDB | |
| from models.network_imdn import IMDN as net | |
| netG = net(in_nc=opt_net['in_nc'], | |
| out_nc=opt_net['out_nc'], | |
| nc=opt_net['nc'], | |
| nb=opt_net['nb'], | |
| upscale=opt_net['scale'], | |
| act_mode=opt_net['act_mode'], | |
| upsample_mode=opt_net['upsample_mode']) | |
| # ---------------------------------------- | |
| # USRNet | |
| # ---------------------------------------- | |
| elif net_type == 'usrnet': # USRNet | |
| from models.network_usrnet import USRNet as net | |
| netG = net(n_iter=opt_net['n_iter'], | |
| h_nc=opt_net['h_nc'], | |
| in_nc=opt_net['in_nc'], | |
| out_nc=opt_net['out_nc'], | |
| nc=opt_net['nc'], | |
| nb=opt_net['nb'], | |
| act_mode=opt_net['act_mode'], | |
| downsample_mode=opt_net['downsample_mode'], | |
| upsample_mode=opt_net['upsample_mode'] | |
| ) | |
| # ---------------------------------------- | |
| # Deep Residual U-Net (drunet) | |
| # ---------------------------------------- | |
| elif net_type == 'drunet': | |
| from models.network_unet import UNetRes as net | |
| netG = net(in_nc=opt_net['in_nc'], | |
| out_nc=opt_net['out_nc'], | |
| nc=opt_net['nc'], | |
| nb=opt_net['nb'], | |
| act_mode=opt_net['act_mode'], | |
| downsample_mode=opt_net['downsample_mode'], | |
| upsample_mode=opt_net['upsample_mode'], | |
| bias=opt_net['bias']) | |
| # ---------------------------------------- | |
| # SwinIR | |
| # ---------------------------------------- | |
| elif net_type == 'swinir': | |
| from models.network_swinir import SwinIR as net | |
| netG = net(upscale=opt_net['upscale'], | |
| in_chans=opt_net['in_chans'], | |
| img_size=opt_net['img_size'], | |
| window_size=opt_net['window_size'], | |
| img_range=opt_net['img_range'], | |
| depths=opt_net['depths'], | |
| embed_dim=opt_net['embed_dim'], | |
| num_heads=opt_net['num_heads'], | |
| mlp_ratio=opt_net['mlp_ratio'], | |
| upsampler=opt_net['upsampler'], | |
| resi_connection=opt_net['resi_connection']) | |
| # ---------------------------------------- | |
| # VRT | |
| # ---------------------------------------- | |
| elif net_type == 'vrt': | |
| from models.network_vrt import VRT as net | |
| netG = net(upscale=opt_net['upscale'], | |
| img_size=opt_net['img_size'], | |
| window_size=opt_net['window_size'], | |
| depths=opt_net['depths'], | |
| indep_reconsts=opt_net['indep_reconsts'], | |
| embed_dims=opt_net['embed_dims'], | |
| num_heads=opt_net['num_heads'], | |
| spynet_path=opt_net['spynet_path'], | |
| pa_frames=opt_net['pa_frames'], | |
| deformable_groups=opt_net['deformable_groups'], | |
| nonblind_denoising=opt_net['nonblind_denoising'], | |
| use_checkpoint_attn=opt_net['use_checkpoint_attn'], | |
| use_checkpoint_ffn=opt_net['use_checkpoint_ffn'], | |
| no_checkpoint_attn_blocks=opt_net['no_checkpoint_attn_blocks'], | |
| no_checkpoint_ffn_blocks=opt_net['no_checkpoint_ffn_blocks']) | |
| # ---------------------------------------- | |
| # others | |
| # ---------------------------------------- | |
| # TODO | |
| else: | |
| raise NotImplementedError('netG [{:s}] is not found.'.format(net_type)) | |
| # ---------------------------------------- | |
| # initialize weights | |
| # ---------------------------------------- | |
| if opt['is_train']: | |
| init_weights(netG, | |
| init_type=opt_net['init_type'], | |
| init_bn_type=opt_net['init_bn_type'], | |
| gain=opt_net['init_gain']) | |
| return netG | |
| # -------------------------------------------- | |
| # Discriminator, netD, D | |
| # -------------------------------------------- | |
| def define_D(opt): | |
| opt_net = opt['netD'] | |
| net_type = opt_net['net_type'] | |
| # ---------------------------------------- | |
| # discriminator_vgg_96 | |
| # ---------------------------------------- | |
| if net_type == 'discriminator_vgg_96': | |
| from models.network_discriminator import Discriminator_VGG_96 as discriminator | |
| netD = discriminator(in_nc=opt_net['in_nc'], | |
| base_nc=opt_net['base_nc'], | |
| ac_type=opt_net['act_mode']) | |
| # ---------------------------------------- | |
| # discriminator_vgg_128 | |
| # ---------------------------------------- | |
| elif net_type == 'discriminator_vgg_128': | |
| from models.network_discriminator import Discriminator_VGG_128 as discriminator | |
| netD = discriminator(in_nc=opt_net['in_nc'], | |
| base_nc=opt_net['base_nc'], | |
| ac_type=opt_net['act_mode']) | |
| # ---------------------------------------- | |
| # discriminator_vgg_192 | |
| # ---------------------------------------- | |
| elif net_type == 'discriminator_vgg_192': | |
| from models.network_discriminator import Discriminator_VGG_192 as discriminator | |
| netD = discriminator(in_nc=opt_net['in_nc'], | |
| base_nc=opt_net['base_nc'], | |
| ac_type=opt_net['act_mode']) | |
| # ---------------------------------------- | |
| # discriminator_vgg_128_SN | |
| # ---------------------------------------- | |
| elif net_type == 'discriminator_vgg_128_SN': | |
| from models.network_discriminator import Discriminator_VGG_128_SN as discriminator | |
| netD = discriminator() | |
| elif net_type == 'discriminator_patchgan': | |
| from models.network_discriminator import Discriminator_PatchGAN as discriminator | |
| netD = discriminator(input_nc=opt_net['in_nc'], | |
| ndf=opt_net['base_nc'], | |
| n_layers=opt_net['n_layers'], | |
| norm_type=opt_net['norm_type']) | |
| elif net_type == 'discriminator_unet': | |
| from models.network_discriminator import Discriminator_UNet as discriminator | |
| netD = discriminator(input_nc=opt_net['in_nc'], | |
| ndf=opt_net['base_nc']) | |
| else: | |
| raise NotImplementedError('netD [{:s}] is not found.'.format(net_type)) | |
| # ---------------------------------------- | |
| # initialize weights | |
| # ---------------------------------------- | |
| init_weights(netD, | |
| init_type=opt_net['init_type'], | |
| init_bn_type=opt_net['init_bn_type'], | |
| gain=opt_net['init_gain']) | |
| return netD | |
| # -------------------------------------------- | |
| # VGGfeature, netF, F | |
| # -------------------------------------------- | |
| def define_F(opt, use_bn=False): | |
| device = torch.device('cuda' if opt['gpu_ids'] else 'cpu') | |
| from models.network_feature import VGGFeatureExtractor | |
| # pytorch pretrained VGG19-54, before ReLU. | |
| if use_bn: | |
| feature_layer = 49 | |
| else: | |
| feature_layer = 34 | |
| netF = VGGFeatureExtractor(feature_layer=feature_layer, | |
| use_bn=use_bn, | |
| use_input_norm=True, | |
| device=device) | |
| netF.eval() # No need to train, but need BP to input | |
| return netF | |
| """ | |
| # -------------------------------------------- | |
| # weights initialization | |
| # -------------------------------------------- | |
| """ | |
| def init_weights(net, init_type='xavier_uniform', init_bn_type='uniform', gain=1): | |
| """ | |
| # Kai Zhang, https://github.com/cszn/KAIR | |
| # | |
| # Args: | |
| # init_type: | |
| # default, none: pass init_weights | |
| # normal; normal; xavier_normal; xavier_uniform; | |
| # kaiming_normal; kaiming_uniform; orthogonal | |
| # init_bn_type: | |
| # uniform; constant | |
| # gain: | |
| # 0.2 | |
| """ | |
| def init_fn(m, init_type='xavier_uniform', init_bn_type='uniform', gain=1): | |
| classname = m.__class__.__name__ | |
| if classname.find('Conv') != -1 or classname.find('Linear') != -1: | |
| if init_type == 'normal': | |
| init.normal_(m.weight.data, 0, 0.1) | |
| m.weight.data.clamp_(-1, 1).mul_(gain) | |
| elif init_type == 'uniform': | |
| init.uniform_(m.weight.data, -0.2, 0.2) | |
| m.weight.data.mul_(gain) | |
| elif init_type == 'xavier_normal': | |
| init.xavier_normal_(m.weight.data, gain=gain) | |
| m.weight.data.clamp_(-1, 1) | |
| elif init_type == 'xavier_uniform': | |
| init.xavier_uniform_(m.weight.data, gain=gain) | |
| elif init_type == 'kaiming_normal': | |
| init.kaiming_normal_(m.weight.data, a=0, mode='fan_in', nonlinearity='relu') | |
| m.weight.data.clamp_(-1, 1).mul_(gain) | |
| elif init_type == 'kaiming_uniform': | |
| init.kaiming_uniform_(m.weight.data, a=0, mode='fan_in', nonlinearity='relu') | |
| m.weight.data.mul_(gain) | |
| elif init_type == 'orthogonal': | |
| init.orthogonal_(m.weight.data, gain=gain) | |
| else: | |
| raise NotImplementedError('Initialization method [{:s}] is not implemented'.format(init_type)) | |
| if m.bias is not None: | |
| m.bias.data.zero_() | |
| elif classname.find('BatchNorm2d') != -1: | |
| if init_bn_type == 'uniform': # preferred | |
| if m.affine: | |
| init.uniform_(m.weight.data, 0.1, 1.0) | |
| init.constant_(m.bias.data, 0.0) | |
| elif init_bn_type == 'constant': | |
| if m.affine: | |
| init.constant_(m.weight.data, 1.0) | |
| init.constant_(m.bias.data, 0.0) | |
| else: | |
| raise NotImplementedError('Initialization method [{:s}] is not implemented'.format(init_bn_type)) | |
| if init_type not in ['default', 'none']: | |
| print('Initialization method [{:s} + {:s}], gain is [{:.2f}]'.format(init_type, init_bn_type, gain)) | |
| fn = functools.partial(init_fn, init_type=init_type, init_bn_type=init_bn_type, gain=gain) | |
| net.apply(fn) | |
| else: | |
| print('Pass this initialization! Initialization was done during network definition!') | |