File size: 10,147 Bytes
5b557cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
import torch
from torch_scatter import scatter
from torch_geometric.nn.pool.consecutive import consecutive_cluster
from torch_geometric.utils import add_self_loops, add_remaining_self_loops, remove_self_loops
from torch_geometric.nn import fps, knn
from torch_sparse import coalesce
import graph_helpers as gh
import sphere_helpers as sh
import mesh_helpers as mh

import math
from math import pi,sqrt


from warnings import warn

def makeImageClusters(pos2D,Nx,Ny,edge_index,selections,depth=1,device='cpu',stride=2):
    clusters = []
    edge_indexes = [torch.clone(edge_index).to(device)]
    selections_list = [torch.clone(selections).to(device)]
    
    for _ in range(depth):
        Nx = Nx//stride
        Ny = Ny//stride
        cx,cy = getGrid(pos2D,Nx,Ny)
        cluster, pos2D = gridCluster(pos2D,cx,cy,Nx)
        edge_index, selections = selectionAverage(cluster, edge_index, selections)
        
        clusters.append(torch.clone(cluster).to(device))
        edge_indexes.append(torch.clone(edge_index).to(device))
        selections_list.append(torch.clone(selections).to(device))

    return clusters, edge_indexes, selections_list
        
def makeSphereClusters(pos3D,edge_index,selections,interps,rows,cols,cluster_method="layering",stride=2,bary_d=None,depth=1,device='cpu'):
    clusters = []
    edge_indexes = [torch.clone(edge_index).to(device)]
    selections_list = [torch.clone(selections).to(device)]
    interps_list = [torch.clone(interps).to(device)]
    
    for _ in range(depth):
    
        rows = rows//stride
        cols = cols//stride
        
        if bary_d is not None:
            bary_d = bary_d*stride
    
        if cluster_method == "equirec":
            centroids, _ = sh.sampleSphere_Equirec(rows,cols)
    
        elif cluster_method == "layering":
            centroids, _ = sh.sampleSphere_Layering(rows)
            
        elif cluster_method == "spiral":
            centroids, _ = sh.sampleSphere_Spiral(rows,cols)
            
        elif cluster_method == "icosphere":
            centroids, _ = sh.sampleSphere_Icosphere(rows)
            
        elif cluster_method == "random":
            centroids, _ = sh.sampleSphere_Random(rows,cols)
            
        elif cluster_method == "random_nodes":
            index = torch.multinomial(torch.ones(len(pos3D)),N) # close equivalent to np.random.choice
            centroids = pos3D[index]

        elif cluster_method == "fps":
            # Farthest Point Search used in PointNet++
            index = fps(pos3D, ratio=ratio)
            centroids = pos3D[index]
        else:
            raise ValueError("Sphere cluster_method unknown")
        
        
        # Find closest centriod to each current point
        cluster = knn(centroids,pos3D,1)[1]
        cluster, _ = consecutive_cluster(cluster)
        pos3D = scatter(pos3D, cluster, dim=0, reduce='mean')

        # Regenerate surface graph
        normals = pos3D/torch.linalg.norm(pos3D,dim=1,keepdims=True) # Make sure normals are unit vectors
        edge_index,directions = gh.surface2Edges(pos3D,normals)
        edge_index,selections,interps = gh.edges2Selections(edge_index,directions,interpolated=True,bary_d=bary_d)
        
        clusters.append(torch.clone(cluster).to(device))
        edge_indexes.append(torch.clone(edge_index).to(device))
        selections_list.append(torch.clone(selections).to(device))
        interps_list.append(torch.clone(interps).to(device))
    
    return clusters, edge_indexes, selections_list, interps_list

def makeSurfaceClusters(pos3D,normals,edge_index,selections,interps,cluster_method="random",ratio=.25,up_vector=None,depth=1,device='cpu'):
    clusters = []
    edge_indexes = [torch.clone(edge_index).to(device)]
    selections_list = [torch.clone(selections).to(device)]
    interps_list = [torch.clone(interps).to(device)]
    
    for _ in range(depth):
    
        #Desired number of clusters in the next level
        N = int(len(pos3D) * ratio)
            
        if cluster_method == "random":
            index = torch.multinomial(torch.ones(len(pos3D)),N) # close equivalent to np.random.choice
            centroids = pos3D[index]

        elif cluster_method == "fps":
            # Farthest Point Search used in PointNet++
            index = fps(pos3D, ratio=ratio)
            centroids = pos3D[index]
        
        # Find closest centriod to each current point
        cluster = knn(centroids,pos3D,1)[1]
        cluster, _ = consecutive_cluster(cluster)
        pos3D = scatter(pos3D, cluster, dim=0, reduce='mean')
        normals = scatter(normals, cluster, dim=0, reduce='mean')

        # Regenerate surface graph
        normals = normals/torch.linalg.norm(normals,dim=1,keepdims=True) # Make sure normals are unit vectors
        edge_index,directions = gh.surface2Edges(pos3D,normals,up_vector,k_neighbors=16)
        edge_index,selections,interps = gh.edges2Selections(edge_index,directions,interpolated=True)
        
        clusters.append(torch.clone(cluster).to(device))
        edge_indexes.append(torch.clone(edge_index).to(device))
        selections_list.append(torch.clone(selections).to(device))
        interps_list.append(torch.clone(interps).to(device))
    
    return clusters, edge_indexes, selections_list, interps_list


