Spaces:
Running
Running
| import torch | |
| from torch_scatter import scatter | |
| from torch_geometric.nn.pool.consecutive import consecutive_cluster | |
| from torch_geometric.utils import add_self_loops, add_remaining_self_loops, remove_self_loops | |
| from torch_geometric.nn import fps, knn | |
| from torch_sparse import coalesce | |
| import graph_helpers as gh | |
| import sphere_helpers as sh | |
| import mesh_helpers as mh | |
| import math | |
| from math import pi,sqrt | |
| from warnings import warn | |
| def makeImageClusters(pos2D,Nx,Ny,edge_index,selections,depth=1,device='cpu',stride=2): | |
| clusters = [] | |
| edge_indexes = [torch.clone(edge_index).to(device)] | |
| selections_list = [torch.clone(selections).to(device)] | |
| for _ in range(depth): | |
| Nx = Nx//stride | |
| Ny = Ny//stride | |
| cx,cy = getGrid(pos2D,Nx,Ny) | |
| cluster, pos2D = gridCluster(pos2D,cx,cy,Nx) | |
| edge_index, selections = selectionAverage(cluster, edge_index, selections) | |
| clusters.append(torch.clone(cluster).to(device)) | |
| edge_indexes.append(torch.clone(edge_index).to(device)) | |
| selections_list.append(torch.clone(selections).to(device)) | |
| return clusters, edge_indexes, selections_list | |
| def makeSphereClusters(pos3D,edge_index,selections,interps,rows,cols,cluster_method="layering",stride=2,bary_d=None,depth=1,device='cpu'): | |
| clusters = [] | |
| edge_indexes = [torch.clone(edge_index).to(device)] | |
| selections_list = [torch.clone(selections).to(device)] | |
| interps_list = [torch.clone(interps).to(device)] | |
| for _ in range(depth): | |
| rows = rows//stride | |
| cols = cols//stride | |
| if bary_d is not None: | |
| bary_d = bary_d*stride | |
| if cluster_method == "equirec": | |
| centroids, _ = sh.sampleSphere_Equirec(rows,cols) | |
| elif cluster_method == "layering": | |
| centroids, _ = sh.sampleSphere_Layering(rows) | |
| elif cluster_method == "spiral": | |
| centroids, _ = sh.sampleSphere_Spiral(rows,cols) | |
| elif cluster_method == "icosphere": | |
| centroids, _ = sh.sampleSphere_Icosphere(rows) | |
| elif cluster_method == "random": | |
| centroids, _ = sh.sampleSphere_Random(rows,cols) | |
| elif cluster_method == "random_nodes": | |
| index = torch.multinomial(torch.ones(len(pos3D)),N) # close equivalent to np.random.choice | |
| centroids = pos3D[index] | |
| elif cluster_method == "fps": | |
| # Farthest Point Search used in PointNet++ | |
| index = fps(pos3D, ratio=ratio) | |
| centroids = pos3D[index] | |
| else: | |
| raise ValueError("Sphere cluster_method unknown") | |
| # Find closest centriod to each current point | |
| cluster = knn(centroids,pos3D,1)[1] | |
| cluster, _ = consecutive_cluster(cluster) | |
| pos3D = scatter(pos3D, cluster, dim=0, reduce='mean') | |
| # Regenerate surface graph | |
| normals = pos3D/torch.linalg.norm(pos3D,dim=1,keepdims=True) # Make sure normals are unit vectors | |
| edge_index,directions = gh.surface2Edges(pos3D,normals) | |
| edge_index,selections,interps = gh.edges2Selections(edge_index,directions,interpolated=True,bary_d=bary_d) | |
| clusters.append(torch.clone(cluster).to(device)) | |
| edge_indexes.append(torch.clone(edge_index).to(device)) | |
| selections_list.append(torch.clone(selections).to(device)) | |
| interps_list.append(torch.clone(interps).to(device)) | |
| return clusters, edge_indexes, selections_list, interps_list | |
| def makeSurfaceClusters(pos3D,normals,edge_index,selections,interps,cluster_method="random",ratio=.25,up_vector=None,depth=1,device='cpu'): | |
| clusters = [] | |
| edge_indexes = [torch.clone(edge_index).to(device)] | |
| selections_list = [torch.clone(selections).to(device)] | |
| interps_list = [torch.clone(interps).to(device)] | |
| for _ in range(depth): | |
| #Desired number of clusters in the next level | |
| N = int(len(pos3D) * ratio) | |
| if cluster_method == "random": | |
| index = torch.multinomial(torch.ones(len(pos3D)),N) # close equivalent to np.random.choice | |
| centroids = pos3D[index] | |
| elif cluster_method == "fps": | |
| # Farthest Point Search used in PointNet++ | |
| index = fps(pos3D, ratio=ratio) | |
| centroids = pos3D[index] | |
| # Find closest centriod to each current point | |
| cluster = knn(centroids,pos3D,1)[1] | |
| cluster, _ = consecutive_cluster(cluster) | |
| pos3D = scatter(pos3D, cluster, dim=0, reduce='mean') | |
| normals = scatter(normals, cluster, dim=0, reduce='mean') | |
| # Regenerate surface graph | |
| normals = normals/torch.linalg.norm(normals,dim=1,keepdims=True) # Make sure normals are unit vectors | |
| edge_index,directions = gh.surface2Edges(pos3D,normals,up_vector,k_neighbors=16) | |
| edge_index,selections,interps = gh.edges2Selections(edge_index,directions,interpolated=True) | |
| clusters.append(torch.clone(cluster).to(device)) | |
| edge_indexes.append(torch.clone(edge_index).to(device)) | |
| selections_list.append(torch.clone(selections).to(device)) | |
| interps_list.append(torch.clone(interps).to(device)) | |
| return clusters, edge_indexes, selections_list, interps_list | |
| def makeMeshClusters(pos3D,mesh,edge_index,selections,interps,ratio=.25,up_vector=None,depth=1,device='cpu'): | |
| clusters = [] | |
| edge_indexes = [torch.clone(edge_index).to(device)] | |
| selections_list = [torch.clone(selections).to(device)] | |
| interps_list = [torch.clone(interps).to(device)] | |
| for _ in range(depth): | |
| #Desired number of clusters in the next level | |
| N = int(len(pos3D) * ratio) | |
| # Generate new point cloud from downsampled version of texture map | |
| centroids, normals = mh.sampleSurface(mesh,N,return_x=False) | |
| # Find closest centriod to each current point | |
| cluster = knn(centroids,pos3D,1)[1] | |
| cluster, _ = consecutive_cluster(cluster) | |
| pos3D = scatter(pos3D, cluster, dim=0, reduce='mean') | |
| # Regenerate surface graph | |
| #normals = normals/torch.linalg.norm(normals,dim=1,keepdims=True) # Make sure normals are unit vectors | |
| edge_index,directions = gh.surface2Edges(pos3D,normals,up_vector) | |
| edge_index,selections,interps = gh.edges2Selections(edge_index,directions,interpolated=True) | |
| clusters.append(torch.clone(cluster).to(device)) | |
| edge_indexes.append(torch.clone(edge_index).to(device)) | |
| selections_list.append(torch.clone(selections).to(device)) | |
| interps_list.append(torch.clone(interps).to(device)) | |
| return clusters, edge_indexes, selections_list, interps_list | |
| def getGrid(pos,Nx,Ny,xrange=None,yrange=None): | |
| xmin = torch.min(pos[:,0]) if xrange is None else xrange[0] | |
| ymin = torch.min(pos[:,1]) if yrange is None else yrange[0] | |
| xmax = torch.max(pos[:,0]) if xrange is None else xrange[1] | |
| ymax = torch.max(pos[:,1]) if yrange is None else yrange[1] | |
| cx = torch.clamp(torch.floor((pos[:,0] - xmin)/(xmax-xmin) * Nx),0,Nx-1) | |
| cy = torch.clamp(torch.floor((pos[:,1] - ymin)/(ymax-ymin) * Ny),0,Ny-1) | |
| return cx, cy | |
| def gridCluster(pos,cx,cy,xmax): | |
| cluster = cx + cy*xmax | |
| cluster = cluster.type(torch.long) # Cast appropriately | |
| cluster, _ = consecutive_cluster(cluster) | |
| pos = scatter(pos, cluster, dim=0, reduce='mean') | |
| return cluster, pos | |
| def selectionAverage(cluster, edge_index, selections): | |
| num_nodes = cluster.size(0) | |
| edge_index = cluster[edge_index.contiguous().view(1, -1)].view(2, -1) | |
| edge_index, selections = remove_self_loops(edge_index, selections) | |
| if edge_index.numel() > 0: | |
| # To avoid means over discontinuities, do mean for two selections at at a time | |
| final_edge_index, selections_check = coalesce(edge_index, selections.type(torch.float), num_nodes, num_nodes, op="mean") | |
| selections_check = torch.round(selections_check).type(torch.long) | |
| final_selections = torch.zeros_like(selections_check).to(selections.device) | |
| final_selections[torch.where(selections_check==4)] = 4 | |
| final_selections[torch.where(selections_check==5)] = 5 | |
| #Rotate selection kernel | |
| selections += 2 | |
| selections = selections % 9 + torch.div(selections, 9, rounding_mode="floor") | |
| _, selections_check = coalesce(edge_index, selections.type(torch.float), num_nodes, num_nodes, op="mean") | |
| selections_check = torch.round(selections_check).type(torch.long) | |
| final_selections[torch.where(selections_check==4)] = 2 | |
| final_selections[torch.where(selections_check==5)] = 3 | |
| #Rotate selection kernel | |
| selections += 2 | |
| selections = selections % 9 + torch.div(selections, 9, rounding_mode="floor") | |
| _, selections_check = coalesce(edge_index, selections.type(torch.float), num_nodes, num_nodes, op="mean") | |
| selections_check = torch.round(selections_check).type(torch.long) | |
| final_selections[torch.where(selections_check==4)] = 8 | |
| final_selections[torch.where(selections_check==5)] = 1 | |
| #Rotate selection kernel | |
| selections += 2 | |
| selections = selections % 9 + torch.div(selections, 9, rounding_mode="floor") | |
| _, selections_check = coalesce(edge_index, selections.type(torch.float), num_nodes, num_nodes, op="mean") | |
| selections_check = torch.round(selections_check).type(torch.long) | |
| final_selections[torch.where(selections_check==4)] = 6 | |
| final_selections[torch.where(selections_check==5)] = 7 | |
| #print(torch.min(final_selections), torch.max(final_selections)) | |
| #print(torch.mean(final_selections.type(torch.float))) | |
| edge_index, selections = add_remaining_self_loops(final_edge_index,final_selections,fill_value=torch.tensor(0,dtype=torch.long)) | |
| else: | |
| edge_index, selections = add_remaining_self_loops(edge_index,selections,fill_value=torch.tensor(0,dtype=torch.long)) | |
| print("Warning: Edge Pool found no edges") | |
| return edge_index, selections | |