| import torch |
| import torch.nn as nn |
| import torch.nn.parallel |
| import torch.utils.data |
| import numpy as np |
| import torch.nn.functional as F |
| from torch.nn import Parameter |
|
|
| from torch_geometric.nn.dense.linear import Linear |
| from torch_geometric.nn.conv import MessagePassing |
| from torch_geometric.utils import softmax |
| |
| from torch_geometric.nn.inits import glorot, zeros |
|
|
| from torch_scatter import scatter |
| from utils.utils import triplets,get_angle,GaussianSmearing |
| from torch.nn import ModuleList |
| from math import pi as PI |
| import math |
|
|
| """ |
| The theory based Grid cell spatial relation encoder, |
| See https://openreview.net/forum?id=Syx0Mh05YQ |
| Learning Grid Cells as Vector Representation of Self-Position Coupled with Matrix Representation of Self-Motion |
| """ |
| def _cal_freq_list(freq_init, frequency_num, max_radius, min_radius): |
| if freq_init == "random": |
| |
| |
| freq_list = np.random.random(size=[frequency_num]) * max_radius |
| elif freq_init == "geometric": |
| |
| |
| |
| |
|
|
| |
|
|
| log_timescale_increment = (math.log(float(max_radius) / float(min_radius)) / |
| (frequency_num*1.0 - 1)) |
|
|
| timescales = min_radius * np.exp( |
| np.arange(frequency_num).astype(float) * log_timescale_increment) |
|
|
| freq_list = 1.0/timescales |
|
|
| return freq_list |
| class TheoryGridCellSpatialRelationEncoder(nn.Module): |
| """ |
| Given a list of (deltaX,deltaY), encode them using the position encoding function |
| |
| """ |
| def __init__(self, spa_embed_dim, coord_dim = 2, frequency_num = 16, |
| max_radius = 10000, min_radius = 1000, freq_init = "geometric", ffn = None): |
| """ |
| Args: |
| spa_embed_dim: the output spatial relation embedding dimention |
| coord_dim: the dimention of space, 2D, 3D, or other |
| frequency_num: the number of different sinusoidal with different frequencies/wavelengths |
| max_radius: the largest context radius this model can handle |
| """ |
| super(TheoryGridCellSpatialRelationEncoder, self).__init__() |
| self.frequency_num = frequency_num |
| self.coord_dim = coord_dim |
| self.max_radius = max_radius |
| self.min_radius = min_radius |
| self.spa_embed_dim = spa_embed_dim |
| self.freq_init = freq_init |
|
|
| |
| self.cal_freq_list() |
| self.cal_freq_mat() |
|
|
| |
| self.unit_vec1 = np.asarray([1.0, 0.0]) |
| self.unit_vec2 = np.asarray([-1.0/2.0, math.sqrt(3)/2.0]) |
| self.unit_vec3 = np.asarray([-1.0/2.0, -math.sqrt(3)/2.0]) |
|
|
|
|
| self.input_embed_dim = self.cal_input_dim() |
| self.ffn = ffn |
| |
| def cal_freq_list(self): |
| self.freq_list = _cal_freq_list(self.freq_init, self.frequency_num, self.max_radius, self.min_radius) |
|
|
| def cal_freq_mat(self): |
| |
| freq_mat = np.expand_dims(self.freq_list, axis = 1) |
| |
| self.freq_mat = np.repeat(freq_mat, 6, axis = 1) |
|
|
| def cal_input_dim(self): |
| |
| return int(6 * self.frequency_num) |
|
|
|
|
| def make_input_embeds(self, coords): |
| if type(coords) == np.ndarray: |
| assert self.coord_dim == np.shape(coords)[2] |
| coords = list(coords) |
| elif type(coords) == list: |
| assert self.coord_dim == len(coords[0][0]) |
| elif type(coords) == torch.Tensor: |
| assert self.coord_dim == (coords.shape)[2] |
| coords=coords.detach().cpu().numpy() |
| else: |
| raise Exception("Unknown coords data type for GridCellSpatialRelationEncoder") |
|
|
| |
| coords_mat = np.asarray(coords).astype(float) |
| batch_size = coords_mat.shape[0] |
| num_context_pt = coords_mat.shape[1] |
|
|
| |
| |
| angle_mat1 = np.expand_dims(np.matmul(coords_mat, self.unit_vec1), axis = -1) |
| |
| angle_mat2 = np.expand_dims(np.matmul(coords_mat, self.unit_vec2), axis = -1) |
| |
| angle_mat3 = np.expand_dims(np.matmul(coords_mat, self.unit_vec3), axis = -1) |
|
|
| |
| angle_mat = np.concatenate([angle_mat1, angle_mat1, angle_mat2, angle_mat2, angle_mat3, angle_mat3], axis = -1) |
| |
| angle_mat = np.expand_dims(angle_mat, axis = -2) |
| |
| angle_mat = np.repeat(angle_mat, self.frequency_num, axis = -2) |
| |
| angle_mat = angle_mat * self.freq_mat |
| |
| spr_embeds = np.reshape(angle_mat, (batch_size, num_context_pt, -1)) |
|
|
| |
| |
| |
| spr_embeds[:, :, 0::2] = np.sin(spr_embeds[:, :, 0::2]) |
| spr_embeds[:, :, 1::2] = np.cos(spr_embeds[:, :, 1::2]) |
| |
| return spr_embeds |
| |
| |
| def forward(self, coords): |
| """ |
| Given a list of coords (deltaX, deltaY), give their spatial relation embedding |
| Args: |
| coords: a python list with shape (batch_size, num_context_pt, coord_dim) |
| Return: |
| sprenc: Tensor shape (batch_size, num_context_pt, spa_embed_dim) |
| """ |
| spr_embeds = self.make_input_embeds(coords) |
|
|
| |
| spr_embeds = torch.FloatTensor(spr_embeds) |
| if self.ffn is not None: |
| return self.ffn(spr_embeds) |
| else: |
| return spr_embeds |
| theoryencoder=TheoryGridCellSpatialRelationEncoder(spa_embed_dim=8) |
|
|
| class GFusion(nn.Module): |
| def __init__(self, h_channel=16,input_featuresize=32,localdepth=2,num_interactions=3,finaldepth=3,num_of_datasources=2,share=True,batchnorm="False"): |
| super(GFusion,self).__init__() |
| self.training=True |
| self.h_channel = h_channel |
| self.input_featuresize=input_featuresize |
| self.localdepth = localdepth |
| self.num_interactions=num_interactions |
| self.finaldepth=finaldepth |
| self.batchnorm = batchnorm |
| self.activation=nn.ReLU() |
|
|
| num_gaussians=(1,12) |
| self.theta_expansion = GaussianSmearing(-PI, PI, num_gaussians[1]) |
| self.mlps_list = ModuleList() |
| if int(share[0])==1: |
| mlp_geo = ModuleList() |
| for i in range(self.localdepth): |
| if i == 0: |
| mlp_geo.append(Linear(sum(num_gaussians), h_channel)) |
| else: |
| mlp_geo.append(Linear(h_channel, h_channel)) |
| if self.batchnorm == "True": |
| mlp_geo.append(nn.BatchNorm1d(h_channel)) |
| mlp_geo.append(self.activation) |
| for i in range(num_of_datasources): |
| self.mlps_list.append(mlp_geo) |
| else: |
| for i in range(num_of_datasources): |
| mlp_geo = ModuleList() |
| for i in range(self.localdepth): |
| if i == 0: |
| mlp_geo.append(Linear(sum(num_gaussians), h_channel)) |
| else: |
| mlp_geo.append(Linear(h_channel, h_channel)) |
| if self.batchnorm == "True": |
| mlp_geo.append(nn.BatchNorm1d(h_channel)) |
| mlp_geo.append(self.activation) |
| self.mlps_list.append(mlp_geo) |
| self.mlps_list_backup = ModuleList() |
| for i in range(num_of_datasources): |
| mlp_geo = ModuleList() |
| for i in range(self.localdepth): |
| if i == 0: |
| mlp_geo.append(Linear(4, h_channel)) |
| else: |
| mlp_geo.append(Linear(h_channel, h_channel)) |
| if self.batchnorm == "True": |
| mlp_geo.append(nn.BatchNorm1d(h_channel)) |
| mlp_geo.append(self.activation) |
| self.mlps_list_backup.append(mlp_geo) |
| self.translinear=Linear(input_featuresize+1, self.h_channel) |
| self.interactions_list = ModuleList() |
| if int(share[1])==1: |
| interactions= ModuleList() |
| for i in range(self.num_interactions): |
| block = SPNN( |
| in_ch=self.input_featuresize, |
| hidden_channels=self.h_channel, |
| activation=self.activation, |
| finaldepth=self.finaldepth, |
| batchnorm=self.batchnorm, |
| num_input_geofeature=self.h_channel |
| ) |
| interactions.append(block) |
| for i in range(num_of_datasources): |
| self.interactions_list.append(interactions) |
| else: |
| for i in range(num_of_datasources): |
| interactions= ModuleList() |
| for i in range(self.num_interactions): |
| block = SPNN( |
| in_ch=self.input_featuresize, |
| hidden_channels=self.h_channel, |
| activation=self.activation, |
| finaldepth=self.finaldepth, |
| batchnorm=self.batchnorm, |
| num_input_geofeature=self.h_channel |
| ) |
| interactions.append(block) |
| self.interactions_list.append(interactions) |
| self.finalMLP_list = ModuleList() |
| if int(share[2])==1: |
| finalMLP=ModuleList() |
| for i in range(self.finaldepth + 1): |
| finalMLP.append(Linear(self.h_channel, self.h_channel)) |
| if self.batchnorm == "True": |
| finalMLP.append(nn.BatchNorm1d(self.h_channel)) |
| finalMLP.append(self.activation) |
| finalMLP.append(Linear(self.h_channel, 1)) |
| for i in range(num_of_datasources): |
| self.finalMLP_list.append(finalMLP) |
| else: |
| for i in range(num_of_datasources): |
| finalMLP=ModuleList() |
| for i in range(self.finaldepth + 1): |
| finalMLP.append(Linear(self.h_channel, self.h_channel)) |
| if self.batchnorm == "True": |
| finalMLP.append(nn.BatchNorm1d(self.h_channel)) |
| finalMLP.append(self.activation) |
| finalMLP.append(Linear(self.h_channel, 1)) |
| self.finalMLP_list.append(finalMLP) |
| self.reset_parameters() |
| def reset_parameters(self): |
| for i in range(len(self.mlps_list)): |
| for lin in self.mlps_list[i]: |
| if isinstance(lin, Linear): |
| torch.nn.init.xavier_uniform_(lin.weight) |
| lin.bias.data.fill_(0) |
| for i in range(len(self.interactions_list)): |
| for block in self.interactions_list[i]: |
| block.reset_parameters() |
| for finalMLP in self.finalMLP_list: |
| for lin in finalMLP: |
| if isinstance(lin, Linear): |
| torch.nn.init.xavier_uniform_(lin.weight) |
| lin.bias.data.fill_(0) |
|
|
| def single_forward(self, coords,edge_index,edge_index_2rd, edx_2nd,batch,input_feature,is_source,edge_rep,datasource_idx): |
| distances={} |
| thetas={} |
| if edge_rep: |
| i, j, k = edge_index_2rd |
| distances[1]=(coords[edge_index[0]] - coords[edge_index[1]]).norm(p=2, dim=1) |
| theta_ijk = get_angle(coords[j] - coords[i], coords[k] - coords[j]) |
| v1 = torch.cross(F.pad(coords[j] - coords[i],(0,1)), F.pad(coords[k] - coords[j],(0,1)), dim=1)[...,2] |
| flag = torch.sign((v1)) |
| flag[flag==0]=-1 |
| thetas[1] = scatter(theta_ijk*flag ,edx_2nd,dim=0,dim_size=edge_index.shape[1],reduce='min') |
| thetas[1]=self.theta_expansion(thetas[1]) |
| geo_encoding_1st=distances[1][:,None] |
| geo_encoding_1st[geo_encoding_1st==0]=1E-10 |
| geo_encoding_1st=torch.pow(geo_encoding_1st,-1) |
| geo_encoding_2nd = thetas[1] |
| geo_encoding=torch.cat([geo_encoding_1st,geo_encoding_2nd],dim=-1) |
| else: |
| |
| |
| |
| coords_j = coords[edge_index[0]] |
| coords_i = coords[edge_index[1]] |
| geo_encoding=torch.cat([coords_j,coords_i],dim=-1) |
| if edge_rep: |
| for lin in self.mlps_list[datasource_idx]: |
| geo_encoding=lin(geo_encoding) |
| else: |
| for lin in self.mlps_list_backup[datasource_idx]: |
| geo_encoding=lin(geo_encoding) |
| geo_encoding=torch.zeros_like(geo_encoding,device=geo_encoding.device,dtype=geo_encoding.dtype) |
| node_feature=self.translinear(input_feature[:,:-2]) |
| for interaction in self.interactions_list[datasource_idx]: |
| node_feature = interaction(node_feature,geo_encoding,edge_index,is_source) |
| return node_feature |
| def forward(self, coords,edge_index,edge_index_2rd, edx_2nd,batch,input_feature,is_source,edge_rep): |
| outputs=[] |
| for i in range(len(coords)): |
| output=self.single_forward(coords[i],edge_index[i],edge_index_2rd[i], edx_2nd[i],batch[i],input_feature[i],is_source[i],edge_rep,i) |
| for lin in self.finalMLP_list[i]: |
| output=lin(output) |
| outputs.append(output) |
| return outputs |
| |
| class SPNN(torch.nn.Module): |
| def __init__( |
| self, |
| in_ch, |
| hidden_channels, |
| activation=torch.nn.ReLU(), |
| finaldepth=3, |
| batchnorm="False", |
| num_input_geofeature=13 |
| ): |
| super(SPNN, self).__init__() |
| self.activation = activation |
| self.finaldepth = finaldepth |
| self.batchnorm = batchnorm |
| self.num_input_geofeature=num_input_geofeature |
| self.att = Parameter(torch.Tensor(1, hidden_channels),requires_grad=True) |
|
|
| self.WMLP = ModuleList() |
| for i in range(self.finaldepth + 1): |
| if i == 0: |
| self.WMLP.append(Linear(hidden_channels*2+num_input_geofeature, hidden_channels)) |
| else: |
| self.WMLP.append(Linear(hidden_channels, hidden_channels)) |
| if self.batchnorm == "True": |
| self.WMLP.append(nn.BatchNorm1d(hidden_channels)) |
| self.WMLP.append(self.activation) |
| self.reset_parameters() |
|
|
| def reset_parameters(self): |
| for lin in self.WMLP: |
| if isinstance(lin, Linear): |
| torch.nn.init.xavier_uniform_(lin.weight) |
| lin.bias.data.fill_(0) |
| glorot(self.att) |
| def forward(self, node_feature,geo_encoding,edge_index,is_source): |
| j, i = edge_index |
| input_feature=node_feature.clone() |
| if node_feature is None: |
| concatenated_vector = geo_encoding |
| else: |
| node_attr_0st = node_feature[i] |
| node_attr_1st = node_feature[j] |
| concatenated_vector = torch.cat( |
| [ |
| node_attr_0st, |
| node_attr_1st, |
| geo_encoding, |
| ], |
| dim=-1, |
| ) |
| x_i = concatenated_vector |
| for lin in self.WMLP: |
| x_i=lin(x_i) |
| input_feature_j=input_feature[edge_index[0]] |
| x_i = F.leaky_relu(x_i) |
| alpha = F.leaky_relu(x_i * self.att).sum(dim=-1) |
| alpha = softmax(alpha, edge_index[1]) |
| |
| message=input_feature_j * alpha.unsqueeze(-1) |
| out_feature = scatter(message, edge_index[1], dim=0, reduce='add') |
| out_feature=input_feature+out_feature |
| |
| return out_feature |
|
|
|
|
|
|