def makeMeshClusters(pos3D,mesh,edge_index,selections,interps,ratio=.25,up_vector=None,depth=1,device='cpu'):
    clusters = []
    edge_indexes = [torch.clone(edge_index).to(device)]
    selections_list = [torch.clone(selections).to(device)]
    interps_list = [torch.clone(interps).to(device)]
    
    for _ in range(depth):
    
        #Desired number of clusters in the next level
        N = int(len(pos3D) * ratio)
        
        # Generate new point cloud from downsampled version of texture map
        centroids, normals = mh.sampleSurface(mesh,N,return_x=False)
        
        # Find closest centriod to each current point
        cluster = knn(centroids,pos3D,1)[1]
        cluster, _ = consecutive_cluster(cluster)
        pos3D = scatter(pos3D, cluster, dim=0, reduce='mean')

        # Regenerate surface graph
        #normals = normals/torch.linalg.norm(normals,dim=1,keepdims=True) # Make sure normals are unit vectors
        edge_index,directions = gh.surface2Edges(pos3D,normals,up_vector)
        edge_index,selections,interps = gh.edges2Selections(edge_index,directions,interpolated=True)
        
        clusters.append(torch.clone(cluster).to(device))
        edge_indexes.append(torch.clone(edge_index).to(device))
        selections_list.append(torch.clone(selections).to(device))
        interps_list.append(torch.clone(interps).to(device))
    
    return clusters, edge_indexes, selections_list, interps_list
    
    
    
def getGrid(pos,Nx,Ny,xrange=None,yrange=None):
    xmin = torch.min(pos[:,0]) if xrange is None else xrange[0]
    ymin = torch.min(pos[:,1]) if yrange is None else yrange[0]
    xmax = torch.max(pos[:,0]) if xrange is None else xrange[1]
    ymax = torch.max(pos[:,1]) if yrange is None else yrange[1]

    cx = torch.clamp(torch.floor((pos[:,0] - xmin)/(xmax-xmin) * Nx),0,Nx-1)
    cy = torch.clamp(torch.floor((pos[:,1] - ymin)/(ymax-ymin) * Ny),0,Ny-1)
    return cx, cy
    
def gridCluster(pos,cx,cy,xmax):
    cluster = cx + cy*xmax
    cluster = cluster.type(torch.long) # Cast appropriately
    cluster, _ = consecutive_cluster(cluster)
    pos = scatter(pos, cluster, dim=0, reduce='mean')
    
    return cluster, pos

def selectionAverage(cluster, edge_index, selections):
    num_nodes = cluster.size(0)
    edge_index = cluster[edge_index.contiguous().view(1, -1)].view(2, -1)
    edge_index, selections = remove_self_loops(edge_index, selections)
    if edge_index.numel() > 0:

        # To avoid means over discontinuities, do mean for two selections at at a time
        final_edge_index, selections_check = coalesce(edge_index, selections.type(torch.float), num_nodes, num_nodes, op="mean")
        selections_check = torch.round(selections_check).type(torch.long)

        final_selections = torch.zeros_like(selections_check).to(selections.device)

        final_selections[torch.where(selections_check==4)] = 4
        final_selections[torch.where(selections_check==5)] = 5

        #Rotate selection kernel
        selections += 2
        selections = selections % 9 + torch.div(selections, 9, rounding_mode="floor")

        _, selections_check = coalesce(edge_index, selections.type(torch.float), num_nodes, num_nodes, op="mean")
        selections_check = torch.round(selections_check).type(torch.long)
        final_selections[torch.where(selections_check==4)] = 2
        final_selections[torch.where(selections_check==5)] = 3

        #Rotate selection kernel
        selections += 2
        selections = selections % 9 + torch.div(selections, 9, rounding_mode="floor")

        _, selections_check = coalesce(edge_index, selections.type(torch.float), num_nodes, num_nodes, op="mean")
        selections_check = torch.round(selections_check).type(torch.long)
        final_selections[torch.where(selections_check==4)] = 8
        final_selections[torch.where(selections_check==5)] = 1

        #Rotate selection kernel
        selections += 2
        selections = selections % 9 + torch.div(selections, 9, rounding_mode="floor")

        _, selections_check = coalesce(edge_index, selections.type(torch.float), num_nodes, num_nodes, op="mean")
        selections_check = torch.round(selections_check).type(torch.long)
        final_selections[torch.where(selections_check==4)] = 6
        final_selections[torch.where(selections_check==5)] = 7

        #print(torch.min(final_selections), torch.max(final_selections))
        #print(torch.mean(final_selections.type(torch.float)))

        edge_index, selections = add_remaining_self_loops(final_edge_index,final_selections,fill_value=torch.tensor(0,dtype=torch.long))

    else:
        edge_index, selections = add_remaining_self_loops(edge_index,selections,fill_value=torch.tensor(0,dtype=torch.long))
        print("Warning: Edge Pool found no edges")

    return edge_index, selections