| | from modules import *
|
| | import os, sys
|
| | import numpy as np
|
| | from tqdm import tqdm
|
| | import torch
|
| | from torch import nn
|
| | from config import CFG
|
| | import utils
|
| | import json
|
| | import pandas as pd
|
| | import pickle
|
| | from rdkit import Chem
|
| | from rdkit.Chem import inchi
|
| |
|
| | def smiles_to_inchikey(smiles, nostereo=True):
|
| | try:
|
| |
|
| | mol = Chem.MolFromSmiles(smiles)
|
| | if mol is None:
|
| | return None
|
| |
|
| | if nostereo:
|
| | options = "-SNon"
|
| | inchi_string = inchi.MolToInchi(mol, options=options)
|
| | else:
|
| | inchi_string = inchi.MolToInchi(mol)
|
| |
|
| | if not inchi_string:
|
| | return None
|
| |
|
| | inchikey = inchi.InchiToInchiKey(inchi_string)
|
| |
|
| | return inchikey
|
| |
|
| | except Exception as e:
|
| | print(f"转换失败: {e}")
|
| | return None
|
| |
|
| | def calc_mol_embeddings(model, smis, cfg):
|
| | model.eval()
|
| | fp_featsl = []
|
| | gnn_featsl = []
|
| | fm_featsl = []
|
| | valid_smis = []
|
| |
|
| | for smil in smis:
|
| | smi = smil[1]
|
| | try:
|
| | if 'gnn' in cfg.mol_encoder:
|
| | gnn_feats = utils.mol_graph_featurizer(smi)
|
| | gnn_featsl.append(gnn_feats)
|
| | if 'fp' in cfg.mol_encoder:
|
| | fp_feats = utils.mol_fp_encoder(smi, tp=cfg.fptype, nbits=cfg.mol_embedding_dim).to(cfg.device)
|
| | fp_featsl.append(fp_feats)
|
| | if 'fm' in cfg.mol_encoder:
|
| | fm_feats = utils.smi2fmvec(smi).to(cfg.device)
|
| | fm_featsl.append(fm_feats)
|
| | valid_smis.append(smil)
|
| | except Exception as e:
|
| | print(smi, e)
|
| | continue
|
| |
|
| | mol_feat_list = []
|
| | if 'gnn' in cfg.mol_encoder:
|
| | vl, al, msl = [], [], []
|
| | bat = {}
|
| | for b in gnn_featsl:
|
| | if 'V' in b:
|
| | vl.append(b['V'])
|
| | if 'A' in b:
|
| | al.append(b['A'])
|
| | if 'mol_size' in b:
|
| | msl.append(b['mol_size'])
|
| |
|
| | vl1, al1 = [], []
|
| | if vl and al and msl:
|
| | max_n = max(map(lambda x:x.shape[0], vl))
|
| | for v in vl:
|
| | vl1.append(utils.pad_V(v, max_n))
|
| | for a in al:
|
| | al1.append(utils.pad_A(a, max_n))
|
| |
|
| | bat['V'] = torch.stack(vl1).to(cfg.device)
|
| | bat['A'] = torch.stack(al1).to(cfg.device)
|
| | bat['mol_size'] = torch.cat(msl, dim=0).to(cfg.device)
|
| |
|
| | mol_feat_list.append(model.mol_gnn_encoder(bat))
|
| |
|
| | del bat
|
| |
|
| | if 'fp' in cfg.mol_encoder:
|
| | mol_feat_list.append(torch.stack(fp_featsl).to(cfg.device))
|
| |
|
| | if 'fm' in cfg.mol_encoder:
|
| | mol_feat_list.append(torch.stack(fm_featsl).to(cfg.device))
|
| |
|
| | if len(mol_feat_list) > 1:
|
| | mol_features = torch.cat(mol_feat_list, dim=1).to(cfg.device)
|
| | else:
|
| | mol_features = mol_feat_list[0].to(cfg.device)
|
| |
|
| | with torch.no_grad():
|
| | mol_embeddings = model.mol_projection(mol_features)
|
| |
|
| | del mol_features, mol_feat_list
|
| |
|
| | return mol_embeddings, valid_smis
|
| |
|
| | def find_matches(model, ms, smis, cfg, n=10, batch_size=64):
|
| | model.eval()
|
| | with torch.no_grad():
|
| | ms_features = utils.ms_binner(ms, min_mz=cfg.min_mz, max_mz=cfg.max_mz, bin_size=cfg.bin_size, add_nl=cfg.add_nl, binary_intn=cfg.binary_intn).to(cfg.device)
|
| | ms_features = ms_features.unsqueeze(0)
|
| | ms_embeddings = model.ms_projection(ms_features)
|
| | ms_embeddings_n = F.normalize(ms_embeddings, p=2, dim=1)
|
| |
|
| |
|
| | all_similarities = []
|
| | all_valid_smis = []
|
| |
|
| |
|
| | all_embeddings = []
|
| | for i in tqdm(range(0, len(smis), batch_size)):
|
| | batch_smis = smis[i:i+batch_size]
|
| | batch_embeddings, valid_smis = calc_mol_embeddings(model, batch_smis, cfg)
|
| | all_embeddings.append(batch_embeddings)
|
| | all_valid_smis.extend(valid_smis)
|
| |
|
| | del batch_embeddings
|
| |
|
| |
|
| | all_embeddings = torch.cat(all_embeddings, dim=0)
|
| | all_embeddings_n = F.normalize(all_embeddings, p=2, dim=1)
|
| |
|
| |
|
| | similarities = F.cosine_similarity(all_embeddings_n, ms_embeddings_n, dim=1)
|
| |
|
| |
|
| | if n == -1 or n > len(all_valid_smis):
|
| | n = len(all_valid_smis)
|
| |
|
| | values, topk_indices = torch.topk(similarities, n)
|
| |
|
| | topk_indices_list = topk_indices.cpu().tolist()
|
| |
|
| | matchsmis = [all_valid_smis[idx] for idx in topk_indices_list]
|
| |
|
| | return matchsmis, values.cpu().numpy()*100, topk_indices_list
|
| |
|
| | def calc(models, datal, cfg):
|
| | dicall = {}
|
| | coridxd = {}
|
| |
|
| | for idx, model in enumerate(models):
|
| | for nn, data in enumerate(datal):
|
| | print(f'Calculating {nn}-th MS...')
|
| |
|
| | try:
|
| | smis, scores, indices = find_matches(model, data['ms'], data['candidates'], cfg, 50)
|
| | except Exception as e:
|
| | print(131, e)
|
| | continue
|
| |
|
| | dic = {}
|
| | for n, smil in enumerate(smis):
|
| | smi = smil[1]
|
| | if smi in dic:
|
| | dic[smi]['score'] = scores[n]
|
| | dic[smi]['iscor'] = smis[n][-1]
|
| | dic[smi]['idx'] = smis[n][0]
|
| | else:
|
| | dic[smi] = {'score': scores[n], 'iscor': smis[n][-1], 'idx': smis[n][0]}
|
| |
|
| |
|
| | ikey = smiles_to_inchikey(data['smiles'], True)
|
| | if ikey is None:
|
| | ikey = data['ikey']
|
| |
|
| | if ikey in dicall:
|
| | for k, v in dic.items():
|
| | if k in dicall[ikey]:
|
| | dicall[ikey][k]['score'] += v['score']
|
| | dicall[ikey][k]['score'] /= 2
|
| | else:
|
| | dicall[ikey][k] = v
|
| | else:
|
| | dicall[ikey] = dic
|
| |
|
| | for ikey, dic in dicall.items():
|
| | smis = [k for k in dic.keys()]
|
| | scorel = [d['score'] for d in dic.values()]
|
| | iscorl = [d['iscor'] for d in dic.values()]
|
| | indexl = [d['idx'] for d in dic.values()]
|
| |
|
| | scoretsor = torch.tensor(scorel)
|
| | n = 100
|
| | if n > len(scorel):
|
| | n = len(scorel)
|
| |
|
| | values, indices = torch.topk(scoretsor, n)
|
| |
|
| |
|
| | indices_list = indices.cpu().tolist()
|
| |
|
| | scorel = values.cpu().numpy()
|
| | smis = [smis[i] for i in indices_list]
|
| | iscorl = [iscorl[i] for i in indices_list]
|
| | indexl = [indexl[i] for i in indices_list]
|
| |
|
| | try:
|
| | i = iscorl.index(True)
|
| | k = 'Hit %.3d' %(i+1)
|
| | if k in coridxd:
|
| | coridxd[k] += 1
|
| | else:
|
| | coridxd[k] = 1
|
| | except:
|
| | pass
|
| |
|
| | ks = sorted(list(coridxd.keys()))
|
| | dc = {}
|
| | sumtop3 = 0
|
| |
|
| | for k in ks:
|
| | dc[k] = [coridxd[k]]
|
| | if k in ['Hit 001', 'Hit 002', 'Hit 003']:
|
| | sumtop3 += coridxd[k]
|
| |
|
| | for i in range(100):
|
| | k = 'Hit %.3d' %(i+1)
|
| | if not k in dc:
|
| | dc[k] = [0]
|
| |
|
| | return sumtop3, dc, dicall
|
| |
|
| | def calc_rank(dicall):
|
| | rankd = {}
|
| |
|
| | for ikey, dic in dicall.items():
|
| | smis = [k for k in dic.keys()]
|
| | scorel = [d['score'] for d in dic.values()]
|
| | iscorl = [d['iscor'] for d in dic.values()]
|
| | indexl = [d['idx'] for d in dic.values()]
|
| |
|
| | scoretsor = torch.tensor(scorel)
|
| | n = 100
|
| | if n > len(scorel):
|
| | n = len(scorel)
|
| |
|
| | values, indices = torch.topk(scoretsor, n)
|
| |
|
| | scorel = values
|
| | smis = [smis[i] for i in indices]
|
| | iscorl = [iscorl[i] for i in indices]
|
| | indexl = [indexl[i] for i in indices]
|
| |
|
| | sl = []
|
| | for n, smi in enumerate(smis):
|
| | sl.append(f'{scorel[n]}:{smi}:{smiles_to_inchikey(smi)}')
|
| |
|
| | try:
|
| | i = iscorl.index(True)
|
| | rankd[ikey] = {'Hit': i+1, 'Rank': sl}
|
| | except:
|
| | pass
|
| |
|
| | return rankd
|
| |
|
| | def predict(modelfnl, datal, datafn=''):
|
| | maxtop3 = 0
|
| | maxoutt = ''
|
| |
|
| | for fn in modelfnl:
|
| | d = torch.load(fn)
|
| | CFG.load(d['config'])
|
| | print(d['config'])
|
| | CFG.save('', True)
|
| |
|
| | model = FragSimiModel(CFG).to(CFG.device)
|
| | model.load_state_dict(d['state_dict'])
|
| |
|
| | sumtop3, dc, dicall = calc([model], datal, CFG)
|
| |
|
| | sumtop10 = 0
|
| | for k in ['Hit %.3d' %(i+1) for i in range(10)]:
|
| | if k in dc:
|
| | sumtop10 += dc[k][0]
|
| |
|
| | sumtop50 = 0
|
| | for k in ['Hit %.3d' %(i+1) for i in range(50)]:
|
| | if k in dc:
|
| | sumtop50 += dc[k][0]
|
| |
|
| | tops = {}
|
| | for i in range(100):
|
| | k = 'Hit %.3d' %(i+1)
|
| | key = k.replace('Hit', 'Top')
|
| | if not key in tops:
|
| | tops[key] = [0]
|
| | if k in dc:
|
| | for n in range(i+1):
|
| | kk = 'Hit %.3d' %(n+1)
|
| | if kk in dc:
|
| | tops[key][0] += dc[kk][0]
|
| |
|
| | outt = f'Top1: {dc.setdefault("Hit 001", [0])[0]}, top3: {sumtop3}, top10: {sumtop10}, top50: {sumtop50} of {len(datal)}'
|
| |
|
| | if sumtop3 > maxtop3:
|
| | maxtop3 = sumtop3
|
| | maxoutt = outt
|
| |
|
| | basefn = fn.replace('.pth', f'-{os.path.basename(datafn).split(".")[0]}')
|
| | rank = calc_rank(dicall)
|
| | json.dump(rank, open(basefn + '-predict-rank.json', 'w'), indent=2)
|
| |
|
| | df = pd.DataFrame(tops)
|
| | df.to_csv(basefn + '-predict-summary.csv', index=False)
|
| |
|
| | return maxoutt, maxtop3
|
| |
|
| | def main(datafn, fnl):
|
| | outl = []
|
| |
|
| | datal = json.load(open(datafn))
|
| |
|
| | n = 0
|
| | for n, fn in enumerate(fnl):
|
| | out, _ = predict([fn], datal, datafn)
|
| | print(out, os.path.basename(fn))
|
| | outl.append(out)
|
| |
|
| | print(outl)
|
| |
|
| | if __name__ == '__main__':
|
| | import time
|
| | t0 = time.time()
|
| | main(sys.argv[1], sys.argv[2:])
|
| | print(300, time.time()-t0)
|
| |
|