| | from networks import ResnetBlock |
| | import functools |
| | import torch |
| | import torch.nn as nn |
| |
|
| | |
| | |
| | |
| |
|
| | class GlobalGenerator(nn.Module): |
| | def __init__(self, input_nc=3, output_nc=3, ngf=64, n_downsampling=4, n_blocks=9, norm_layer=functools.partial(nn.InstanceNorm2d, affine=False), |
| | padding_type='reflect'): |
| | assert(n_blocks >= 0) |
| | super(GlobalGenerator, self).__init__() |
| | activation = nn.ReLU(True) |
| |
|
| | model = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0), norm_layer(ngf), activation] |
| | |
| | for i in range(n_downsampling): |
| | mult = 2**i |
| | model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1), |
| | norm_layer(ngf * mult * 2), activation] |
| |
|
| | |
| | mult = 2**n_downsampling |
| | for i in range(n_blocks): |
| | model += [ResnetBlock(ngf * mult, padding_type=padding_type, activation=activation, norm_layer=norm_layer)] |
| | |
| | |
| | for i in range(n_downsampling): |
| | mult = 2**(n_downsampling - i) |
| | model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1), |
| | norm_layer(int(ngf * mult / 2)), activation] |
| | model += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Tanh()] |
| | self.model = nn.Sequential(*model) |
| | |
| | def forward(self, input): |
| | return self.model(input) |
| |
|