Spaces:
Sleeping
Sleeping
| from typing import Union, Tuple | |
| from torch_geometric.typing import OptTensor, OptPairTensor, Adj, Size | |
| import torch | |
| from torch import Tensor | |
| from torch.nn.functional import conv2d | |
| from torch_sparse import SparseTensor, matmul | |
| from torch_geometric.nn.conv import MessagePassing | |
| from torch_geometric.utils.loop import contains_self_loops | |
| from torch.nn.parameter import Parameter | |
| from torch_geometric.nn.inits import glorot, zeros | |
| from math import sqrt | |
| import numpy as np | |
| def intersect1d(tensor1,tensor2): | |
| device = tensor1.device | |
| result, ind1, ind2 = np.intersect1d(tensor1.cpu().numpy(),tensor2.cpu().numpy(),return_indices=True) | |
| return torch.tensor(result).to(device), torch.tensor(ind1).to(device), torch.tensor(ind2).to(device) | |
| def setdiff1d(tensor1,tensor2): | |
| device = tensor1.device | |
| result = np.setdiff1d(tensor1.cpu().numpy(),tensor2.cpu().numpy()) | |
| return torch.tensor(result).to(device) | |
| class SelectionConv(MessagePassing): | |
| def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3, dilation = 1, padding_mode = 'zeros', **kwargs): | |
| super(SelectionConv, self).__init__(**kwargs) | |
| self.kernel_size = kernel_size | |
| self.in_channels = in_channels | |
| self.out_channels = out_channels | |
| self.dilation = dilation | |
| self.padding_mode = padding_mode | |
| self.selection_count = kernel_size * kernel_size | |
| #self.has_self_loops = has_self_loops | |
| self.weight = Parameter(torch.randn(self.selection_count,in_channels,out_channels,dtype=torch.float)) | |
| torch.nn.init.uniform_(self.weight, a=-0.1, b=0.1) | |
| #torch.nn.init.normal_(self.weight) | |
| self.bias = Parameter(torch.randn(out_channels,dtype=torch.float)) | |
| torch.nn.init.uniform_(self.bias, a=0.0, b=0.1) | |
| def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj, selections: Tensor, interps = None) -> Tensor: | |
| """""" | |
| all_nodes = torch.arange(x.shape[0]).to(x.device) | |
| if self.padding_mode == 'constant': | |
| # Constant value of the average of the all the nodes | |
| x_mean = torch.mean(x,dim=0) | |
| out = torch.zeros((x.shape[0],self.out_channels)).to(x.device) | |
| if self.padding_mode == 'normalize': | |
| dir_count = torch.zeros((x.shape[0],1)).to(x.device) | |
| if self.kernel_size == 1 or self.kernel_size == 3: | |
| # Find the appropriate node for each selection by stepping through connecting edges | |
| for s in range(self.selection_count): | |
| cur_dir = torch.where(selections == s)[0] | |
| cur_source = edge_index[0,cur_dir] | |
| cur_target = edge_index[1,cur_dir] | |
| if interps is not None: | |
| cur_interps = interps[cur_dir] | |
| cur_interps = torch.unsqueeze(cur_interps,dim=1) | |
| #print(torch.amin(cur_interps),torch.amax(cur_interps)) | |
| if self.dilation > 1: | |
| for _ in range(1, self.dilation): | |
| vals, ind1, ind2 = intersect1d(cur_target,edge_index[0,cur_dir]) | |
| cur_source = cur_source[ind1] | |
| cur_target = edge_index[1,cur_dir][ind2] | |
| if interps is not None: | |
| cur_interps = cur_interps[ind1] | |
| # Main Calculation | |
| if interps is None: | |
| #out[cur_source] += torch.matmul(x[cur_target], self.weight[s]) | |
| result = torch.matmul(x[cur_target], self.weight[s]) | |
| else: | |
| #out[cur_source] += cur_interps*torch.matmul(x[cur_target], self.weight[s]) | |
| result = cur_interps*torch.matmul(x[cur_target], self.weight[s]) | |
| # Adding with duplicate indices | |
| out.index_add_(0,cur_source,result) | |
| # Sanity check | |
| #from tqdm import tqdm | |
| #for i,node in enumerate(tqdm(cur_source)): | |
| # out[node] += result[i] | |
| if self.padding_mode == 'constant': | |
| missed_nodes = setdiff1d(all_nodes, cur_source) | |
| #out[missed_nodes] += torch.matmul(x_mean, self.weight[s]) | |
| out.index_add_(0,missed_nodes,torch.matmul(x_mean, self.weight[s])) | |
| if self.padding_mode == 'replicate': | |
| missed_nodes = setdiff1d(all_nodes, cur_source) | |
| #out[missed_nodes] += torch.matmul(x[missed_nodes], self.weight[s]) | |
| out.index_add_(0,missed_nodes,torch.matmul(x[missed_nodes], self.weight[s])) | |
| if self.padding_mode == 'reflect': | |
| missed_nodes = setdiff1d(all_nodes, cur_source) | |
| opposite = s+4 | |
| if opposite > 8: | |
| opposite = opposite % 9 + 1 | |
| op_dir = torch.where(selections == opposite)[0] | |
| op_source = edge_index[0,op_dir] | |
| op_target = edge_index[1,op_dir] | |
| if interps is not None: | |
| op_interps = interps[op_dir] | |
| op_interps = torch.unsqueeze(op_interps,dim=1) | |
| # Only take edges that are part of missed nodes | |
| vals, ind1, ind2 = intersect1d(op_source,missed_nodes) | |
| op_source = op_source[ind1] | |
| op_target = op_target[ind1] | |
| if interps is not None: | |
| op_interps = op_interps[ind1] | |
| if self.dilation > 1: | |
| for _ in range(1, self.dilation): | |
| vals, ind1, ind2 = intersect1d(op_target,edge_index[0,op_dir]) | |
| op_source = op_source[ind1] | |
| op_target = edge_index[1,op_dir][ind2] | |
| if interps is not None: | |
| op_interps = op_interps[ind1] | |
| # Main Calculation | |
| if interps is None: | |
| result = torch.matmul(x[op_target], self.weight[s]) | |
| else: | |
| result = op_interps * torch.matmul(x[op_target], self.weight[s]) | |
| out.index_add_(0,op_source,result) | |
| if self.padding_mode == 'normalize': | |
| dir_count[torch.unique(cur_source)] += 1 | |
| else: | |
| width = self.kernel_size//2 | |
| horiz = torch.arange(-width,width+1).to(x.device) | |
| vert = torch.arange(-width,width+1).to(x.device) | |
| right = torch.where(selections == 1)[0] | |
| left = torch.where(selections == 5)[0] | |
| down = torch.where(selections == 7)[0] | |
| up = torch.where(selections == 3)[0] | |
| center = torch.where(selections == 0)[0] | |
| # Find the appropriate node for each selection by stepping through connecting edges | |
| s = 0 | |
| for i in range(self.kernel_size): | |
| for j in range(self.kernel_size): | |
| x_loc = horiz[j] | |
| y_loc = vert[i] | |
| cur_source = edge_index[0,center] #Starting location | |
| cur_target = edge_index[1,center] | |
| if interps is not None: | |
| cur_interps = interps[center] | |
| cur_interps = torch.unsqueeze(cur_interps,dim=1) | |
| #print(torch.sum(cur_target-cur_source)) | |
| #print(cur_target.shape) | |
| if x_loc < 0: | |
| for _ in range(self.dilation*abs(x_loc)): | |
| vals, ind1, ind2 = intersect1d(cur_target,edge_index[0,left]) | |
| cur_source = cur_source[ind1] | |
| cur_target = edge_index[1,left][ind2] | |
| if interps is not None: | |
| cur_interps = cur_interps[ind1] | |
| if x_loc > 0: | |
| for _ in range(self.dilation*abs(x_loc)): | |
| vals, ind1, ind2 = intersect1d(cur_target,edge_index[0,right]) | |
| cur_source = cur_source[ind1] | |
| cur_target = edge_index[1,right][ind2] | |
| if interps is not None: | |
| cur_interps = cur_interps[ind1] | |
| if y_loc < 0: | |
| for _ in range(self.dilation*abs(y_loc)): | |
| vals, ind1, ind2 = intersect1d(cur_target,edge_index[0,up]) | |
| cur_source = cur_source[ind1] | |
| cur_target = edge_index[1,up][ind2] | |
| if interps is not None: | |
| cur_interps = cur_interps[ind1] | |
| if y_loc > 0: | |
| for _ in range(self.dilation*abs(y_loc)): | |
| vals, ind1, ind2 = intersect1d(cur_target,edge_index[0,down]) | |
| cur_source = cur_source[ind1] | |
| cur_target = edge_index[1,down][ind2] | |
| if interps is not None: | |
| cur_interps = cur_interps[ind1] | |
| # Main Calculation | |
| if interps is None: | |
| #out[cur_source] += torch.matmul(x[cur_target], self.weight[s]) | |
| result = torch.matmul(x[cur_target], self.weight[s]) | |
| else: | |
| #out[cur_source] += cur_interps*torch.matmul(x[cur_target], self.weight[s]) | |
| result = cur_interps*torch.matmul(x[cur_target], self.weight[s]) | |
| # Adding with duplicate indices | |
| out.index_add_(0,cur_source,result) | |
| if self.padding_mode == 'constant': | |
| missed_nodes = setdiff1d(all_nodes, cur_source) | |
| #out[missed_nodes] += torch.matmul(x_mean, self.weight[s]) | |
| out.index_add_(0,missed_nodes,torch.matmul(x_mean, self.weight[s])) | |
| if self.padding_mode == 'replicate': | |
| missed_nodes = setdiff1d(all_nodes, cur_source) | |
| #out[missed_nodes] += torch.matmul(x[missed_nodes], self.weight[s]) | |
| out.index_add_(0,missed_nodes,torch.matmul(x[missed_nodes], self.weight[s])) | |
| if self.padding_mode == 'reflect': | |
| raise ValueError("Reflect padding not yet implemented for larger kernels") | |
| if self.padding_mode == 'normalize': | |
| dir_count[torch.unique(cur_source)] += 1 | |
| s+=1 | |
| #print(self.selection_count/(dir_count + 1e-8)) | |
| #test_val = self.selection_count/(dir_count + 1e-8) | |
| # print(torch.max(test_val),torch.min(test_val),torch.mean(test_val)) | |
| if self.padding_mode == 'zeros': | |
| pass # Already accounted for in the graph structure, no further computation needed | |
| elif self.padding_mode == 'normalize': | |
| out *= self.selection_count/(dir_count + 1e-8) | |
| elif self.padding_mode == 'constant': | |
| pass # Processed earlier | |
| elif self.padding_mode == 'replicate': | |
| pass | |
| elif self.padding_mode == 'reflect': | |
| pass | |
| elif self.padding_mode == 'circular': | |
| raise ValueError("Circular padding cannot be generalized on a graph. Instead, create a graph with edges connecting to the wrapped around nodes") | |
| else: | |
| raise ValueError(f"Unknown padding mode: {self.padding_mode}") | |
| # Add bias if applicable | |
| out += self.bias | |
| return out | |
| def copy_weightsNxN(self,weight,bias=None): | |
| width = int(sqrt(self.selection_count)) | |
| # Assumes weight comes in as [output channels, input channels, row, col] | |
| for i in range(self.selection_count): | |
| self.weight[i] = weight[:,:,i//width,i%width].permute(1,0) | |
| def copy_weights3x3(self,weight,bias=None): | |
| # Assumes weight comes in as [output channels, input channels, row, col] | |
| # Assumes weight is a 3x3 | |
| # Current Ordering | |
| # 4 3 2 | |
| # 5 0 1 | |
| # 6 7 8 | |
| # Need to flip horizontally per implementation of convolution | |
| #self.weight[5] = weight[:,:,1,2].permute(1,0) | |
| #self.weight[7] = weight[:,:,0,1].permute(1,0) | |
| #self.weight[1] = weight[:,:,1,0].permute(1,0) | |
| #self.weight[3] = weight[:,:,2,1].permute(1,0) | |
| #self.weight[6] = weight[:,:,0,2].permute(1,0) | |
| #self.weight[8] = weight[:,:,0,0].permute(1,0) | |
| #self.weight[2] = weight[:,:,2,0].permute(1,0) | |
| #self.weight[4] = weight[:,:,2,2].permute(1,0) | |
| #self.weight[0] = weight[:,:,1,1].permute(1,0) | |
| self.weight[1] = weight[:,:,1,2].permute(1,0) | |
| self.weight[3] = weight[:,:,0,1].permute(1,0) | |
| self.weight[5] = weight[:,:,1,0].permute(1,0) | |
| self.weight[7] = weight[:,:,2,1].permute(1,0) | |
| self.weight[2] = weight[:,:,0,2].permute(1,0) | |
| self.weight[4] = weight[:,:,0,0].permute(1,0) | |
| self.weight[6] = weight[:,:,2,0].permute(1,0) | |
| self.weight[8] = weight[:,:,2,2].permute(1,0) | |
| self.weight[0] = weight[:,:,1,1].permute(1,0) | |
| def copy_weights1x1(self, weight, bias=None): | |
| self.weight[0] = weight[:,:,0,0].permute(1, 0) | |
| def copy_weights(self,weight,bias=None): | |
| if self.kernel_size == 3: | |
| self.copy_weights3x3(weight,bias) | |
| elif self.kernel_size == 1: | |
| self.copy_weights1x1(weight, bias) | |
| else: | |
| self.copy_weightsNxN(weight,bias) | |
| if bias is None: | |
| self.bias[:] = 0.0 | |
| else: | |
| self.bias = bias | |
| def __repr__(self): | |
| return '{}({}, {})'.format(self.__class__.__name__, self.in_channels, | |
| self.out_channels) | |