Spaces:
Running
Running
| import numpy as np | |
| import torch | |
| from torch_geometric.nn import knn | |
| from torch_geometric.data import Data | |
| from torch_geometric.nn import radius_graph, knn_graph | |
| import graph_helpers as gh | |
| import sphere_helpers as sh | |
| import mesh_helpers as mh | |
| import clusters as cl | |
| import utils | |
| from torch_scatter import scatter | |
| import math | |
| from math import pi, sqrt | |
| from warnings import warn | |
| def image2Graph(data, gt = None, mask = None, depth = 1, x_only = False, device = 'cpu'): | |
| _,ch,rows,cols = data.shape | |
| x = torch.reshape(data,(ch,rows*cols)).permute((1,0)).to(device) | |
| if mask is not None: | |
| # Mask out nodes | |
| node_mask = torch.where(mask.flatten()) | |
| x = x[node_mask] | |
| if gt is not None: | |
| y = gt.flatten().to(device) | |
| if mask is not None: | |
| y = y[node_mask] | |
| if x_only: | |
| if gt is not None: | |
| return x,y | |
| else: | |
| return x | |
| im_pos = gh.getImPos(rows,cols) | |
| if mask is not None: | |
| im_pos = im_pos[node_mask] | |
| # Make "point cloud" for clustering | |
| pos2D = gh.convertImPos(im_pos,flip_y=False) | |
| # Generate initial graph | |
| edge_index = gh.grid2Edges(pos2D) | |
| directions = pos2D[edge_index[1]] - pos2D[edge_index[0]] | |
| selections = gh.edges2Selections(edge_index,directions,interpolated=False,y_down=True) | |
| # Generate info for downsampled versions of the graph | |
| clusters, edge_indexes, selections_list = cl.makeImageClusters(pos2D,cols,rows,edge_index,selections,depth=depth,device=device) | |
| # Make final graph and metadata needed for mapping the result after going through the network | |
| graph = Data(x=x,clusters=clusters,edge_indexes=edge_indexes,selections_list=selections_list,interps_list=None) | |
| metadata = Data(original=data,im_pos=im_pos.long(),rows=rows,cols=cols,ch=ch) | |
| if gt is not None: | |
| graph.y = y | |
| return graph,metadata | |
| def graph2Image(result,metadata,canvas=None): | |
| x = utils.toNumpy(result,permute=False) | |
| im_pos = utils.toNumpy(metadata.im_pos,permute=False) | |
| if canvas is None: | |
| canvas = utils.makeCanvas(x,metadata.original) | |
| # Paint over the original image (neccesary for masked images) | |
| canvas[im_pos[:,0],im_pos[:,1]] = x | |
| return canvas | |
| ### Begin Interpolated Methods ### | |
| def sphere2Graph(data, structure="layering", cluster_method="layering", scale=1.0, stride=2, interpolation_mode = "angle", gt = None, mask = None, depth = 1, x_only = False, device = 'cpu'): | |
| _,ch,rows,cols = data.shape | |
| if structure == "equirec": | |
| # Use the original data to start with | |
| cartesian, spherical = sh.sampleSphere_Equirec(scale*rows,scale*cols) | |
| elif structure == "layering": | |
| cartesian, spherical = sh.sampleSphere_Layering(scale*rows) | |
| elif structure == "spiral": | |
| cartesian, spherical = sh.sampleSphere_Spiral(scale*rows,scale*cols) | |
| elif structure == "icosphere": | |
| cartesian, spherical = sh.sampleSphere_Icosphere(scale*rows) | |
| elif structure == "random": | |
| cartesian, spherical = sh.sampleSphere_Random(scale*rows,scale*cols) | |
| else: | |
| raise ValueError("Sphere structure unknown") | |
| if interpolation_mode == "bary": | |
| bary_d = pi/(scale*rows) | |
| else: | |
| bary_d = None | |
| # Get the landing point for each node | |
| sample_x, sample_y = sh.spherical2equirec(spherical[:,0],spherical[:,1],rows,cols) | |
| if mask is not None: | |
| node_mask = gh.maskPoints(mask,sample_x,sample_y) | |
| sample_x = sample_x[node_mask] | |
| sample_y = sample_y[node_mask] | |
| spherical = spherical[node_mask] | |
| cartesian = cartesian[node_mask] | |
| features = utils.bilinear_interpolate(data, sample_x, sample_y).to(device) | |
| if gt is not None: | |
| features_y = utils.bilinear_interpolate(gt.unsqueeze(0), sample_x, sample_y).to(device) | |
| if x_only: | |
| if gt is not None: | |
| return features,features_y | |
| else: | |
| return features | |
| # Build initial graph | |
| edge_index,directions = gh.surface2Edges(cartesian,cartesian) | |
| edge_index,selections,interps = gh.edges2Selections(edge_index,directions,interpolated=True,bary_d=bary_d) | |
| # Generate info for downsampled versions of the graph | |
| clusters, edge_indexes, selections_list, interps_list = cl.makeSphereClusters(cartesian,edge_index,selections,interps,rows*scale,cols*scale,cluster_method,stride=stride,bary_d=bary_d,depth=depth,device=device) | |
| # Make final graph and metadata needed for mapping the result after going through the network | |
| graph = Data(x=features,clusters=clusters,edge_indexes=edge_indexes,selections_list=selections_list,interps_list=interps_list) | |
| metadata = Data(original=data,pos3D=cartesian,mask=mask,rows=rows,cols=cols,ch=ch) | |
| if gt is not None: | |
| graph.y = features_y | |
| return graph, metadata | |
| def graph2Sphere(features,metadata): | |
| # Generate equirectangular points and their 3D locations | |
| theta, phi = sh.equirec2spherical(metadata.rows, metadata.cols) | |
| x,y,z = sh.spherical2xyz(theta,phi) | |
| v = torch.stack((x,y,z),dim=1) | |
| # Find closest 3D point to each equirectangular point | |
| nearest = torch.reshape(knn(metadata.pos3D,v,3)[1],(len(v),3)) | |
| #Interpolate based on proximty to each node | |
| w0 = 1/torch.linalg.norm((v - metadata.pos3D[nearest[:,0]]),dim=1, keepdim=True).to(features.device) | |
| w1 = 1/torch.linalg.norm((v - metadata.pos3D[nearest[:,1]]),dim=1, keepdim=True).to(features.device) | |
| w2 = 1/torch.linalg.norm((v - metadata.pos3D[nearest[:,2]]),dim=1, keepdim=True).to(features.device) | |
| w0 = torch.nan_to_num(w0, nan=1e6) | |
| w1 = torch.nan_to_num(w1, nan=1e6) | |
| w2 = torch.nan_to_num(w2, nan=1e6) | |
| w0 = torch.clamp(w0,0,1e6) | |
| w1 = torch.clamp(w1,0,1e6) | |
| w2 = torch.clamp(w2,0,1e6) | |
| total = w0 + w1 + w2 | |
| #w0,w1,w2 = mh.getBarycentricWeights(v,metadata.pos3D[nearest[:,0]],metadata.pos3D[nearest[:,1]],metadata.pos3D[nearest[:,2]]) | |
| #w0 = w0.unsqueeze(1).to(features.device) | |
| #w1 = w1.unsqueeze(1).to(features.device) | |
| #w2 = w2.unsqueeze(1).to(features.device) | |
| result = (w0*features[nearest[:,0]] + w1*features[nearest[:,1]] + w2*features[nearest[:,2]])/total | |
| #result = result.clamp(0,1) | |
| if hasattr(metadata,"mask"): | |
| mask = utils.toNumpy(metadata.mask.squeeze(),permute=False) | |
| canvas = utils.makeCanvas(result,metadata.original) | |
| result = np.reshape(result.data.cpu().numpy(),(metadata.rows,metadata.cols,features.shape[1])) | |
| canvas[np.where(mask)] = result[np.where(mask)] | |
| return canvas | |
| else: | |
| return np.reshape(result.data.cpu().numpy(),(metadata.rows,metadata.cols,features.shape[1])) | |
| def splat2Graph(data, mesh, up_vector = None, N = 100000, ratio=.25, depth = 1, device = 'cpu'): | |
| """ Sample mesh faces to determine graph """ | |
| if up_vector == None: | |
| up_vector = torch.tensor([[1,1,1]],dtype=torch.float) | |
| #up_vector = 2*torch.rand((1,3))-1 | |
| up_vector = up_vector/torch.linalg.norm(up_vector,dim=1) | |
| #position, normal vector, uv coordinates in the texture map, x is color | |
| pos3D, normals = mh.sampleSurface(mesh,N) | |
| # Build initial graph | |
| #edge_index are neighbors of a point, directions are the directions from that point | |
| edge_index,directions = gh.surface2Edges(pos3D,normals,up_vector,k_neighbors=16) | |
| #directions need to be turned into selections "W sub n" from the star-like coordinate system from Dr. Hart's github interpolated-selectionconv | |
| edge_index,selections,interps = gh.edges2Selections(edge_index,directions,interpolated=True) | |
| # Generate info for downsampled versions of the graph | |
| clusters, edge_indexes, selections_list, interps_list = cl.makeSurfaceClusters(pos3D,normals,edge_index,selections,interps,ratio=ratio,up_vector=up_vector,depth=depth,device=device) | |
| #clusters, edge_indexes, selections_list, interps_list = cl.makeMeshClusters(pos3D,mesh,edge_index,selections,interps,ratio=ratio,up_vector=up_vector,depth=depth,device=device) | |
| # Make final graph and metadata needed for mapping the result after going through the network | |
| graph = Data(clusters=clusters,edge_indexes=edge_indexes,selections_list=selections_list,interps_list=interps_list) | |
| metadata = Data(original=data,pos3D=pos3D,mesh=mesh) | |
| return graph,metadata | |
| def mesh2Graph(data, mesh, up_vector = None, N = 100000, ratio=.25, mask = None, depth = 1, x_only = False, device = 'cpu'): | |
| """ Sample mesh faces to determine graph """ | |
| if up_vector == None: | |
| up_vector = torch.tensor([[1,1,1]],dtype=torch.float) | |
| #up_vector = 2*torch.rand((1,3))-1 | |
| up_vector = up_vector/torch.linalg.norm(up_vector,dim=1) | |
| if mask is not None: | |
| warn("Masks are not currently implemented for mesh graphs") | |
| #position, normal vector, uv coordinates in the texture map, x is color | |
| pos3D, normals, uvs, x = mh.sampleSurface(mesh,N,return_x=True) | |
| x = x.to(device) | |
| if x_only: | |
| warn("x_only returns randomly selected points for mesh2Graph. Do not use with previous graph structures") | |
| return x | |
| # Build initial graph | |
| #edge_index are neighbors of a point, directions are the directions from that point | |
| edge_index,directions = gh.surface2Edges(pos3D,normals,up_vector,k_neighbors=16) | |
| #directions need to be turned into selections "W sub n" from the star-like coordinate system from Dr. Hart's github interpolated-selectionconv | |
| edge_index,selections,interps = gh.edges2Selections(edge_index,directions,interpolated=True) | |
| # Generate info for downsampled versions of the graph | |
| clusters, edge_indexes, selections_list, interps_list = cl.makeSurfaceClusters(pos3D,normals,edge_index,selections,interps,ratio=ratio,up_vector=up_vector,depth=depth,device=device) | |
| #clusters, edge_indexes, selections_list, interps_list = cl.makeMeshClusters(pos3D,mesh,edge_index,selections,interps,ratio=ratio,up_vector=up_vector,depth=depth,device=device) | |
| # Make final graph and metadata needed for mapping the result after going through the network | |
| graph = Data(x=x,clusters=clusters,edge_indexes=edge_indexes,selections_list=selections_list,interps_list=interps_list) | |
| metadata = Data(original=data,pos3D=pos3D,uvs=uvs,mesh=mesh) | |
| return graph,metadata | |
| def graph2Splat(features,metadata,view3D=False): | |
| features = features.cpu().numpy() | |
| canvas = utils.toNumpy(metadata.original) | |
| rows,cols,ch = canvas.shape | |
| # Get 2D positions by scaling uv | |
| pos2D = metadata.uvs.cpu().numpy() | |
| pos2D[:,0] = pos2D[:,0]*cols | |
| pos2D[:,1] = 1-pos2D[:,1] # UV puts y=0 at the bottom | |
| pos2D[:,1] = pos2D[:,1]*rows | |
| # Generate desired points | |
| row_space = np.arange(rows) | |
| col_space = np.arange(cols) | |
| col_image,row_image = np.meshgrid(col_space,row_space) | |
| canvas = utils.interpolatePointCloud2D(pos2D,features,col_image,row_image) | |
| canvas = np.clip(canvas,0,1) | |
| if view3D: | |
| mesh = mh.setTexture(metadata.mesh,canvas) | |
| mesh.show() | |
| return canvas | |
| def graph2Mesh(features,metadata,view3D=False): | |
| features = features.cpu().numpy() | |
| canvas = utils.toNumpy(metadata.original) | |
| rows,cols,ch = canvas.shape | |
| # Get 2D positions by scaling uv | |
| pos2D = metadata.uvs.cpu().numpy() | |
| pos2D[:,0] = pos2D[:,0]*cols | |
| pos2D[:,1] = 1-pos2D[:,1] # UV puts y=0 at the bottom | |
| pos2D[:,1] = pos2D[:,1]*rows | |
| # Generate desired points | |
| row_space = np.arange(rows) | |
| col_space = np.arange(cols) | |
| col_image,row_image = np.meshgrid(col_space,row_space) | |
| canvas = utils.interpolatePointCloud2D(pos2D,features,col_image,row_image) | |
| canvas = np.clip(canvas,0,1) | |
| if view3D: | |
| mesh = mh.setTexture(metadata.mesh,canvas) | |
| mesh.show() | |
| return canvas | |