| import torch |
| import torch.nn as nn |
| import numpy as np |
|
|
|
|
| class GraphUnet(nn.Module): |
|
|
| def __init__(self, ks, in_dim, out_dim, dim, act, drop_p): |
| super(GraphUnet, self).__init__() |
| self.ks = ks |
| self.bottom_gcn = GCN(dim, dim, act, drop_p) |
| self.down_gcns = nn.ModuleList() |
| self.up_gcns = nn.ModuleList() |
| self.pools = nn.ModuleList() |
| self.unpools = nn.ModuleList() |
| self.l_n = len(ks) |
| for i in range(self.l_n): |
| self.down_gcns.append(GCN(dim, dim, act, drop_p)) |
| self.up_gcns.append(GCN(dim, dim, act, drop_p)) |
| self.pools.append(Pool(ks[i], dim, drop_p)) |
| self.unpools.append(Unpool(dim, dim, drop_p)) |
|
|
| def forward(self, g, h): |
| adj_ms = [] |
| indices_list = [] |
| down_outs = [] |
| hs = [] |
| org_h = h |
| for i in range(self.l_n): |
| h = self.down_gcns[i](g, h) |
| adj_ms.append(g) |
| down_outs.append(h) |
| g, h, idx = self.pools[i](g, h) |
| indices_list.append(idx) |
| h = self.bottom_gcn(g, h) |
| for i in range(self.l_n): |
| up_idx = self.l_n - i - 1 |
| g, idx = adj_ms[up_idx], indices_list[up_idx] |
| g, h = self.unpools[i](g, h, down_outs[up_idx], idx) |
| h = self.up_gcns[i](g, h) |
| h = h.add(down_outs[up_idx]) |
| hs.append(h) |
| h = h.add(org_h) |
| hs.append(h) |
| return hs |
|
|
|
|
| class GCN(nn.Module): |
|
|
| def __init__(self, in_dim, out_dim, act, p): |
| super(GCN, self).__init__() |
| self.proj = nn.Linear(in_dim, out_dim) |
| self.act = act |
| self.drop = nn.Dropout(p=p) if p > 0.0 else nn.Identity() |
|
|
| def forward(self, g, h): |
| h = self.drop(h) |
| h = torch.matmul(g, h) |
| h = self.proj(h) |
| h = self.act(h) |
| return h |
|
|
|
|
| class Pool(nn.Module): |
|
|
| def __init__(self, k, in_dim, p): |
| super(Pool, self).__init__() |
| self.k = k |
| self.sigmoid = nn.Sigmoid() |
| self.proj = nn.Linear(in_dim, 1) |
| self.drop = nn.Dropout(p=p) if p > 0 else nn.Identity() |
|
|
| def forward(self, g, h): |
| Z = self.drop(h) |
| weights = self.proj(Z).squeeze() |
| scores = self.sigmoid(weights) |
| return top_k_graph(scores, g, h, self.k) |
|
|
|
|
| class Unpool(nn.Module): |
|
|
| def __init__(self, *args): |
| super(Unpool, self).__init__() |
|
|
| def forward(self, g, h, pre_h, idx): |
| new_h = h.new_zeros([g.shape[0], h.shape[1]]) |
| new_h[idx] = h |
| return g, new_h |
|
|
|
|
| def top_k_graph(scores, g, h, k): |
| num_nodes = g.shape[0] |
| values, idx = torch.topk(scores, max(2, int(k*num_nodes))) |
| new_h = h[idx, :] |
| values = torch.unsqueeze(values, -1) |
| new_h = torch.mul(new_h, values) |
| un_g = g.bool().float() |
| un_g = torch.matmul(un_g, un_g).bool().float() |
| un_g = un_g[idx, :] |
| un_g = un_g[:, idx] |
| g = norm_g(un_g) |
| return g, new_h, idx |
|
|
|
|
| def norm_g(g): |
| degrees = torch.sum(g, 1) |
| g = g / degrees |
| return g |
|
|
|
|
| class Initializer(object): |
|
|
| @classmethod |
| def _glorot_uniform(cls, w): |
| if len(w.size()) == 2: |
| fan_in, fan_out = w.size() |
| elif len(w.size()) == 3: |
| fan_in = w.size()[1] * w.size()[2] |
| fan_out = w.size()[0] * w.size()[2] |
| else: |
| fan_in = np.prod(w.size()) |
| fan_out = np.prod(w.size()) |
| limit = np.sqrt(6.0 / (fan_in + fan_out)) |
| w.uniform_(-limit, limit) |
|
|
| @classmethod |
| def _param_init(cls, m): |
| if isinstance(m, nn.parameter.Parameter): |
| cls._glorot_uniform(m.data) |
| elif isinstance(m, nn.Linear): |
| m.bias.data.zero_() |
| cls._glorot_uniform(m.weight.data) |
|
|
| @classmethod |
| def weights_init(cls, m): |
| for p in m.modules(): |
| if isinstance(p, nn.ParameterList): |
| for pp in p: |
| cls._param_init(pp) |
| else: |
| cls._param_init(p) |
|
|
| for name, p in m.named_parameters(): |
| if '.' not in name: |
| cls._param_init(p) |
|
|