| | import torch |
| | import torch.nn as nn |
| |
|
| |
|
| | def count_params(model): |
| | total_params = sum(p.numel() for p in model.parameters()) |
| | return total_params |
| |
|
| |
|
| | class ActNorm(nn.Module): |
| | def __init__(self, num_features, logdet=False, affine=True, |
| | allow_reverse_init=False): |
| | assert affine |
| | super().__init__() |
| | self.logdet = logdet |
| | self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1)) |
| | self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1)) |
| | self.allow_reverse_init = allow_reverse_init |
| |
|
| | self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8)) |
| |
|
| | def initialize(self, input): |
| | with torch.no_grad(): |
| | flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1) |
| | mean = ( |
| | flatten.mean(1) |
| | .unsqueeze(1) |
| | .unsqueeze(2) |
| | .unsqueeze(3) |
| | .permute(1, 0, 2, 3) |
| | ) |
| | std = ( |
| | flatten.std(1) |
| | .unsqueeze(1) |
| | .unsqueeze(2) |
| | .unsqueeze(3) |
| | .permute(1, 0, 2, 3) |
| | ) |
| |
|
| | self.loc.data.copy_(-mean) |
| | self.scale.data.copy_(1 / (std + 1e-6)) |
| |
|
| | def forward(self, input, reverse=False): |
| | if reverse: |
| | return self.reverse(input) |
| | if len(input.shape) == 2: |
| | input = input[:,:,None,None] |
| | squeeze = True |
| | else: |
| | squeeze = False |
| |
|
| | _, _, height, width = input.shape |
| |
|
| | if self.training and self.initialized.item() == 0: |
| | self.initialize(input) |
| | self.initialized.fill_(1) |
| |
|
| | h = self.scale * (input + self.loc) |
| |
|
| | if squeeze: |
| | h = h.squeeze(-1).squeeze(-1) |
| |
|
| | if self.logdet: |
| | log_abs = torch.log(torch.abs(self.scale)) |
| | logdet = height*width*torch.sum(log_abs) |
| | logdet = logdet * torch.ones(input.shape[0]).to(input) |
| | return h, logdet |
| |
|
| | return h |
| |
|
| | def reverse(self, output): |
| | if self.training and self.initialized.item() == 0: |
| | if not self.allow_reverse_init: |
| | raise RuntimeError( |
| | "Initializing ActNorm in reverse direction is " |
| | "disabled by default. Use allow_reverse_init=True to enable." |
| | ) |
| | else: |
| | self.initialize(output) |
| | self.initialized.fill_(1) |
| |
|
| | if len(output.shape) == 2: |
| | output = output[:,:,None,None] |
| | squeeze = True |
| | else: |
| | squeeze = False |
| |
|
| | h = output / self.scale - self.loc |
| |
|
| | if squeeze: |
| | h = h.squeeze(-1).squeeze(-1) |
| | return h |
| |
|
| |
|
| | class AbstractEncoder(nn.Module): |
| | def __init__(self): |
| | super().__init__() |
| |
|
| | def encode(self, *args, **kwargs): |
| | raise NotImplementedError |
| |
|
| |
|
| | class Labelator(AbstractEncoder): |
| | """Net2Net Interface for Class-Conditional Model""" |
| | def __init__(self, n_classes, quantize_interface=True): |
| | super().__init__() |
| | self.n_classes = n_classes |
| | self.quantize_interface = quantize_interface |
| |
|
| | def encode(self, c): |
| | c = c[:,None] |
| | if self.quantize_interface: |
| | return c, None, [None, None, c.long()] |
| | return c |
| |
|
| |
|
| | class SOSProvider(AbstractEncoder): |
| | |
| | def __init__(self, sos_token, quantize_interface=True): |
| | super().__init__() |
| | self.sos_token = sos_token |
| | self.quantize_interface = quantize_interface |
| |
|
| | def encode(self, x): |
| | |
| | c = torch.ones(x.shape[0], 1)*self.sos_token |
| | c = c.long().to(x.device) |
| | if self.quantize_interface: |
| | return c, None, [None, None, c] |
| | return c |
| |
|