FastSplatStyler / pooling.py
incrl's picture
Initial Upload (attempt 2)
5b557cf verified
import torch
from torch_scatter import scatter
def avgPoolKernel(x,edge_index,selections,cluster,kernel_size=2,even_dirs=[0,1,7,8]):
is_even = kernel_size % 2 == 0
full_passes = kernel_size//2 - int(is_even)
# Assumes the lowest number node index is the topleft most position in the cluster
indices = torch.arange(len(x)).to(x.device)
# Find the minimum node index in each cluster and select those x values
picks = scatter(indices, cluster, dim=0, reduce='min')
# Send max pool messages the appropriate number of times
for _ in range(full_passes):
message = x[edge_index[1]]
x = scatter(message,edge_index[0],dim=0,reduce='mean') # Aggregate
# Even kernel_sizes are not symetric and are lopsided towards the bottom right corner
# Repeat the process one more time going to the bottom right
if is_even:
# Prefilter edge_index
keep = torch.zeros_like(selections,dtype=torch.bool).to(x.device)
for i in even_dirs:
keep[torch.where(selections == i)] = True
even_edge_index = edge_index[:,torch.where(keep)[0]]
message = x[even_edge_index[1]]
x = scatter(message,even_edge_index[0],dim=0,reduce='mean') # Aggregate
# Take the previously selected nodes
x = x[picks]
return x
def maxPoolKernel(x,edge_index,selections,cluster,kernel_size=2,even_dirs=[0,1,7,8]):
is_even = kernel_size % 2 == 0
full_passes = kernel_size//2 - int(is_even)
# Assumes the lowest number node index is the topleft most position in the cluster
indices = torch.arange(len(x)).to(x.device)
# Find the minimum node index in each cluster and select those x values
picks = scatter(indices, cluster, dim=0, reduce='min')
# Send max pool messages the appropriate number of times
for _ in range(full_passes):
message = x[edge_index[1]]
x = scatter(message,edge_index[0],dim=0,reduce='max') # Aggregate
# Even kernel_sizes are not symetric and are lopsided towards the bottom right corner
# Repeat the process one more time going to the bottom right
if is_even:
# Prefilter edge_index
keep = torch.zeros_like(selections,dtype=torch.bool).to(x.device)
for i in even_dirs:
keep[torch.where(selections == i)] = True
even_edge_index = edge_index[:,torch.where(keep)[0]]
message = x[even_edge_index[1]]
x = scatter(message,even_edge_index[0],dim=0,reduce='max') # Aggregate
# Take the previously selected nodes
x = x[picks]
return x
def stridePoolCluster(x,cluster):
# Assumes the lowest number node index is the topleft most position in the cluster
indices = torch.arange(len(x)).to(x.device)
# Find the minimum node index in each cluster and select those x values
picks = scatter(indices, cluster, dim=0, reduce='min')
x = x[picks]
return x
def maxPoolCluster(x,cluster):
x = scatter(x, cluster, dim=0, reduce='max')
return x
def avgPoolCluster(x,cluster,edge_index=None, edge_weight=None):
x = scatter(x, cluster, dim=0, reduce='mean')
return x
def unpoolInterpolated(x,cluster,up_edge_index,up_interps=None):
if up_interps is None:
return unpoolEdgeAverage(x,cluster,up_edge_index)
# Determine node averages based on based on interps
target_clusters = cluster[up_edge_index[1]]
node_vals = x[target_clusters]*up_interps.unsqueeze(1)
x = scatter(node_vals,up_edge_index[0],dim=0,reduce='add')
norm = scatter(up_interps,up_edge_index[0],dim=0)
x/=norm.unsqueeze(1)
return x
def unpoolBilinear(x,cluster,up_edge_index,up_selections,selection_dirs=[0,1,7,8]):
# Remove edges that won't be used for the bilinear interpolation calculation
keep = torch.zeros_like(up_selections,dtype=torch.bool).to(x.device)
for i in selection_dirs:
keep[torch.where(up_selections == i)] = True
ref_edge_index = up_edge_index[:,torch.where(keep)[0]]
cluster_index = torch.vstack((ref_edge_index[0],cluster[ref_edge_index[1]]))
cluster_index = torch.unique(cluster_index,dim=1)
x = scatter(x[cluster_index[1]],cluster_index[0],dim=0,reduce='mean')
return x
def unpoolEdgeAverage(x,cluster,up_edge_index,weighted=True):
# Interpolates based on the number of connections to each previous cluster. Works best with dense data.
# If weighted = False, clusters are weighted equally regardless of the number of connections
if weighted:
target_clusters = cluster[up_edge_index[1]]
x = scatter(x[target_clusters],up_edge_index[0],dim=0,reduce='mean')
else:
cluster_index = torch.vstack((up_edge_index[0],cluster[up_edge_index[1]]))
cluster_index = torch.unique(cluster_index,dim=1)
x = scatter(x[cluster_index[1]],cluster_index[0],dim=0,reduce='mean')
return x
def unpoolCluster(x,cluster):
return x[cluster]