import torch from torch_geometric.nn import radius_graph, knn_graph import torch_geometric as tg from torch_geometric.utils import subgraph import utils from math import sqrt def getImPos(rows,cols,start_row=0,start_col=0): row_space = torch.arange(start_row,rows+start_row) col_space = torch.arange(start_col,cols+start_col) col_image,row_image = torch.meshgrid(col_space,row_space,indexing='xy') im_pos = torch.reshape(torch.stack((row_image,col_image),dim=-1),(rows*cols,2)) return im_pos def convertImPos(im_pos,flip_y=True): # Cast to float for clustering based methods pos2D = im_pos.float() # Switch rows,cols to x,y pos2D[:,[1,0]] = pos2D[:,[0,1]] if flip_y: # Flip to y-axis to match mathematical definition and edges2Selections settings pos2D[:,1] = torch.amax(pos2D[:,1]) - pos2D[:,1] return pos2D def grid2Edges(locs): # Assume locs are already spaced at a distance of 1 structure edge_index = radius_graph(locs,1.44,loop=True) return edge_index def radius2Edges(locs,r=1.0): edge_index = radius_graph(locs,r,loop=True) return edge_index def knn2Edges(locs,knn=9): edge_index = knn_graph(locs,knn,loop=True) return edge_index def surface2Edges(pos3D,normals,up_vector=None,k_neighbors=9): if up_vector is None: up_vector = torch.tensor([[0.0,1.0,0.0]]).to(pos3D.device) # K Nearest Neighbors graph edge_index = knn_graph(pos3D,k_neighbors,loop=True) # Cull neighbors based on normals (dot them together) culling = torch.sum(torch.multiply(normals[edge_index[1]],normals[edge_index[0]]),dim=1) edge_index = edge_index[:,torch.where(culling>0)[0]] # For each node, rotate based on Grahm-Schmidt Orthognalization norms = normals[edge_index[0]] z_dir = norms z_dir = z_dir/torch.linalg.norm(z_dir,dim=1,keepdims=True) # Make sure it is a unit vector #x_dir = torch.cross(up_vector,norms,dim=1) x_dir = utils.cross(up_vector,norms) # torch.cross doesn't broadcast properly in some versions of torch x_dir = x_dir/torch.linalg.norm(x_dir,dim=1,keepdims=True) #y_dir = torch.cross(norms,x_dir,dim=1) y_dir = utils.cross(norms,x_dir) y_dir = y_dir/torch.linalg.norm(y_dir,dim=1,keepdims=True) directions = (pos3D[edge_index[1]] - pos3D[edge_index[0]]) # Perform rotation by multiplying out rotation matrix temp = torch.clone(directions) # Buffer directions[:,0] = temp[:,0] * x_dir[:,0] + temp[:,1] * x_dir[:,1] + temp[:,2] * x_dir[:,2] directions[:,1] = temp[:,0] * y_dir[:,0] + temp[:,1] * y_dir[:,1] + temp[:,2] * y_dir[:,2] #directions[:,2] = temp[:,0] * z_dir[:,0] + temp[:,1] * z_dir[:,1] + temp[:,2] * z_dir[:,2] # Drop z coordinate directions = directions[:,:2] return edge_index, directions def edges2Selections(edge_index,directions,interpolated=True,bary_d=None,y_down=False): # Current Ordering # 4 3 2 # 5 0 1 # 6 7 8 if y_down: vectorList = torch.tensor([[1,0],[sqrt(2)/2,-sqrt(2)/2],[0,-1],[-sqrt(2)/2,-sqrt(2)/2],[-1,0],[-sqrt(2)/2,sqrt(2)/2],[0,1],[sqrt(2)/2,sqrt(2)/2]],dtype=torch.float).transpose(1,0) else: vectorList = torch.tensor([[1,0],[sqrt(2)/2,sqrt(2)/2],[0,1],[-sqrt(2)/2,sqrt(2)/2],[-1,0],[-sqrt(2)/2,-sqrt(2)/2],[0,-1],[sqrt(2)/2,-sqrt(2)/2]],dtype=torch.float).transpose(1,0) if interpolated: if bary_d is None: edge_index,selections,interps = interpolateSelections(edge_index,directions,vectorList) else: edge_index,selections,interps = interpolateSelections_barycentric(edge_index,directions,bary_d,vectorList) interps = normalizeEdges(edge_index,selections,interps) return edge_index,selections,interps else: selections = torch.argmax(torch.matmul(directions,vectorList),dim=1) + 1 selections[torch.where(torch.sum(torch.abs(directions),axis=1) == 0)] = 0 # Same cell selection return selections def makeEdges(prev_sources,prev_targets,prev_selections,sources,targets,selection,reverse=True): sources = sources.flatten() targets = targets.flatten() prev_sources += sources.tolist() prev_targets += targets.tolist() prev_selections += len(sources)*[selection] if reverse: prev_sources += targets prev_targets += sources prev_selections += len(sources)*[utils.reverse_selection(selection)] return prev_sources,prev_targets,prev_selections def maskNodes(mask,x): node_mask = torch.where(mask) x = x[node_mask] return x def maskPoints(mask,x,y): mask = torch.squeeze(mask) x0 = torch.floor(x).long() x1 = x0 + 1 y0 = torch.floor(y).long() y1 = y0 + 1 x0 = torch.clip(x0, 0, mask.shape[1]-1); x1 = torch.clip(x1, 0, mask.shape[1]-1); y0 = torch.clip(y0, 0, mask.shape[0]-1); y1 = torch.clip(y1, 0, mask.shape[0]-1); Ma = mask[ y0, x0 ] Mb = mask[ y1, x0 ] Mc = mask[ y0, x1 ] Md = mask[ y1, x1 ] node_mask = torch.where(torch.logical_and(torch.logical_and(torch.logical_and(Ma,Mb),Mc),Md))[0] return node_mask def maskGraph(mask,edge_index,selections,interps=None): edge_index,_,edge_mask = subgraph(mask,edge_index,relabel_nodes=True,return_edge_mask=True) selections = selections[edge_mask] if interps: interps = interps[edge_mask] return edge_index, selections, interps else: return edge_index, selections def interpolateSelections(edge_index,directions,vectorList=None): if vectorList is None: # Current Ordering # 4 3 2 # 5 0 1 # 6 7 8 vectorList = torch.tensor([[1,0],[sqrt(2)/2,sqrt(2)/2],[0,1],[-sqrt(2)/2,sqrt(2)/2],[-1,0],[-sqrt(2)/2,-sqrt(2)/2],[0,-1],[sqrt(2)/2,-sqrt(2)/2]],dtype=torch.float).transpose(1,0) # Normalize directions for simplicity of calculations dir_norm = torch.linalg.norm(directions,dim=1,keepdims=True) directions = directions/dir_norm #locs = torch.where(dir_norm > 1)[0] #directions[locs] = directions[locs]/dir_norm[locs] values = torch.matmul(directions,vectorList) best = torch.unsqueeze(torch.argmax(values,dim=1),1) best_val = torch.take_along_dim(values,best,dim=1) # Look at both neighbors to see who is closer lower_val = torch.take_along_dim(values,(best-1) % 8,dim=1) upper_val = torch.take_along_dim(values,(best+1) % 8,dim=1) comp_vals = torch.cat((lower_val,upper_val),dim=1) second_best_vals = torch.amax(comp_vals,dim=1) second_best = torch.argmax(comp_vals,dim=1) # Find the interpolation value (in terms of angles) best_val = torch.minimum(best_val[:,0],torch.tensor(1,device=directions.device)) # Prep for arccos function angle_best = torch.arccos(best_val) angle_second_best = torch.arccos(second_best_vals) angle_vals = angle_best/(angle_second_best + angle_best) # Use negative values for clockwise selections clockwise = torch.where(second_best == 0)[0] angle_vals[clockwise] = -angle_vals[clockwise] # Handle computation problems at the poles angle_vals = torch.nan_to_num(angle_vals) # Make Selections selections = best[:,0] + 1 # Same cell selection same_locs = torch.where(edge_index[0] == edge_index[1]) selections[same_locs] = 0 angle_vals[same_locs] = 0 # Make starting interp_values interps = torch.ones_like(angle_vals) interps -= torch.abs(angle_vals) # Add new edges pos_interp_locs = torch.where(angle_vals > 1e-2)[0] pos_interps = angle_vals[pos_interp_locs] pos_edges = edge_index[:,pos_interp_locs] pos_selections = selections[pos_interp_locs] + 1 pos_selections[torch.where(pos_selections>8)] = 1 # Account for wrap around neg_interp_locs = torch.where(angle_vals < -1e-2)[0] neg_interps = torch.abs(angle_vals[neg_interp_locs]) neg_edges = edge_index[:,neg_interp_locs] neg_selections = selections[neg_interp_locs] - 1 neg_selections[torch.where(neg_selections<1)] = 8 # Account for wrap around edge_index = torch.cat((edge_index,pos_edges,neg_edges),dim=1) selections = torch.cat((selections,pos_selections,neg_selections),dim=0) interps = torch.cat((interps,pos_interps,neg_interps),dim=0) return edge_index,selections,interps def interpolateSelections_barycentric(edge_index,directions,d,vectorList=None): if vectorList is None: # Current Ordering # 4 3 2 # 5 0 1 # 6 7 8 vectorList = torch.tensor([[1,0],[sqrt(2)/2,-sqrt(2)/2],[0,-1],[-sqrt(2)/2,-sqrt(2)/2],[-1,0],[-sqrt(2)/2,sqrt(2)/2],[0,1],[sqrt(2)/2,sqrt(2)/2]],dtype=torch.float).transpose(1,0).to(directions.device) # Preprune central selections and reappend them at the end same_locs = torch.where(edge_index[0] == edge_index[1]) same_edges = edge_index[:,same_locs[0]] different_locs = torch.where(edge_index[0] != edge_index[1]) edge_index = edge_index[:,different_locs[0]] directions = directions[different_locs[0]] # Normalize directions for simplicity of calculations dir_norm = torch.linalg.norm(directions,dim=1,keepdims=True) unit_directions = directions/dir_norm #locs = torch.where(dir_norm > 1)[0] #directions[locs] = directions[locs]/dir_norm[locs] values = torch.matmul(unit_directions,vectorList) best = torch.unsqueeze(torch.argmax(values,dim=1),1) #best_val = torch.take_along_dim(values,best,dim=1) # Look at both neighbors to see who is closer lower_val = torch.take_along_dim(values,(best-1) % 8,dim=1) upper_val = torch.take_along_dim(values,(best+1) % 8,dim=1) comp_vals = torch.cat((lower_val,upper_val),dim=1) second_best = torch.argmax(comp_vals,dim=1) #second_best_vals = torch.amax(comp_vals,dim=1) # Convert into uv cooridnates for barycentric interpolation calculation # /| # / |v # /__| # u scaled_directions = torch.abs(directions/d) u = torch.amax(scaled_directions,dim=1) v = torch.amin(scaled_directions,dim=1) # Force coordinates to be within the triangle boundary_check = torch.where(u > d) v[boundary_check] /= u[boundary_check] u[boundary_check] = 1.0 # Precalculated barycentric values from linear matrix solve I0 = 1 - u I1 = u - v I2 = v # Make first selections and proper interps selections = best[:,0] + 1 interps = I1 even_sels = torch.where(selections % 2 == 0) interps[even_sels] = I2[even_sels] # Corners get different weights # Make new edges for the central selections central_edges = torch.clone(edge_index).to(edge_index.device) central_selections = torch.zeros_like(selections) central_interps = I0 # Make new edges for the last selection pos_locs = torch.where(second_best==1)[0] pos_edges = edge_index[:,pos_locs] pos_selections = selections[pos_locs] + 1 pos_selections[torch.where(pos_selections>8)] = 1 #Account for wrap around pos_interps = I1[pos_locs] even_sels = torch.where(pos_selections % 2 == 0) pos_interps[even_sels] = I2[pos_locs][even_sels] neg_locs = torch.where(second_best==0)[0] neg_edges = edge_index[:,neg_locs] neg_selections = selections[neg_locs] - 1 neg_selections[torch.where(neg_selections<1)] = 8 # Account for wrap around neg_interps = I1[neg_locs] even_sels = torch.where(neg_selections % 2 == 0) neg_interps[even_sels] = I2[neg_locs][even_sels] # Account for the previously pruned same node edges same_selections = torch.zeros(same_edges.shape[1],dtype=torch.long) same_interps = torch.ones(same_edges.shape[1],dtype=torch.float) # Combine edge_index = torch.cat((edge_index,central_edges,pos_edges,neg_edges,same_edges),dim=1) selections = torch.cat((selections,central_selections,pos_selections,neg_selections,same_selections),dim=0) interps = torch.cat((interps,central_interps,pos_interps,neg_interps,same_interps),dim=0) #edge_index = torch.cat((edge_index,central_edges,pos_edges,neg_edges),dim=1) #selections = torch.cat((selections,central_selections,pos_selections,neg_selections),dim=0) #interps = torch.cat((interps,central_interps,pos_interps,neg_interps),dim=0) # Account for edges to the same node #same_locs = torch.where(edge_index[0] == edge_index[1]) #selections[same_locs] = 0 #interps[same_locs] = 1 return edge_index,selections,interps def normalizeEdges(edge_index,selections,interps=None,kernel_norm=False): '''Given an edge_index and selections, normalize the edges for each node so that aggregation of edges with interps = 1. If interps is given, use a weighted average. if kernel_norm = True, account for missing selections by increasing weight on other selections.''' N = torch.max(edge_index) + 1 S = torch.max(selections) + 1 total_weight = torch.zeros((N,S),dtype=torch.float).to(edge_index.device) if interps is None: interps = torch.ones(len(selections),dtype=torch.float).to(edge_index.device) # Aggregate all edges to determine normalizations per selection nodes = edge_index[0] #total_weight[nodes,selections] += interps total_weight.index_put_((nodes,selections),interps,accumulate=True) # Reassign interps accordingly if kernel_norm: row_totals = torch.sum(total_weight,dim=1) interps = interps * S/row_totals[nodes] else: norms = total_weight[nodes,selections] norms[torch.where(norms < 1e-6)] = 1e-6 # Avoid divide by zero error interps = interps/norms return interps def simplifyGraph(edge_index,selections,edge_lengths): # Take the shortest edge for the set of the same selections on a given node num_edges = edge_index.shape[1] # Keep track of which nodes have been visited keep_edges = torch.zeros(num_edges,dtype=torch.bool).to(edge_index.device) previous_best_distance = 100000*torch.ones((torch.amax(edge_index)+1,torch.amax(selections)+1),dtype=torch.long).to(edge_index.device) previous_best_edge = -1*torch.ones((torch.amax(edge_index)+1,torch.amax(selections)+1),dtype=torch.long).to(edge_index.device) for i in range(num_edges): start_node = edge_index[0,i] #end_node = edge_index[1,i] selection = selections[i] distance = edge_lengths[i] if distance < previous_best_distance[start_node,selection]: previous_best_distance[start_node,selection] = distance keep_edges[i] = True prev = previous_best_edge[start_node,selection] if prev != -1: keep_edges[prev] = False previous_best_edge[start_node,selection] = i edge_index = edge_index[:,torch.where(keep_edges)[0]] selections = selections[torch.where(keep_edges)] return edge_index, selections