clique / GraphUNets /src /utils /ops.py
qingy2024's picture
Upload folder using huggingface_hub
bf620c6 verified
import torch
import torch.nn as nn
import numpy as np
class GraphUnet(nn.Module):
def __init__(self, ks, in_dim, out_dim, dim, act, drop_p):
super(GraphUnet, self).__init__()
self.ks = ks
self.bottom_gcn = GCN(dim, dim, act, drop_p)
self.down_gcns = nn.ModuleList()
self.up_gcns = nn.ModuleList()
self.pools = nn.ModuleList()
self.unpools = nn.ModuleList()
self.l_n = len(ks)
for i in range(self.l_n):
self.down_gcns.append(GCN(dim, dim, act, drop_p))
self.up_gcns.append(GCN(dim, dim, act, drop_p))
self.pools.append(Pool(ks[i], dim, drop_p))
self.unpools.append(Unpool(dim, dim, drop_p))
def forward(self, g, h):
adj_ms = []
indices_list = []
down_outs = []
hs = []
org_h = h
for i in range(self.l_n):
h = self.down_gcns[i](g, h)
adj_ms.append(g)
down_outs.append(h)
g, h, idx = self.pools[i](g, h)
indices_list.append(idx)
h = self.bottom_gcn(g, h)
for i in range(self.l_n):
up_idx = self.l_n - i - 1
g, idx = adj_ms[up_idx], indices_list[up_idx]
g, h = self.unpools[i](g, h, down_outs[up_idx], idx)
h = self.up_gcns[i](g, h)
h = h.add(down_outs[up_idx])
hs.append(h)
h = h.add(org_h)
hs.append(h)
return hs
class GCN(nn.Module):
def __init__(self, in_dim, out_dim, act, p):
super(GCN, self).__init__()
self.proj = nn.Linear(in_dim, out_dim)
self.act = act
self.drop = nn.Dropout(p=p) if p > 0.0 else nn.Identity()
def forward(self, g, h):
h = self.drop(h)
h = torch.matmul(g, h)
h = self.proj(h)
h = self.act(h)
return h
class Pool(nn.Module):
def __init__(self, k, in_dim, p):
super(Pool, self).__init__()
self.k = k
self.sigmoid = nn.Sigmoid()
self.proj = nn.Linear(in_dim, 1)
self.drop = nn.Dropout(p=p) if p > 0 else nn.Identity()
def forward(self, g, h):
Z = self.drop(h)
weights = self.proj(Z).squeeze()
scores = self.sigmoid(weights)
return top_k_graph(scores, g, h, self.k)
class Unpool(nn.Module):
def __init__(self, *args):
super(Unpool, self).__init__()
def forward(self, g, h, pre_h, idx):
new_h = h.new_zeros([g.shape[0], h.shape[1]])
new_h[idx] = h
return g, new_h
def top_k_graph(scores, g, h, k):
num_nodes = g.shape[0]
values, idx = torch.topk(scores, max(2, int(k*num_nodes)))
new_h = h[idx, :]
values = torch.unsqueeze(values, -1)
new_h = torch.mul(new_h, values)
un_g = g.bool().float()
un_g = torch.matmul(un_g, un_g).bool().float()
un_g = un_g[idx, :]
un_g = un_g[:, idx]
g = norm_g(un_g)
return g, new_h, idx
def norm_g(g):
degrees = torch.sum(g, 1)
g = g / degrees
return g
class Initializer(object):
@classmethod
def _glorot_uniform(cls, w):
if len(w.size()) == 2:
fan_in, fan_out = w.size()
elif len(w.size()) == 3:
fan_in = w.size()[1] * w.size()[2]
fan_out = w.size()[0] * w.size()[2]
else:
fan_in = np.prod(w.size())
fan_out = np.prod(w.size())
limit = np.sqrt(6.0 / (fan_in + fan_out))
w.uniform_(-limit, limit)
@classmethod
def _param_init(cls, m):
if isinstance(m, nn.parameter.Parameter):
cls._glorot_uniform(m.data)
elif isinstance(m, nn.Linear):
m.bias.data.zero_()
cls._glorot_uniform(m.weight.data)
@classmethod
def weights_init(cls, m):
for p in m.modules():
if isinstance(p, nn.ParameterList):
for pp in p:
cls._param_init(pp)
else:
cls._param_init(p)
for name, p in m.named_parameters():
if '.' not in name:
cls._param_init(p)