| """ |
| (c) Adaptation of the code from https://github.com/THUDM/GraphMAE |
| """ |
|
|
| from typing import Optional |
| from itertools import chain |
| from functools import partial |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch_geometric.utils import dropout_edge |
| from torch_geometric.utils import add_self_loops |
|
|
| from .acm_gin import ACM_GIN_model |
|
|
|
|
| def sce_loss(x, y, alpha=3): |
| x = F.normalize(x, p=2, dim=-1) |
| y = F.normalize(y, p=2, dim=-1) |
|
|
| loss = (1 - (x * y).sum(dim=-1)).pow_(alpha) |
| loss = loss.mean() |
|
|
| return loss |
|
|
|
|
| def setup_module( |
| m_type, |
| in_dim, |
| out_dim, |
| num_hidden, |
| num_layers, |
| activation, |
| batchnorm, |
| ) -> nn.Module: |
|
|
| if m_type == "acm_gin": |
| mod = ACM_GIN_model( |
| int(in_dim), |
| int(out_dim), |
| num_layers, |
| int(num_hidden), |
| batchnorm, |
| activation=activation, |
| ) |
| else: |
| raise NotImplementedError |
|
|
| return mod |
|
|
|
|
| class PreModel(nn.Module): |
| def __init__( |
| self, |
| in_dim: int, |
| edge_in_dim: int, |
| num_hidden: int, |
| num_layers: int, |
| nhead: int, |
| nhead_out: int, |
| activation: str, |
| feat_drop: float, |
| attn_drop: float, |
| negative_slope: float, |
| residual: bool, |
| norm: Optional[str], |
| mask_rate: float = 0.3, |
| encoder_type: str = "gat", |
| decoder_type: str = "gat", |
| loss_fn: str = "sce", |
| drop_edge_rate: float = 0.0, |
| replace_rate: float = 0.1, |
| alpha_l: float = 2, |
| concat_hidden: bool = False, |
| batchnorm=False, |
| ): |
| super(PreModel, self).__init__() |
| self._mask_rate = mask_rate |
| self._encoder_type = encoder_type |
| self._decoder_type = decoder_type |
| self._drop_edge_rate = drop_edge_rate |
| self._output_hidden_size = num_hidden |
| self._concat_hidden = concat_hidden |
|
|
| self._replace_rate = replace_rate |
| self._mask_token_rate = 1 - self._replace_rate |
|
|
| assert num_hidden % nhead == 0 |
| assert num_hidden % nhead_out == 0 |
|
|
| enc_num_hidden = num_hidden |
| enc_nhead = 1 |
|
|
| dec_in_dim = num_hidden |
| dec_num_hidden = num_hidden |
|
|
| |
| self.encoder = setup_module( |
| m_type=encoder_type, |
| in_dim=in_dim, |
| out_dim=enc_num_hidden, |
| num_hidden=enc_num_hidden, |
| num_layers=num_layers, |
| activation=activation, |
| batchnorm=batchnorm, |
| ) |
|
|
| |
| self.decoder = setup_module( |
| m_type=decoder_type, |
| in_dim=dec_in_dim, |
| out_dim=in_dim, |
| num_hidden=dec_num_hidden, |
| num_layers=1, |
| activation=activation, |
| batchnorm=batchnorm, |
| ) |
|
|
| self.enc_mask_token = nn.Parameter(torch.zeros(1, in_dim)) |
| if concat_hidden: |
| self.encoder_to_decoder = nn.Linear( |
| dec_in_dim * num_layers, dec_in_dim, bias=False |
| ) |
| else: |
| self.encoder_to_decoder = nn.Linear(dec_in_dim, dec_in_dim, bias=False) |
|
|
| |
| self.criterion = self.setup_loss_fn(loss_fn, alpha_l) |
|
|
| @property |
| def output_hidden_dim(self): |
| return self._output_hidden_size |
|
|
| def setup_loss_fn(self, loss_fn, alpha_l): |
| if loss_fn == "mse": |
| criterion = nn.MSELoss() |
| elif loss_fn == "sce": |
| criterion = partial(sce_loss, alpha=alpha_l) |
| else: |
| raise NotImplementedError |
| return criterion |
|
|
| def encoding_mask_noise(self, x, mask_rate=0.3, virtual_node_index=None): |
| num_nodes = x.shape[0] |
| all_indices = torch.arange(num_nodes, device=x.device) |
|
|
| |
| if virtual_node_index is not None: |
| all_indices = all_indices[~torch.isin(all_indices, virtual_node_index)] |
|
|
| perm = all_indices[torch.randperm(len(all_indices), device=x.device)] |
|
|
| |
| num_mask_nodes = int(mask_rate * len(perm)) |
| mask_nodes = perm[:num_mask_nodes] |
| keep_nodes = perm[num_mask_nodes:] |
|
|
| out_x = x.clone() |
|
|
| if self._replace_rate > 0: |
| num_noise_nodes = int(self._replace_rate * num_mask_nodes) |
| perm_mask = torch.randperm(num_mask_nodes, device=x.device) |
| token_nodes = mask_nodes[ |
| perm_mask[: int(self._mask_token_rate * num_mask_nodes)] |
| ] |
| noise_nodes = mask_nodes[ |
| perm_mask[-int(self._replace_rate * num_mask_nodes) :] |
| ] |
| noise_to_be_chosen = torch.randperm(len(perm), device=x.device)[ |
| :num_noise_nodes |
| ] |
| noise_to_be_chosen = all_indices[noise_to_be_chosen] |
|
|
| out_x[token_nodes] = 0.0 |
| out_x[noise_nodes] = x[noise_to_be_chosen] |
| else: |
| token_nodes = mask_nodes |
| out_x[mask_nodes] = 0.0 |
|
|
| out_x[token_nodes] += self.enc_mask_token |
|
|
| return out_x, (mask_nodes, keep_nodes) |
|
|
| def forward(self, batch): |
| |
| x, edge_index, edge_attr, virtual_node_index, batch = ( |
| batch.x, |
| batch.edge_index, |
| batch.edge_attr, |
| getattr(batch, "virtual_node_index", None), |
| batch.batch, |
| ) |
| loss = self.mask_attr_prediction( |
| x, edge_index, edge_attr, batch, virtual_node_index |
| ) |
| return loss |
|
|
| def mask_attr_prediction(self, x, edge_index, edge_attr, batch, virtual_node_index): |
|
|
| use_x, (mask_nodes, keep_nodes) = self.encoding_mask_noise( |
| x, |
| self._mask_rate, |
| virtual_node_index, |
| ) |
|
|
| if self._drop_edge_rate > 0: |
| use_edge_index, masked_edges = dropout_edge( |
| edge_index, self._drop_edge_rate |
| ) |
| use_edge_attr = edge_attr[masked_edges] |
| use_edge_index, use_edge_attr = add_self_loops( |
| use_edge_index, use_edge_attr, fill_value="min" |
| ) |
| else: |
| use_edge_index = edge_index |
| use_edge_attr = edge_attr |
|
|
| enc_rep, all_hidden = self.encoder( |
| use_x, use_edge_index, use_edge_attr, return_hidden=True |
| ) |
| if self._concat_hidden: |
| enc_rep = torch.cat(all_hidden, dim=1) |
|
|
| |
| rep = self.encoder_to_decoder(enc_rep) |
|
|
| if self._decoder_type not in ("mlp", "linear"): |
| |
| rep[mask_nodes] = 0 |
|
|
| if self._decoder_type in ("mlp", "linear"): |
| recon = self.decoder(rep) |
| else: |
| recon = self.decoder(rep, use_edge_index, use_edge_attr) |
|
|
| x_init = x[mask_nodes] |
| x_rec = recon[mask_nodes] |
|
|
| loss = self.criterion(x_rec, x_init) |
|
|
| return loss |
|
|
| def embed(self, x, edge_index, edge_attr, batch): |
| if self._concat_hidden: |
| enc_rep, all_hidden = self.encoder( |
| x, edge_index, edge_attr, return_hidden=True |
| ) |
| enc_rep = torch.cat(all_hidden, dim=1) |
| else: |
| enc_rep = self.encoder(x, edge_index, edge_attr) |
| rep = self.encoder_to_decoder(enc_rep) |
| return rep |
|
|
| @property |
| def enc_params(self): |
| return self.encoder.parameters() |
|
|
| @property |
| def dec_params(self): |
| return chain(*[self.encoder_to_decoder.parameters(), self.decoder.parameters()]) |
|
|