FastSplatStyler / graph_helpers.py
incrl's picture
Initial Upload (attempt 2)
5b557cf verified
import torch
from torch_geometric.nn import radius_graph, knn_graph
import torch_geometric as tg
from torch_geometric.utils import subgraph
import utils
from math import sqrt
def getImPos(rows,cols,start_row=0,start_col=0):
row_space = torch.arange(start_row,rows+start_row)
col_space = torch.arange(start_col,cols+start_col)
col_image,row_image = torch.meshgrid(col_space,row_space,indexing='xy')
im_pos = torch.reshape(torch.stack((row_image,col_image),dim=-1),(rows*cols,2))
return im_pos
def convertImPos(im_pos,flip_y=True):
# Cast to float for clustering based methods
pos2D = im_pos.float()
# Switch rows,cols to x,y
pos2D[:,[1,0]] = pos2D[:,[0,1]]
if flip_y:
# Flip to y-axis to match mathematical definition and edges2Selections settings
pos2D[:,1] = torch.amax(pos2D[:,1]) - pos2D[:,1]
return pos2D
def grid2Edges(locs):
# Assume locs are already spaced at a distance of 1 structure
edge_index = radius_graph(locs,1.44,loop=True)
return edge_index
def radius2Edges(locs,r=1.0):
edge_index = radius_graph(locs,r,loop=True)
return edge_index
def knn2Edges(locs,knn=9):
edge_index = knn_graph(locs,knn,loop=True)
return edge_index
def surface2Edges(pos3D,normals,up_vector=None,k_neighbors=9):
if up_vector is None:
up_vector = torch.tensor([[0.0,1.0,0.0]]).to(pos3D.device)
# K Nearest Neighbors graph
edge_index = knn_graph(pos3D,k_neighbors,loop=True)
# Cull neighbors based on normals (dot them together)
culling = torch.sum(torch.multiply(normals[edge_index[1]],normals[edge_index[0]]),dim=1)
edge_index = edge_index[:,torch.where(culling>0)[0]]
# For each node, rotate based on Grahm-Schmidt Orthognalization
norms = normals[edge_index[0]]
z_dir = norms
z_dir = z_dir/torch.linalg.norm(z_dir,dim=1,keepdims=True) # Make sure it is a unit vector
#x_dir = torch.cross(up_vector,norms,dim=1)
x_dir = utils.cross(up_vector,norms) # torch.cross doesn't broadcast properly in some versions of torch
x_dir = x_dir/torch.linalg.norm(x_dir,dim=1,keepdims=True)
#y_dir = torch.cross(norms,x_dir,dim=1)
y_dir = utils.cross(norms,x_dir)
y_dir = y_dir/torch.linalg.norm(y_dir,dim=1,keepdims=True)
directions = (pos3D[edge_index[1]] - pos3D[edge_index[0]])
# Perform rotation by multiplying out rotation matrix
temp = torch.clone(directions) # Buffer
directions[:,0] = temp[:,0] * x_dir[:,0] + temp[:,1] * x_dir[:,1] + temp[:,2] * x_dir[:,2]
directions[:,1] = temp[:,0] * y_dir[:,0] + temp[:,1] * y_dir[:,1] + temp[:,2] * y_dir[:,2]
#directions[:,2] = temp[:,0] * z_dir[:,0] + temp[:,1] * z_dir[:,1] + temp[:,2] * z_dir[:,2]
# Drop z coordinate
directions = directions[:,:2]
return edge_index, directions
def edges2Selections(edge_index,directions,interpolated=True,bary_d=None,y_down=False):
# Current Ordering
# 4 3 2
# 5 0 1
# 6 7 8
if y_down:
vectorList = torch.tensor([[1,0],[sqrt(2)/2,-sqrt(2)/2],[0,-1],[-sqrt(2)/2,-sqrt(2)/2],[-1,0],[-sqrt(2)/2,sqrt(2)/2],[0,1],[sqrt(2)/2,sqrt(2)/2]],dtype=torch.float).transpose(1,0)
else:
vectorList = torch.tensor([[1,0],[sqrt(2)/2,sqrt(2)/2],[0,1],[-sqrt(2)/2,sqrt(2)/2],[-1,0],[-sqrt(2)/2,-sqrt(2)/2],[0,-1],[sqrt(2)/2,-sqrt(2)/2]],dtype=torch.float).transpose(1,0)
if interpolated:
if bary_d is None:
edge_index,selections,interps = interpolateSelections(edge_index,directions,vectorList)
else:
edge_index,selections,interps = interpolateSelections_barycentric(edge_index,directions,bary_d,vectorList)
interps = normalizeEdges(edge_index,selections,interps)
return edge_index,selections,interps
else:
selections = torch.argmax(torch.matmul(directions,vectorList),dim=1) + 1
selections[torch.where(torch.sum(torch.abs(directions),axis=1) == 0)] = 0 # Same cell selection
return selections
def makeEdges(prev_sources,prev_targets,prev_selections,sources,targets,selection,reverse=True):
sources = sources.flatten()
targets = targets.flatten()
prev_sources += sources.tolist()
prev_targets += targets.tolist()
prev_selections += len(sources)*[selection]
if reverse:
prev_sources += targets
prev_targets += sources
prev_selections += len(sources)*[utils.reverse_selection(selection)]
return prev_sources,prev_targets,prev_selections
def maskNodes(mask,x):
node_mask = torch.where(mask)
x = x[node_mask]
return x
def maskPoints(mask,x,y):
mask = torch.squeeze(mask)
x0 = torch.floor(x).long()
x1 = x0 + 1
y0 = torch.floor(y).long()
y1 = y0 + 1
x0 = torch.clip(x0, 0, mask.shape[1]-1);
x1 = torch.clip(x1, 0, mask.shape[1]-1);
y0 = torch.clip(y0, 0, mask.shape[0]-1);
y1 = torch.clip(y1, 0, mask.shape[0]-1);
Ma = mask[ y0, x0 ]
Mb = mask[ y1, x0 ]
Mc = mask[ y0, x1 ]
Md = mask[ y1, x1 ]
node_mask = torch.where(torch.logical_and(torch.logical_and(torch.logical_and(Ma,Mb),Mc),Md))[0]
return node_mask
def maskGraph(mask,edge_index,selections,interps=None):
edge_index,_,edge_mask = subgraph(mask,edge_index,relabel_nodes=True,return_edge_mask=True)
selections = selections[edge_mask]
if interps:
interps = interps[edge_mask]
return edge_index, selections, interps
else:
return edge_index, selections
def interpolateSelections(edge_index,directions,vectorList=None):
if vectorList is None:
# Current Ordering
# 4 3 2
# 5 0 1
# 6 7 8
vectorList = torch.tensor([[1,0],[sqrt(2)/2,sqrt(2)/2],[0,1],[-sqrt(2)/2,sqrt(2)/2],[-1,0],[-sqrt(2)/2,-sqrt(2)/2],[0,-1],[sqrt(2)/2,-sqrt(2)/2]],dtype=torch.float).transpose(1,0)
# Normalize directions for simplicity of calculations
dir_norm = torch.linalg.norm(directions,dim=1,keepdims=True)
directions = directions/dir_norm
#locs = torch.where(dir_norm > 1)[0]
#directions[locs] = directions[locs]/dir_norm[locs]
values = torch.matmul(directions,vectorList)
best = torch.unsqueeze(torch.argmax(values,dim=1),1)
best_val = torch.take_along_dim(values,best,dim=1)
# Look at both neighbors to see who is closer
lower_val = torch.take_along_dim(values,(best-1) % 8,dim=1)
upper_val = torch.take_along_dim(values,(best+1) % 8,dim=1)
comp_vals = torch.cat((lower_val,upper_val),dim=1)
second_best_vals = torch.amax(comp_vals,dim=1)
second_best = torch.argmax(comp_vals,dim=1)
# Find the interpolation value (in terms of angles)
best_val = torch.minimum(best_val[:,0],torch.tensor(1,device=directions.device)) # Prep for arccos function
angle_best = torch.arccos(best_val)
angle_second_best = torch.arccos(second_best_vals)
angle_vals = angle_best/(angle_second_best + angle_best)
# Use negative values for clockwise selections
clockwise = torch.where(second_best == 0)[0]
angle_vals[clockwise] = -angle_vals[clockwise]
# Handle computation problems at the poles
angle_vals = torch.nan_to_num(angle_vals)
# Make Selections
selections = best[:,0] + 1
# Same cell selection
same_locs = torch.where(edge_index[0] == edge_index[1])
selections[same_locs] = 0
angle_vals[same_locs] = 0
# Make starting interp_values
interps = torch.ones_like(angle_vals)
interps -= torch.abs(angle_vals)
# Add new edges
pos_interp_locs = torch.where(angle_vals > 1e-2)[0]
pos_interps = angle_vals[pos_interp_locs]
pos_edges = edge_index[:,pos_interp_locs]
pos_selections = selections[pos_interp_locs] + 1
pos_selections[torch.where(pos_selections>8)] = 1 # Account for wrap around
neg_interp_locs = torch.where(angle_vals < -1e-2)[0]
neg_interps = torch.abs(angle_vals[neg_interp_locs])
neg_edges = edge_index[:,neg_interp_locs]
neg_selections = selections[neg_interp_locs] - 1
neg_selections[torch.where(neg_selections<1)] = 8 # Account for wrap around
edge_index = torch.cat((edge_index,pos_edges,neg_edges),dim=1)
selections = torch.cat((selections,pos_selections,neg_selections),dim=0)
interps = torch.cat((interps,pos_interps,neg_interps),dim=0)
return edge_index,selections,interps
def interpolateSelections_barycentric(edge_index,directions,d,vectorList=None):
if vectorList is None:
# Current Ordering
# 4 3 2
# 5 0 1
# 6 7 8
vectorList = torch.tensor([[1,0],[sqrt(2)/2,-sqrt(2)/2],[0,-1],[-sqrt(2)/2,-sqrt(2)/2],[-1,0],[-sqrt(2)/2,sqrt(2)/2],[0,1],[sqrt(2)/2,sqrt(2)/2]],dtype=torch.float).transpose(1,0).to(directions.device)
# Preprune central selections and reappend them at the end
same_locs = torch.where(edge_index[0] == edge_index[1])
same_edges = edge_index[:,same_locs[0]]
different_locs = torch.where(edge_index[0] != edge_index[1])
edge_index = edge_index[:,different_locs[0]]
directions = directions[different_locs[0]]
# Normalize directions for simplicity of calculations
dir_norm = torch.linalg.norm(directions,dim=1,keepdims=True)
unit_directions = directions/dir_norm
#locs = torch.where(dir_norm > 1)[0]
#directions[locs] = directions[locs]/dir_norm[locs]
values = torch.matmul(unit_directions,vectorList)
best = torch.unsqueeze(torch.argmax(values,dim=1),1)
#best_val = torch.take_along_dim(values,best,dim=1)
# Look at both neighbors to see who is closer
lower_val = torch.take_along_dim(values,(best-1) % 8,dim=1)
upper_val = torch.take_along_dim(values,(best+1) % 8,dim=1)
comp_vals = torch.cat((lower_val,upper_val),dim=1)
second_best = torch.argmax(comp_vals,dim=1)
#second_best_vals = torch.amax(comp_vals,dim=1)
# Convert into uv cooridnates for barycentric interpolation calculation
# /|
# / |v
# /__|
# u
scaled_directions = torch.abs(directions/d)
u = torch.amax(scaled_directions,dim=1)
v = torch.amin(scaled_directions,dim=1)
# Force coordinates to be within the triangle
boundary_check = torch.where(u > d)
v[boundary_check] /= u[boundary_check]
u[boundary_check] = 1.0
# Precalculated barycentric values from linear matrix solve
I0 = 1 - u
I1 = u - v
I2 = v
# Make first selections and proper interps
selections = best[:,0] + 1
interps = I1
even_sels = torch.where(selections % 2 == 0)
interps[even_sels] = I2[even_sels] # Corners get different weights
# Make new edges for the central selections
central_edges = torch.clone(edge_index).to(edge_index.device)
central_selections = torch.zeros_like(selections)
central_interps = I0
# Make new edges for the last selection
pos_locs = torch.where(second_best==1)[0]
pos_edges = edge_index[:,pos_locs]
pos_selections = selections[pos_locs] + 1
pos_selections[torch.where(pos_selections>8)] = 1 #Account for wrap around
pos_interps = I1[pos_locs]
even_sels = torch.where(pos_selections % 2 == 0)
pos_interps[even_sels] = I2[pos_locs][even_sels]
neg_locs = torch.where(second_best==0)[0]
neg_edges = edge_index[:,neg_locs]
neg_selections = selections[neg_locs] - 1
neg_selections[torch.where(neg_selections<1)] = 8 # Account for wrap around
neg_interps = I1[neg_locs]
even_sels = torch.where(neg_selections % 2 == 0)
neg_interps[even_sels] = I2[neg_locs][even_sels]
# Account for the previously pruned same node edges
same_selections = torch.zeros(same_edges.shape[1],dtype=torch.long)
same_interps = torch.ones(same_edges.shape[1],dtype=torch.float)
# Combine
edge_index = torch.cat((edge_index,central_edges,pos_edges,neg_edges,same_edges),dim=1)
selections = torch.cat((selections,central_selections,pos_selections,neg_selections,same_selections),dim=0)
interps = torch.cat((interps,central_interps,pos_interps,neg_interps,same_interps),dim=0)
#edge_index = torch.cat((edge_index,central_edges,pos_edges,neg_edges),dim=1)
#selections = torch.cat((selections,central_selections,pos_selections,neg_selections),dim=0)
#interps = torch.cat((interps,central_interps,pos_interps,neg_interps),dim=0)
# Account for edges to the same node
#same_locs = torch.where(edge_index[0] == edge_index[1])
#selections[same_locs] = 0
#interps[same_locs] = 1
return edge_index,selections,interps
def normalizeEdges(edge_index,selections,interps=None,kernel_norm=False):
'''Given an edge_index and selections, normalize the edges for each node so that
aggregation of edges with interps = 1. If interps is given, use a weighted average.
if kernel_norm = True, account for missing selections by increasing weight on other selections.'''
N = torch.max(edge_index) + 1
S = torch.max(selections) + 1
total_weight = torch.zeros((N,S),dtype=torch.float).to(edge_index.device)
if interps is None:
interps = torch.ones(len(selections),dtype=torch.float).to(edge_index.device)
# Aggregate all edges to determine normalizations per selection
nodes = edge_index[0]
#total_weight[nodes,selections] += interps
total_weight.index_put_((nodes,selections),interps,accumulate=True)
# Reassign interps accordingly
if kernel_norm:
row_totals = torch.sum(total_weight,dim=1)
interps = interps * S/row_totals[nodes]
else:
norms = total_weight[nodes,selections]
norms[torch.where(norms < 1e-6)] = 1e-6 # Avoid divide by zero error
interps = interps/norms
return interps
def simplifyGraph(edge_index,selections,edge_lengths):
# Take the shortest edge for the set of the same selections on a given node
num_edges = edge_index.shape[1]
# Keep track of which nodes have been visited
keep_edges = torch.zeros(num_edges,dtype=torch.bool).to(edge_index.device)
previous_best_distance = 100000*torch.ones((torch.amax(edge_index)+1,torch.amax(selections)+1),dtype=torch.long).to(edge_index.device)
previous_best_edge = -1*torch.ones((torch.amax(edge_index)+1,torch.amax(selections)+1),dtype=torch.long).to(edge_index.device)
for i in range(num_edges):
start_node = edge_index[0,i]
#end_node = edge_index[1,i]
selection = selections[i]
distance = edge_lengths[i]
if distance < previous_best_distance[start_node,selection]:
previous_best_distance[start_node,selection] = distance
keep_edges[i] = True
prev = previous_best_edge[start_node,selection]
if prev != -1:
keep_edges[prev] = False
previous_best_edge[start_node,selection] = i
edge_index = edge_index[:,torch.where(keep_edges)[0]]
selections = selections[torch.where(keep_edges)]
return edge_index, selections