Spaces:
Sleeping
Sleeping
| import torch | |
| from torch_scatter import scatter | |
| def avgPoolKernel(x,edge_index,selections,cluster,kernel_size=2,even_dirs=[0,1,7,8]): | |
| is_even = kernel_size % 2 == 0 | |
| full_passes = kernel_size//2 - int(is_even) | |
| # Assumes the lowest number node index is the topleft most position in the cluster | |
| indices = torch.arange(len(x)).to(x.device) | |
| # Find the minimum node index in each cluster and select those x values | |
| picks = scatter(indices, cluster, dim=0, reduce='min') | |
| # Send max pool messages the appropriate number of times | |
| for _ in range(full_passes): | |
| message = x[edge_index[1]] | |
| x = scatter(message,edge_index[0],dim=0,reduce='mean') # Aggregate | |
| # Even kernel_sizes are not symetric and are lopsided towards the bottom right corner | |
| # Repeat the process one more time going to the bottom right | |
| if is_even: | |
| # Prefilter edge_index | |
| keep = torch.zeros_like(selections,dtype=torch.bool).to(x.device) | |
| for i in even_dirs: | |
| keep[torch.where(selections == i)] = True | |
| even_edge_index = edge_index[:,torch.where(keep)[0]] | |
| message = x[even_edge_index[1]] | |
| x = scatter(message,even_edge_index[0],dim=0,reduce='mean') # Aggregate | |
| # Take the previously selected nodes | |
| x = x[picks] | |
| return x | |
| def maxPoolKernel(x,edge_index,selections,cluster,kernel_size=2,even_dirs=[0,1,7,8]): | |
| is_even = kernel_size % 2 == 0 | |
| full_passes = kernel_size//2 - int(is_even) | |
| # Assumes the lowest number node index is the topleft most position in the cluster | |
| indices = torch.arange(len(x)).to(x.device) | |
| # Find the minimum node index in each cluster and select those x values | |
| picks = scatter(indices, cluster, dim=0, reduce='min') | |
| # Send max pool messages the appropriate number of times | |
| for _ in range(full_passes): | |
| message = x[edge_index[1]] | |
| x = scatter(message,edge_index[0],dim=0,reduce='max') # Aggregate | |
| # Even kernel_sizes are not symetric and are lopsided towards the bottom right corner | |
| # Repeat the process one more time going to the bottom right | |
| if is_even: | |
| # Prefilter edge_index | |
| keep = torch.zeros_like(selections,dtype=torch.bool).to(x.device) | |
| for i in even_dirs: | |
| keep[torch.where(selections == i)] = True | |
| even_edge_index = edge_index[:,torch.where(keep)[0]] | |
| message = x[even_edge_index[1]] | |
| x = scatter(message,even_edge_index[0],dim=0,reduce='max') # Aggregate | |
| # Take the previously selected nodes | |
| x = x[picks] | |
| return x | |
| def stridePoolCluster(x,cluster): | |
| # Assumes the lowest number node index is the topleft most position in the cluster | |
| indices = torch.arange(len(x)).to(x.device) | |
| # Find the minimum node index in each cluster and select those x values | |
| picks = scatter(indices, cluster, dim=0, reduce='min') | |
| x = x[picks] | |
| return x | |
| def maxPoolCluster(x,cluster): | |
| x = scatter(x, cluster, dim=0, reduce='max') | |
| return x | |
| def avgPoolCluster(x,cluster,edge_index=None, edge_weight=None): | |
| x = scatter(x, cluster, dim=0, reduce='mean') | |
| return x | |
| def unpoolInterpolated(x,cluster,up_edge_index,up_interps=None): | |
| if up_interps is None: | |
| return unpoolEdgeAverage(x,cluster,up_edge_index) | |
| # Determine node averages based on based on interps | |
| target_clusters = cluster[up_edge_index[1]] | |
| node_vals = x[target_clusters]*up_interps.unsqueeze(1) | |
| x = scatter(node_vals,up_edge_index[0],dim=0,reduce='add') | |
| norm = scatter(up_interps,up_edge_index[0],dim=0) | |
| x/=norm.unsqueeze(1) | |
| return x | |
| def unpoolBilinear(x,cluster,up_edge_index,up_selections,selection_dirs=[0,1,7,8]): | |
| # Remove edges that won't be used for the bilinear interpolation calculation | |
| keep = torch.zeros_like(up_selections,dtype=torch.bool).to(x.device) | |
| for i in selection_dirs: | |
| keep[torch.where(up_selections == i)] = True | |
| ref_edge_index = up_edge_index[:,torch.where(keep)[0]] | |
| cluster_index = torch.vstack((ref_edge_index[0],cluster[ref_edge_index[1]])) | |
| cluster_index = torch.unique(cluster_index,dim=1) | |
| x = scatter(x[cluster_index[1]],cluster_index[0],dim=0,reduce='mean') | |
| return x | |
| def unpoolEdgeAverage(x,cluster,up_edge_index,weighted=True): | |
| # Interpolates based on the number of connections to each previous cluster. Works best with dense data. | |
| # If weighted = False, clusters are weighted equally regardless of the number of connections | |
| if weighted: | |
| target_clusters = cluster[up_edge_index[1]] | |
| x = scatter(x[target_clusters],up_edge_index[0],dim=0,reduce='mean') | |
| else: | |
| cluster_index = torch.vstack((up_edge_index[0],cluster[up_edge_index[1]])) | |
| cluster_index = torch.unique(cluster_index,dim=1) | |
| x = scatter(x[cluster_index[1]],cluster_index[0],dim=0,reduce='mean') | |
| return x | |
| def unpoolCluster(x,cluster): | |
| return x[cluster] | |