| | import torch
|
| | from torch import nn
|
| | import torch.nn.functional as F
|
| | from config import CFG
|
| | import utils
|
| | import math
|
| | import numpy as np
|
| | from cliplayers import QuickGELU, Transformer as MSTsfmEncoder
|
| | from GNN import layers as gly
|
| |
|
| | class MolGNNEncoder(nn.Module):
|
| | def __init__(self,
|
| | outdim,
|
| | n_feats=74,
|
| | n_filters_list=[256, 256, 256],
|
| | n_head=4,
|
| | mols=1,
|
| | adj_chans=6,
|
| | readout_layers=2,
|
| | bias=True):
|
| |
|
| | super().__init__()
|
| |
|
| | n_filters_list = [i for i in n_filters_list if i is not None]
|
| | lys = []
|
| |
|
| | for i, nf in enumerate(n_filters_list):
|
| | if i == 0:
|
| | nf1 = n_feats
|
| | else:
|
| | nf1 = prevnf
|
| |
|
| | prevnf = nf
|
| |
|
| | ly = gly.GConvBlockNoGF(nf1, nf, mols, adj_chans, bias)
|
| | lys.append(ly)
|
| |
|
| | self.block_layers = nn.ModuleList(lys)
|
| | self.attention_layer = gly.MultiHeadGlobalAttention(nf, n_head=n_head, concat=True, bias=bias)
|
| | self.readout_layers = nn.ModuleList([nn.Linear(nf*n_head, outdim, bias=bias)] + [nn.Linear(outdim, outdim) for _ in range(readout_layers-1)])
|
| | self.gelu = QuickGELU()
|
| |
|
| | def forward(self, batch):
|
| | V = batch['V']
|
| | A = batch['A']
|
| | mol_size = batch['mol_size']
|
| |
|
| | for ly in self.block_layers:
|
| | V = ly(V, A)
|
| |
|
| | X = self.attention_layer(V, mol_size)
|
| |
|
| | for ly in self.readout_layers:
|
| | X = self.gelu(ly(X))
|
| |
|
| | return X
|
| |
|
| | class ProjectionHead(nn.Module):
|
| | def __init__(self,
|
| | embedding_dim,
|
| | projection_dim,
|
| | cfg,
|
| | transformer=True,
|
| | lstm=False):
|
| |
|
| | super().__init__()
|
| |
|
| | self.projection = nn.Linear(embedding_dim, projection_dim)
|
| | self.gelu = nn.GELU()
|
| | self.transformer = None
|
| | if transformer:
|
| | self.transformer = MSTsfmEncoder(projection_dim, cfg.tsfm_layers, cfg.tsfm_heads)
|
| | self.lstm = None
|
| | if lstm:
|
| | self.lstm = nn.LSTM(input_size=projection_dim, hidden_size=projection_dim, num_layers=cfg.lstm_layers, batch_first=True)
|
| | self.dropout = nn.Dropout(cfg.dropout)
|
| |
|
| | def forward(self, x):
|
| | projected = self.projection(x)
|
| | if self.transformer is None:
|
| | x = self.gelu(projected)
|
| | else:
|
| | x = self.transformer(projected)
|
| | if not self.lstm is None:
|
| | x, (_, _) = self.lstm(x)
|
| | x = self.dropout(x)
|
| |
|
| | return x
|
| |
|
| |
|
| | class FragSimiModel(nn.Module):
|
| | def __init__(
|
| | self,
|
| | cfg
|
| | ):
|
| | super().__init__()
|
| |
|
| | self.cfg = cfg
|
| | self.mol_gnn_encoder = None
|
| | mol_embedding_dim = cfg.mol_embedding_dim
|
| |
|
| | if 'gnn' in self.cfg.mol_encoder:
|
| | self.mol_gnn_encoder = MolGNNEncoder(outdim=cfg.mol_embedding_dim,
|
| | n_filters_list=cfg.molgnn_n_filters_list,
|
| | n_head=cfg.molgnn_nhead,
|
| | readout_layers=cfg.molgnn_readout_layers)
|
| | if 'fp' in self.cfg.mol_encoder:
|
| | mol_embedding_dim = 2*cfg.mol_embedding_dim
|
| |
|
| | if 'fm' in self.cfg.mol_encoder:
|
| | mol_embedding_dim += 10
|
| |
|
| | self.ms_projection = ProjectionHead(cfg.ms_embedding_dim,
|
| | cfg.projection_dim,
|
| | cfg,
|
| | cfg.tsfm_in_ms,
|
| | cfg.lstm_in_ms)
|
| |
|
| | self.mol_projection = ProjectionHead(mol_embedding_dim,
|
| | cfg.projection_dim,
|
| | cfg,
|
| | cfg.tsfm_in_mol,
|
| | cfg.lstm_in_mol)
|
| |
|
| | def forward(self, batch):
|
| | ms_features = batch["ms_bins"]
|
| | mol_feat_list = []
|
| | if 'gnn' in self.cfg.mol_encoder:
|
| | mol_feat_list.append(self.mol_gnn_encoder(batch))
|
| | if 'fp' in self.cfg.mol_encoder:
|
| | mol_feat_list.append(batch["mol_fps"])
|
| | if 'fm' in self.cfg.mol_encoder:
|
| | mol_feat_list.append(batch["mol_fmvec"])
|
| |
|
| | if len(mol_feat_list) > 1:
|
| | mol_features = torch.cat(mol_feat_list, dim=1)
|
| | else:
|
| | mol_features = mol_feat_list[0]
|
| |
|
| |
|
| | ms_embeddings = self.ms_projection(ms_features)
|
| | mol_embeddings = self.mol_projection(mol_features)
|
| |
|
| |
|
| | mol_embeddings = F.normalize(mol_embeddings, p=2, dim=1)
|
| | ms_embeddings = F.normalize(ms_embeddings, p=2, dim=1)
|
| |
|
| | return mol_embeddings, ms_embeddings
|
| |
|
| |
|
| |
|
| |
|
| | '''logits = mol_embeddings @ ms_embeddings.t()
|
| |
|
| | ground_truth = torch.arange(ms_features.shape[0], dtype=torch.long, device=self.cfg.device)
|
| |
|
| | ms_loss = loss_func(logits, ground_truth)
|
| | mol_loss = loss_func(logits.t(), ground_truth)
|
| | loss = (ms_loss + mol_loss) / 2.0 # shape: (batch_size)
|
| |
|
| | return loss.mean()'''
|
| |
|