FastSplatStyler / clusters.py
incrl's picture
Initial Upload (attempt 2)
5b557cf verified
import torch
from torch_scatter import scatter
from torch_geometric.nn.pool.consecutive import consecutive_cluster
from torch_geometric.utils import add_self_loops, add_remaining_self_loops, remove_self_loops
from torch_geometric.nn import fps, knn
from torch_sparse import coalesce
import graph_helpers as gh
import sphere_helpers as sh
import mesh_helpers as mh
import math
from math import pi,sqrt
from warnings import warn
def makeImageClusters(pos2D,Nx,Ny,edge_index,selections,depth=1,device='cpu',stride=2):
clusters = []
edge_indexes = [torch.clone(edge_index).to(device)]
selections_list = [torch.clone(selections).to(device)]
for _ in range(depth):
Nx = Nx//stride
Ny = Ny//stride
cx,cy = getGrid(pos2D,Nx,Ny)
cluster, pos2D = gridCluster(pos2D,cx,cy,Nx)
edge_index, selections = selectionAverage(cluster, edge_index, selections)
clusters.append(torch.clone(cluster).to(device))
edge_indexes.append(torch.clone(edge_index).to(device))
selections_list.append(torch.clone(selections).to(device))
return clusters, edge_indexes, selections_list
def makeSphereClusters(pos3D,edge_index,selections,interps,rows,cols,cluster_method="layering",stride=2,bary_d=None,depth=1,device='cpu'):
clusters = []
edge_indexes = [torch.clone(edge_index).to(device)]
selections_list = [torch.clone(selections).to(device)]
interps_list = [torch.clone(interps).to(device)]
for _ in range(depth):
rows = rows//stride
cols = cols//stride
if bary_d is not None:
bary_d = bary_d*stride
if cluster_method == "equirec":
centroids, _ = sh.sampleSphere_Equirec(rows,cols)
elif cluster_method == "layering":
centroids, _ = sh.sampleSphere_Layering(rows)
elif cluster_method == "spiral":
centroids, _ = sh.sampleSphere_Spiral(rows,cols)
elif cluster_method == "icosphere":
centroids, _ = sh.sampleSphere_Icosphere(rows)
elif cluster_method == "random":
centroids, _ = sh.sampleSphere_Random(rows,cols)
elif cluster_method == "random_nodes":
index = torch.multinomial(torch.ones(len(pos3D)),N) # close equivalent to np.random.choice
centroids = pos3D[index]
elif cluster_method == "fps":
# Farthest Point Search used in PointNet++
index = fps(pos3D, ratio=ratio)
centroids = pos3D[index]
else:
raise ValueError("Sphere cluster_method unknown")
# Find closest centriod to each current point
cluster = knn(centroids,pos3D,1)[1]
cluster, _ = consecutive_cluster(cluster)
pos3D = scatter(pos3D, cluster, dim=0, reduce='mean')
# Regenerate surface graph
normals = pos3D/torch.linalg.norm(pos3D,dim=1,keepdims=True) # Make sure normals are unit vectors
edge_index,directions = gh.surface2Edges(pos3D,normals)
edge_index,selections,interps = gh.edges2Selections(edge_index,directions,interpolated=True,bary_d=bary_d)
clusters.append(torch.clone(cluster).to(device))
edge_indexes.append(torch.clone(edge_index).to(device))
selections_list.append(torch.clone(selections).to(device))
interps_list.append(torch.clone(interps).to(device))
return clusters, edge_indexes, selections_list, interps_list
def makeSurfaceClusters(pos3D,normals,edge_index,selections,interps,cluster_method="random",ratio=.25,up_vector=None,depth=1,device='cpu'):
clusters = []
edge_indexes = [torch.clone(edge_index).to(device)]
selections_list = [torch.clone(selections).to(device)]
interps_list = [torch.clone(interps).to(device)]
for _ in range(depth):
#Desired number of clusters in the next level
N = int(len(pos3D) * ratio)
if cluster_method == "random":
index = torch.multinomial(torch.ones(len(pos3D)),N) # close equivalent to np.random.choice
centroids = pos3D[index]
elif cluster_method == "fps":
# Farthest Point Search used in PointNet++
index = fps(pos3D, ratio=ratio)
centroids = pos3D[index]
# Find closest centriod to each current point
cluster = knn(centroids,pos3D,1)[1]
cluster, _ = consecutive_cluster(cluster)
pos3D = scatter(pos3D, cluster, dim=0, reduce='mean')
normals = scatter(normals, cluster, dim=0, reduce='mean')
# Regenerate surface graph
normals = normals/torch.linalg.norm(normals,dim=1,keepdims=True) # Make sure normals are unit vectors
edge_index,directions = gh.surface2Edges(pos3D,normals,up_vector,k_neighbors=16)
edge_index,selections,interps = gh.edges2Selections(edge_index,directions,interpolated=True)
clusters.append(torch.clone(cluster).to(device))
edge_indexes.append(torch.clone(edge_index).to(device))
selections_list.append(torch.clone(selections).to(device))
interps_list.append(torch.clone(interps).to(device))
return clusters, edge_indexes, selections_list, interps_list
def makeMeshClusters(pos3D,mesh,edge_index,selections,interps,ratio=.25,up_vector=None,depth=1,device='cpu'):
clusters = []
edge_indexes = [torch.clone(edge_index).to(device)]
selections_list = [torch.clone(selections).to(device)]
interps_list = [torch.clone(interps).to(device)]
for _ in range(depth):
#Desired number of clusters in the next level
N = int(len(pos3D) * ratio)
# Generate new point cloud from downsampled version of texture map
centroids, normals = mh.sampleSurface(mesh,N,return_x=False)
# Find closest centriod to each current point
cluster = knn(centroids,pos3D,1)[1]
cluster, _ = consecutive_cluster(cluster)
pos3D = scatter(pos3D, cluster, dim=0, reduce='mean')
# Regenerate surface graph
#normals = normals/torch.linalg.norm(normals,dim=1,keepdims=True) # Make sure normals are unit vectors
edge_index,directions = gh.surface2Edges(pos3D,normals,up_vector)
edge_index,selections,interps = gh.edges2Selections(edge_index,directions,interpolated=True)
clusters.append(torch.clone(cluster).to(device))
edge_indexes.append(torch.clone(edge_index).to(device))
selections_list.append(torch.clone(selections).to(device))
interps_list.append(torch.clone(interps).to(device))
return clusters, edge_indexes, selections_list, interps_list
def getGrid(pos,Nx,Ny,xrange=None,yrange=None):
xmin = torch.min(pos[:,0]) if xrange is None else xrange[0]
ymin = torch.min(pos[:,1]) if yrange is None else yrange[0]
xmax = torch.max(pos[:,0]) if xrange is None else xrange[1]
ymax = torch.max(pos[:,1]) if yrange is None else yrange[1]
cx = torch.clamp(torch.floor((pos[:,0] - xmin)/(xmax-xmin) * Nx),0,Nx-1)
cy = torch.clamp(torch.floor((pos[:,1] - ymin)/(ymax-ymin) * Ny),0,Ny-1)
return cx, cy
def gridCluster(pos,cx,cy,xmax):
cluster = cx + cy*xmax
cluster = cluster.type(torch.long) # Cast appropriately
cluster, _ = consecutive_cluster(cluster)
pos = scatter(pos, cluster, dim=0, reduce='mean')
return cluster, pos
def selectionAverage(cluster, edge_index, selections):
num_nodes = cluster.size(0)
edge_index = cluster[edge_index.contiguous().view(1, -1)].view(2, -1)
edge_index, selections = remove_self_loops(edge_index, selections)
if edge_index.numel() > 0:
# To avoid means over discontinuities, do mean for two selections at at a time
final_edge_index, selections_check = coalesce(edge_index, selections.type(torch.float), num_nodes, num_nodes, op="mean")
selections_check = torch.round(selections_check).type(torch.long)
final_selections = torch.zeros_like(selections_check).to(selections.device)
final_selections[torch.where(selections_check==4)] = 4
final_selections[torch.where(selections_check==5)] = 5
#Rotate selection kernel
selections += 2
selections = selections % 9 + torch.div(selections, 9, rounding_mode="floor")
_, selections_check = coalesce(edge_index, selections.type(torch.float), num_nodes, num_nodes, op="mean")
selections_check = torch.round(selections_check).type(torch.long)
final_selections[torch.where(selections_check==4)] = 2
final_selections[torch.where(selections_check==5)] = 3
#Rotate selection kernel
selections += 2
selections = selections % 9 + torch.div(selections, 9, rounding_mode="floor")
_, selections_check = coalesce(edge_index, selections.type(torch.float), num_nodes, num_nodes, op="mean")
selections_check = torch.round(selections_check).type(torch.long)
final_selections[torch.where(selections_check==4)] = 8
final_selections[torch.where(selections_check==5)] = 1
#Rotate selection kernel
selections += 2
selections = selections % 9 + torch.div(selections, 9, rounding_mode="floor")
_, selections_check = coalesce(edge_index, selections.type(torch.float), num_nodes, num_nodes, op="mean")
selections_check = torch.round(selections_check).type(torch.long)
final_selections[torch.where(selections_check==4)] = 6
final_selections[torch.where(selections_check==5)] = 7
#print(torch.min(final_selections), torch.max(final_selections))
#print(torch.mean(final_selections.type(torch.float)))
edge_index, selections = add_remaining_self_loops(final_edge_index,final_selections,fill_value=torch.tensor(0,dtype=torch.long))
else:
edge_index, selections = add_remaining_self_loops(edge_index,selections,fill_value=torch.tensor(0,dtype=torch.long))
print("Warning: Edge Pool found no edges")
return edge_index, selections