| | from rdkit import Chem
|
| | from rdkit.Chem import AllChem, MACCSkeys
|
| | from rdkit.Chem.rdmolops import FastFindRings
|
| | from rdkit.Chem.rdMolDescriptors import CalcMolFormula
|
| | import torch
|
| | import numpy as np
|
| | import scipy
|
| | import scipy.sparse as ss
|
| | import scipy.sparse.linalg
|
| | import math
|
| | import json
|
| | import itertools as it
|
| | import re
|
| | from GNN import featurizer as ft
|
| |
|
| | import rdkit.RDLogger as rkl
|
| | logger = rkl.logger()
|
| | logger.setLevel(rkl.ERROR)
|
| |
|
| | import rdkit.rdBase as rkrb
|
| | rkrb.DisableLog('rdApp.error')
|
| |
|
| |
|
| | FPBitIdx = [1, 5, 13, 41, 69, 80, 84, 94, 114, 117, 118, 119, 125, 133, 145,
|
| | 147, 191, 192, 197, 202, 222, 227, 231, 249, 283, 294, 310, 314,
|
| | 322, 333, 352, 361, 378, 387, 389, 392, 401, 406, 441, 478, 486,
|
| | 489, 519, 521, 524, 555, 561, 591, 598, 599, 610, 622, 650, 656,
|
| | 667, 675, 677, 679, 680, 694, 695, 715, 718, 722, 729, 736, 739,
|
| | 745, 750, 760, 775, 781, 787, 794, 798, 802, 807, 811, 823, 835,
|
| | 841, 849, 869, 872, 874, 875, 881, 890, 896, 926, 935, 980, 991,
|
| | 1004, 1009, 1017, 1019, 1027, 1028, 1035, 1037, 1039, 1057, 1060,
|
| | 1066, 1070, 1077, 1088, 1097, 1114, 1126, 1136, 1142, 1143, 1145,
|
| | 1152, 1154, 1160, 1162, 1171, 1181, 1195, 1199, 1202, 1218, 1234,
|
| | 1236, 1243, 1257, 1267, 1274, 1279, 1283, 1292, 1294, 1309, 1313,
|
| | 1323, 1325, 1349, 1356, 1357, 1366, 1380, 1381, 1385, 1386, 1391,
|
| | 1399, 1436, 1440, 1441, 1444, 1452, 1454, 1457, 1475, 1476, 1477,
|
| | 1480, 1487, 1516, 1536, 1544, 1558, 1564, 1573, 1599, 1602, 1604,
|
| | 1607, 1619, 1648, 1670, 1683, 1693, 1716, 1722, 1737, 1738, 1745,
|
| | 1747, 1750, 1754, 1755, 1764, 1781, 1803, 1808, 1810, 1816, 1838,
|
| | 1844, 1847, 1855, 1860, 1866, 1873, 1905, 1911, 1917, 1921, 1923,
|
| | 1928, 1933, 1950, 1951, 1970, 1977, 1980, 1984, 1991, 2002, 2033, 2034, 2038]
|
| |
|
| | class ConfigDict(dict):
|
| | '''
|
| | Makes a dictionary behave like an object,with attribute-style access.
|
| | '''
|
| | def __getattr__(self, name):
|
| | try:
|
| | return self[name]
|
| | except:
|
| | raise AttributeError(name)
|
| |
|
| | def __setattr__(self, name, value):
|
| | self[name] = value
|
| |
|
| | def save(self, fn):
|
| | json.dump(self, open(fn, 'w'), indent=2)
|
| |
|
| | def load_dict(self, dic):
|
| | for k, v in dic.items():
|
| | self[k] = v
|
| |
|
| | def load(self, fn):
|
| | try:
|
| | d = json.load(open(fn, 'r'))
|
| | self.load_dict(d)
|
| | except Exception as e:
|
| | print(e)
|
| |
|
| | def conv_out_dim(length_in, kernel, stride, padding, dilation):
|
| | length_out = (length_in + 2 * padding - dilation * (kernel - 1) - 1)// stride + 1
|
| | return length_out
|
| |
|
| | def filter_ms(ms, thr=0.05, max_mz=2000):
|
| | mz = []
|
| | intn = []
|
| | maxi = 0
|
| | for m, i in ms:
|
| | if m < max_mz and i > maxi:
|
| | maxi = i
|
| |
|
| | for m, i in ms:
|
| | if m < max_mz and i/maxi > thr:
|
| | mz.append(m)
|
| | intn.append(round(i/maxi*100, 2))
|
| |
|
| | return mz, intn
|
| |
|
| | def calc_nls(ms, thr=0.05, max_mz=2000):
|
| | mz, intn = filter_ms(ms, thr=0.05, max_mz=2000)
|
| |
|
| | nlmass = []
|
| | nlintn = []
|
| | for a, b in it.combinations(mz[::-1], 2):
|
| | nl = a - b
|
| | if 0 < nl < 200:
|
| | nlmass.append(round(nl, 5))
|
| | idxa = mz.index(a)
|
| | idxb = mz.index(b)
|
| | nlintn.append(round((intn[idxa]+intn[idxb])/2., 5))
|
| |
|
| | nls = sorted(list(zip(nlmass, nlintn)))
|
| | return nls
|
| |
|
| | def ms_binner(ms, nls=[], min_mz=20, max_mz=2000, bin_size=0.05, add_nl=False, binary_intn=False):
|
| | """
|
| | Convert the given spectrum to a binned sparse SciPy vector.
|
| |
|
| | Parameters
|
| | ----------
|
| | spectrum_mz : np.ndarray
|
| | The peak m/z values of the spectrum to be converted to a vector.
|
| | spectrum_intensity : np.ndarray
|
| | The peak intensities of the spectrum to be converted to a vector.
|
| | min_mz : float
|
| | The minimum m/z to include in the vector.
|
| | bin_size : float
|
| | The bin size in m/z used to divide the m/z range.
|
| | num_bins : int
|
| | The number of elements of which the vector consists.
|
| |
|
| | Returns
|
| | -------
|
| | ss.csr_matrix
|
| | The binned spectrum vector.
|
| | """
|
| | if add_nl and not nls:
|
| | nls = calc_nls(ms, max_mz=max_mz)
|
| |
|
| | nltensor = None
|
| | mz, intn = filter_ms(ms)
|
| |
|
| | if add_nl:
|
| | nlmass = []
|
| | nlintn = []
|
| |
|
| | if not nls:
|
| | nls = calc_nls(ms, max_mz=max_mz)
|
| |
|
| | for m, i in nls:
|
| | if m < 200:
|
| | if binary_intn:
|
| | i = 1
|
| | nlmass.append(m)
|
| | nlintn.append(i)
|
| |
|
| | nlmass = np.array(nlmass)
|
| | nlintn = np.array(nlintn)
|
| | if len(nlintn) > 0:
|
| | nlintn = nlintn/nlintn.max()
|
| | num_nlbins = math.ceil((200) / bin_size)
|
| |
|
| | nlbins = (nlmass / bin_size).astype(np.int32)
|
| |
|
| | if len(nlmass) > 0:
|
| | vecnl = ss.csr_matrix(
|
| | (nlintn,
|
| | (np.repeat(0, len(nlintn)), nlbins)),
|
| | shape=(1, num_nlbins),
|
| | dtype=np.float32)
|
| |
|
| | vecnl = (vecnl / scipy.sparse.linalg.norm(vecnl)*100)
|
| | nltensor = torch.FloatTensor(vecnl.todense()).view(-1)
|
| | else:
|
| | nltensor = torch.zeros(num_nlbins)
|
| |
|
| | mz = np.array(mz)
|
| | keepidx = (mz <= max_mz)
|
| | mz = mz[keepidx]
|
| | intn = np.array(intn)
|
| | intn = intn[keepidx]
|
| |
|
| | if binary_intn:
|
| | intn[intn > 0] = 1.0
|
| | elif len(intn) > 0:
|
| | intn = intn/intn.max()
|
| |
|
| | num_bins = math.ceil((max_mz - min_mz) / bin_size)
|
| |
|
| | bins = ((mz - min_mz) / bin_size).astype(np.int32)
|
| |
|
| |
|
| |
|
| | if len(mz) > 0:
|
| | vec = ss.csr_matrix(
|
| | (intn,
|
| | (np.repeat(0, len(intn)), bins)),
|
| | shape=(1, num_bins),
|
| | dtype=np.float32)
|
| |
|
| | if not binary_intn:
|
| | vec = (vec / scipy.sparse.linalg.norm(vec)*100)
|
| |
|
| | mstensor = torch.FloatTensor(vec.todense()).view(-1)
|
| | else:
|
| | mstensor = torch.zeros(num_bins)
|
| |
|
| | if not nltensor is None:
|
| | return torch.cat([nltensor, mstensor], dim=0)
|
| |
|
| | return mstensor
|
| |
|
| | def formula2vec(formula, elements=['C', 'H', 'O', 'N', 'P', 'S', 'P', 'F', 'Cl', 'Br']):
|
| | formula_p = re.findall(r'([A-Z][a-z]*)(\d*)', formula)
|
| | vec = np.zeros(len(elements))
|
| | for i in range(len(formula_p)):
|
| | ele = formula_p[i][0]
|
| | num = formula_p[i][1]
|
| | if num == '':
|
| | num = 1
|
| | else:
|
| | num = int(num)
|
| | if ele in elements:
|
| | vec[elements.index(ele)] += num
|
| | return np.array(vec)
|
| |
|
| | def mol_fp_encoder0(smiles, tp='rdkit', nbits=2048):
|
| | mol = Chem.MolFromSmiles(smiles)
|
| | if mol is None:
|
| | mol = Chem.MolFromSmiles(smiles, sanitize=False)
|
| | if not mol is None:
|
| | mol.UpdatePropertyCache()
|
| | FastFindRings(mol)
|
| |
|
| | if mol is None:
|
| | return None, None
|
| |
|
| | if tp == 'morgan':
|
| | fp_vec = AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=nbits)
|
| | fp = np.frombuffer(fp_vec.ToBitString().encode(), 'u1') - ord('0')
|
| | fp = fp.tolist()
|
| | elif tp == 'morgan1':
|
| | fp_vec = AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=2048)
|
| | fp = np.frombuffer(fp_vec.ToBitString().encode(), 'u1') - ord('0')
|
| | fp = fp[FPBitIdx].tolist()
|
| | elif tp == 'macc':
|
| |
|
| | fp_vec = MACCSkeys.GenMACCSKeys(mol)
|
| | fp = np.frombuffer(fp_vec.ToBitString().encode(), 'u1') - ord('0')
|
| | fp = fp.tolist()
|
| | elif tp == 'rdkit':
|
| | fp_vec = Chem.RDKFingerprint(mol, nBitsPerHash=1)
|
| | fp = np.frombuffer(fp_vec.ToBitString().encode(), 'u1') - ord('0')
|
| | fp = fp.tolist()
|
| |
|
| | return torch.FloatTensor(fp), mol
|
| |
|
| | def mol_fp_encoder(smiles, tp='rdkit', nbits=2048):
|
| | fpenc, _ = mol_fp_encoder0(smiles, tp, nbits)
|
| | return fpenc
|
| |
|
| | def mol_fp_fm_encoder(smiles, tp='rdkit', nbits=2048):
|
| | fmenc = None
|
| | fpenc, mol = mol_fp_encoder0(smiles, tp, nbits)
|
| | if not mol is None:
|
| | fm = CalcMolFormula(mol)
|
| | fmenc = torch.FloatTensor(formula2vec(fm))
|
| | return fpenc, fmenc
|
| |
|
| | def smi2fmvec(smiles):
|
| | mol = Chem.MolFromSmiles(smiles)
|
| | if mol is None:
|
| | return None
|
| | fm = CalcMolFormula(mol)
|
| | fmenc = torch.FloatTensor(formula2vec(fm))
|
| |
|
| | return fmenc
|
| |
|
| | def mol_graph_featurizer(smiles):
|
| |
|
| | '''mol_graph = ft.calc_data_from_smile(smiles,
|
| | addh=True,
|
| | with_ring_conj=True,
|
| | with_atom_feats=True,
|
| | with_submol_fp=True,
|
| | radius=2)
|
| | '''
|
| | mol_graph = ft.calc_data_from_smile(smiles,
|
| | addh=False,
|
| | with_ring_conj=True,
|
| | with_atom_feats=True,
|
| | with_submol_fp=False,
|
| | radius=2)
|
| | return mol_graph
|
| |
|
| | def pad_V(V, max_n):
|
| | N, C = V.shape
|
| | if max_n > N:
|
| | zeros = torch.zeros(max_n-N, C)
|
| | V = torch.cat([V, zeros], dim=0)
|
| | return V
|
| |
|
| | def pad_A(A, max_n):
|
| | N, L, _ = A.shape
|
| | if max_n > N:
|
| | zeros = torch.zeros(N, L, max_n-N)
|
| | A = torch.cat([A, zeros], dim=-1)
|
| | zeros = torch.zeros(max_n-N, L, max_n)
|
| | A = torch.cat([A, zeros], dim=0)
|
| | return A
|
| |
|
| | class AvgMeter:
|
| | def __init__(self, name="Metric"):
|
| | self.name = name
|
| | self.reset()
|
| |
|
| | def reset(self):
|
| | self.avg, self.sum, self.count = [0] * 3
|
| |
|
| | def update(self, val, count=1):
|
| | self.count += count
|
| | self.sum += val * count
|
| | self.avg = self.sum / self.count
|
| |
|
| | def __repr__(self):
|
| | text = f"{self.name}: {self.avg:.4f}"
|
| | return text
|
| |
|
| | def get_lr(optimizer):
|
| | for param_group in optimizer.param_groups:
|
| | return param_group["lr"]
|
| |
|
| | def segment_max(x, size_list):
|
| | size_list = [int(i) for i in size_list]
|
| | return torch.stack([torch.max(v, 0).values for v in torch.split(x, size_list)])
|
| |
|
| | def segment_sum(x, size_list):
|
| | size_list = [int(i) for i in size_list]
|
| | return torch.stack([torch.sum(v, 0) for v in torch.split(x, size_list)])
|
| |
|
| | def segment_softmax(gate, size_list):
|
| | segmax = segment_max(gate, size_list)
|
| |
|
| | segmax_expand = torch.cat([segmax[i].repeat(n,1) for i,n in enumerate(size_list)], dim=0)
|
| | subtract = gate - segmax_expand
|
| | exp = torch.exp(subtract)
|
| | segsum = segment_sum(exp, size_list)
|
| |
|
| | segsum_expand = torch.cat([segsum[i].repeat(n,1) for i,n in enumerate(size_list)], dim=0)
|
| | attention = exp / (segsum_expand + 1e-16)
|
| |
|
| | return attention
|
| |
|
| | def pad_ms_list(ms_list, thr=0.05, min_mz=20, max_mz=2000):
|
| | thr = thr*100
|
| | mslst = []
|
| | for ms in ms_list:
|
| | ms = np.array(ms)
|
| | ms[:,1] = ms[:,1]/ms[:,1].max()*100
|
| |
|
| | if thr > 0:
|
| | ms = ms[(ms[:,1] >= thr)]
|
| |
|
| | ms = ms[(ms[:,0] >= min_mz)]
|
| | ms = ms[(ms[:,0] <= max_mz)]
|
| |
|
| | mslst.append(ms)
|
| |
|
| | size_list = [ms.shape[0] for ms in mslst]
|
| | maxlen = max(size_list)
|
| |
|
| | l = []
|
| | for ms in mslst:
|
| | extn = maxlen-len(ms)
|
| | if extn > 0:
|
| | l.append(np.concatenate([ms, [[0,0]]*extn], axis=0))
|
| | else:
|
| | l.append(ms)
|
| |
|
| | return torch.FloatTensor(np.stack(l)), torch.IntTensor(size_list)
|
| |
|