| | |
| | import torch |
| | from torch import nn |
| | from torch.nn import functional as F |
| |
|
| |
|
| | class Encoding(nn.Module): |
| | """Encoding Layer: a learnable residual encoder. |
| | |
| | Input is of shape (batch_size, channels, height, width). |
| | Output is of shape (batch_size, num_codes, channels). |
| | |
| | Args: |
| | channels: dimension of the features or feature channels |
| | num_codes: number of code words |
| | """ |
| |
|
| | def __init__(self, channels, num_codes): |
| | super().__init__() |
| | |
| | self.channels, self.num_codes = channels, num_codes |
| | std = 1. / ((num_codes * channels)**0.5) |
| | |
| | self.codewords = nn.Parameter( |
| | torch.empty(num_codes, channels, |
| | dtype=torch.float).uniform_(-std, std), |
| | requires_grad=True) |
| | |
| | self.scale = nn.Parameter( |
| | torch.empty(num_codes, dtype=torch.float).uniform_(-1, 0), |
| | requires_grad=True) |
| |
|
| | @staticmethod |
| | def scaled_l2(x, codewords, scale): |
| | num_codes, channels = codewords.size() |
| | batch_size = x.size(0) |
| | reshaped_scale = scale.view((1, 1, num_codes)) |
| | expanded_x = x.unsqueeze(2).expand( |
| | (batch_size, x.size(1), num_codes, channels)) |
| | reshaped_codewords = codewords.view((1, 1, num_codes, channels)) |
| |
|
| | scaled_l2_norm = reshaped_scale * ( |
| | expanded_x - reshaped_codewords).pow(2).sum(dim=3) |
| | return scaled_l2_norm |
| |
|
| | @staticmethod |
| | def aggregate(assignment_weights, x, codewords): |
| | num_codes, channels = codewords.size() |
| | reshaped_codewords = codewords.view((1, 1, num_codes, channels)) |
| | batch_size = x.size(0) |
| |
|
| | expanded_x = x.unsqueeze(2).expand( |
| | (batch_size, x.size(1), num_codes, channels)) |
| | encoded_feat = (assignment_weights.unsqueeze(3) * |
| | (expanded_x - reshaped_codewords)).sum(dim=1) |
| | return encoded_feat |
| |
|
| | def forward(self, x): |
| | assert x.dim() == 4 and x.size(1) == self.channels |
| | |
| | batch_size = x.size(0) |
| | |
| | x = x.view(batch_size, self.channels, -1).transpose(1, 2).contiguous() |
| | |
| | assignment_weights = F.softmax( |
| | self.scaled_l2(x, self.codewords, self.scale), dim=2) |
| | |
| | encoded_feat = self.aggregate(assignment_weights, x, self.codewords) |
| | return encoded_feat |
| |
|
| | def __repr__(self): |
| | repr_str = self.__class__.__name__ |
| | repr_str += f'(Nx{self.channels}xHxW =>Nx{self.num_codes}' \ |
| | f'x{self.channels})' |
| | return repr_str |
| |
|