FastSplatStyler / selectionConv.py
incrl's picture
Initial Upload (attempt 2)
5b557cf verified
from typing import Union, Tuple
from torch_geometric.typing import OptTensor, OptPairTensor, Adj, Size
import torch
from torch import Tensor
from torch.nn.functional import conv2d
from torch_sparse import SparseTensor, matmul
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils.loop import contains_self_loops
from torch.nn.parameter import Parameter
from torch_geometric.nn.inits import glorot, zeros
from math import sqrt
import numpy as np
def intersect1d(tensor1,tensor2):
device = tensor1.device
result, ind1, ind2 = np.intersect1d(tensor1.cpu().numpy(),tensor2.cpu().numpy(),return_indices=True)
return torch.tensor(result).to(device), torch.tensor(ind1).to(device), torch.tensor(ind2).to(device)
def setdiff1d(tensor1,tensor2):
device = tensor1.device
result = np.setdiff1d(tensor1.cpu().numpy(),tensor2.cpu().numpy())
return torch.tensor(result).to(device)
class SelectionConv(MessagePassing):
def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3, dilation = 1, padding_mode = 'zeros', **kwargs):
super(SelectionConv, self).__init__(**kwargs)
self.kernel_size = kernel_size
self.in_channels = in_channels
self.out_channels = out_channels
self.dilation = dilation
self.padding_mode = padding_mode
self.selection_count = kernel_size * kernel_size
#self.has_self_loops = has_self_loops
self.weight = Parameter(torch.randn(self.selection_count,in_channels,out_channels,dtype=torch.float))
torch.nn.init.uniform_(self.weight, a=-0.1, b=0.1)
#torch.nn.init.normal_(self.weight)
self.bias = Parameter(torch.randn(out_channels,dtype=torch.float))
torch.nn.init.uniform_(self.bias, a=0.0, b=0.1)
def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj, selections: Tensor, interps = None) -> Tensor:
""""""
all_nodes = torch.arange(x.shape[0]).to(x.device)
if self.padding_mode == 'constant':
# Constant value of the average of the all the nodes
x_mean = torch.mean(x,dim=0)
out = torch.zeros((x.shape[0],self.out_channels)).to(x.device)
if self.padding_mode == 'normalize':
dir_count = torch.zeros((x.shape[0],1)).to(x.device)
if self.kernel_size == 1 or self.kernel_size == 3:
# Find the appropriate node for each selection by stepping through connecting edges
for s in range(self.selection_count):
cur_dir = torch.where(selections == s)[0]
cur_source = edge_index[0,cur_dir]
cur_target = edge_index[1,cur_dir]
if interps is not None:
cur_interps = interps[cur_dir]
cur_interps = torch.unsqueeze(cur_interps,dim=1)
#print(torch.amin(cur_interps),torch.amax(cur_interps))
if self.dilation > 1:
for _ in range(1, self.dilation):
vals, ind1, ind2 = intersect1d(cur_target,edge_index[0,cur_dir])
cur_source = cur_source[ind1]
cur_target = edge_index[1,cur_dir][ind2]
if interps is not None:
cur_interps = cur_interps[ind1]
# Main Calculation
if interps is None:
#out[cur_source] += torch.matmul(x[cur_target], self.weight[s])
result = torch.matmul(x[cur_target], self.weight[s])
else:
#out[cur_source] += cur_interps*torch.matmul(x[cur_target], self.weight[s])
result = cur_interps*torch.matmul(x[cur_target], self.weight[s])
# Adding with duplicate indices
out.index_add_(0,cur_source,result)
# Sanity check
#from tqdm import tqdm
#for i,node in enumerate(tqdm(cur_source)):
# out[node] += result[i]
if self.padding_mode == 'constant':
missed_nodes = setdiff1d(all_nodes, cur_source)
#out[missed_nodes] += torch.matmul(x_mean, self.weight[s])
out.index_add_(0,missed_nodes,torch.matmul(x_mean, self.weight[s]))
if self.padding_mode == 'replicate':
missed_nodes = setdiff1d(all_nodes, cur_source)
#out[missed_nodes] += torch.matmul(x[missed_nodes], self.weight[s])
out.index_add_(0,missed_nodes,torch.matmul(x[missed_nodes], self.weight[s]))
if self.padding_mode == 'reflect':
missed_nodes = setdiff1d(all_nodes, cur_source)
opposite = s+4
if opposite > 8:
opposite = opposite % 9 + 1
op_dir = torch.where(selections == opposite)[0]
op_source = edge_index[0,op_dir]
op_target = edge_index[1,op_dir]
if interps is not None:
op_interps = interps[op_dir]
op_interps = torch.unsqueeze(op_interps,dim=1)
# Only take edges that are part of missed nodes
vals, ind1, ind2 = intersect1d(op_source,missed_nodes)
op_source = op_source[ind1]
op_target = op_target[ind1]
if interps is not None:
op_interps = op_interps[ind1]
if self.dilation > 1:
for _ in range(1, self.dilation):
vals, ind1, ind2 = intersect1d(op_target,edge_index[0,op_dir])
op_source = op_source[ind1]
op_target = edge_index[1,op_dir][ind2]
if interps is not None:
op_interps = op_interps[ind1]
# Main Calculation
if interps is None:
result = torch.matmul(x[op_target], self.weight[s])
else:
result = op_interps * torch.matmul(x[op_target], self.weight[s])
out.index_add_(0,op_source,result)
if self.padding_mode == 'normalize':
dir_count[torch.unique(cur_source)] += 1
else:
width = self.kernel_size//2
horiz = torch.arange(-width,width+1).to(x.device)
vert = torch.arange(-width,width+1).to(x.device)
right = torch.where(selections == 1)[0]
left = torch.where(selections == 5)[0]
down = torch.where(selections == 7)[0]
up = torch.where(selections == 3)[0]
center = torch.where(selections == 0)[0]
# Find the appropriate node for each selection by stepping through connecting edges
s = 0
for i in range(self.kernel_size):
for j in range(self.kernel_size):
x_loc = horiz[j]
y_loc = vert[i]
cur_source = edge_index[0,center] #Starting location
cur_target = edge_index[1,center]
if interps is not None:
cur_interps = interps[center]
cur_interps = torch.unsqueeze(cur_interps,dim=1)
#print(torch.sum(cur_target-cur_source))
#print(cur_target.shape)
if x_loc < 0:
for _ in range(self.dilation*abs(x_loc)):
vals, ind1, ind2 = intersect1d(cur_target,edge_index[0,left])
cur_source = cur_source[ind1]
cur_target = edge_index[1,left][ind2]
if interps is not None:
cur_interps = cur_interps[ind1]
if x_loc > 0:
for _ in range(self.dilation*abs(x_loc)):
vals, ind1, ind2 = intersect1d(cur_target,edge_index[0,right])
cur_source = cur_source[ind1]
cur_target = edge_index[1,right][ind2]
if interps is not None:
cur_interps = cur_interps[ind1]
if y_loc < 0:
for _ in range(self.dilation*abs(y_loc)):
vals, ind1, ind2 = intersect1d(cur_target,edge_index[0,up])
cur_source = cur_source[ind1]
cur_target = edge_index[1,up][ind2]
if interps is not None:
cur_interps = cur_interps[ind1]
if y_loc > 0:
for _ in range(self.dilation*abs(y_loc)):
vals, ind1, ind2 = intersect1d(cur_target,edge_index[0,down])
cur_source = cur_source[ind1]
cur_target = edge_index[1,down][ind2]
if interps is not None:
cur_interps = cur_interps[ind1]
# Main Calculation
if interps is None:
#out[cur_source] += torch.matmul(x[cur_target], self.weight[s])
result = torch.matmul(x[cur_target], self.weight[s])
else:
#out[cur_source] += cur_interps*torch.matmul(x[cur_target], self.weight[s])
result = cur_interps*torch.matmul(x[cur_target], self.weight[s])
# Adding with duplicate indices
out.index_add_(0,cur_source,result)
if self.padding_mode == 'constant':
missed_nodes = setdiff1d(all_nodes, cur_source)
#out[missed_nodes] += torch.matmul(x_mean, self.weight[s])
out.index_add_(0,missed_nodes,torch.matmul(x_mean, self.weight[s]))
if self.padding_mode == 'replicate':
missed_nodes = setdiff1d(all_nodes, cur_source)
#out[missed_nodes] += torch.matmul(x[missed_nodes], self.weight[s])
out.index_add_(0,missed_nodes,torch.matmul(x[missed_nodes], self.weight[s]))
if self.padding_mode == 'reflect':
raise ValueError("Reflect padding not yet implemented for larger kernels")
if self.padding_mode == 'normalize':
dir_count[torch.unique(cur_source)] += 1
s+=1
#print(self.selection_count/(dir_count + 1e-8))
#test_val = self.selection_count/(dir_count + 1e-8)
# print(torch.max(test_val),torch.min(test_val),torch.mean(test_val))
if self.padding_mode == 'zeros':
pass # Already accounted for in the graph structure, no further computation needed
elif self.padding_mode == 'normalize':
out *= self.selection_count/(dir_count + 1e-8)
elif self.padding_mode == 'constant':
pass # Processed earlier
elif self.padding_mode == 'replicate':
pass
elif self.padding_mode == 'reflect':
pass
elif self.padding_mode == 'circular':
raise ValueError("Circular padding cannot be generalized on a graph. Instead, create a graph with edges connecting to the wrapped around nodes")
else:
raise ValueError(f"Unknown padding mode: {self.padding_mode}")
# Add bias if applicable
out += self.bias
return out
def copy_weightsNxN(self,weight,bias=None):
width = int(sqrt(self.selection_count))
# Assumes weight comes in as [output channels, input channels, row, col]
for i in range(self.selection_count):
self.weight[i] = weight[:,:,i//width,i%width].permute(1,0)
def copy_weights3x3(self,weight,bias=None):
# Assumes weight comes in as [output channels, input channels, row, col]
# Assumes weight is a 3x3
# Current Ordering
# 4 3 2
# 5 0 1
# 6 7 8
# Need to flip horizontally per implementation of convolution
#self.weight[5] = weight[:,:,1,2].permute(1,0)
#self.weight[7] = weight[:,:,0,1].permute(1,0)
#self.weight[1] = weight[:,:,1,0].permute(1,0)
#self.weight[3] = weight[:,:,2,1].permute(1,0)
#self.weight[6] = weight[:,:,0,2].permute(1,0)
#self.weight[8] = weight[:,:,0,0].permute(1,0)
#self.weight[2] = weight[:,:,2,0].permute(1,0)
#self.weight[4] = weight[:,:,2,2].permute(1,0)
#self.weight[0] = weight[:,:,1,1].permute(1,0)
self.weight[1] = weight[:,:,1,2].permute(1,0)
self.weight[3] = weight[:,:,0,1].permute(1,0)
self.weight[5] = weight[:,:,1,0].permute(1,0)
self.weight[7] = weight[:,:,2,1].permute(1,0)
self.weight[2] = weight[:,:,0,2].permute(1,0)
self.weight[4] = weight[:,:,0,0].permute(1,0)
self.weight[6] = weight[:,:,2,0].permute(1,0)
self.weight[8] = weight[:,:,2,2].permute(1,0)
self.weight[0] = weight[:,:,1,1].permute(1,0)
def copy_weights1x1(self, weight, bias=None):
self.weight[0] = weight[:,:,0,0].permute(1, 0)
def copy_weights(self,weight,bias=None):
if self.kernel_size == 3:
self.copy_weights3x3(weight,bias)
elif self.kernel_size == 1:
self.copy_weights1x1(weight, bias)
else:
self.copy_weightsNxN(weight,bias)
if bias is None:
self.bias[:] = 0.0
else:
self.bias = bias
def __repr__(self):
return '{}({}, {})'.format(self.__class__.__name__, self.in_channels,
self.out_channels)