| import sys |
| sys.path.append('/scratch/pranamlab/tong/ReDi_discrete/smiles') |
| import numpy as np |
| from torch.utils.data import Dataset, DataLoader |
| from sklearn.model_selection import train_test_split |
| from collections import defaultdict |
| import torch |
| import pandas as pd |
| import torch.nn as nn |
| import esm |
| from smiles_tokenizer.my_tokenizers import SMILES_SPE_Tokenizer |
| from transformers import AutoModelForMaskedLM, AutoModelForCausalLM, AutoTokenizer, AutoModel |
|
|
|
|
| class ImprovedBindingPredictor(nn.Module): |
| def __init__(self, |
| esm_dim=1280, |
| smiles_dim=768, |
| hidden_dim=512, |
| n_heads=8, |
| n_layers=3, |
| dropout=0.1): |
| super().__init__() |
| |
| |
| self.tight_threshold = 7.5 |
| self.weak_threshold = 6.0 |
| |
| |
| self.smiles_projection = nn.Linear(smiles_dim, hidden_dim) |
| self.protein_projection = nn.Linear(esm_dim, hidden_dim) |
| self.protein_norm = nn.LayerNorm(hidden_dim) |
| self.smiles_norm = nn.LayerNorm(hidden_dim) |
| |
| |
| self.cross_attention_layers = nn.ModuleList([ |
| nn.ModuleDict({ |
| 'attention': nn.MultiheadAttention(hidden_dim, n_heads, dropout=dropout), |
| 'norm1': nn.LayerNorm(hidden_dim), |
| 'ffn': nn.Sequential( |
| nn.Linear(hidden_dim, hidden_dim * 4), |
| nn.ReLU(), |
| nn.Dropout(dropout), |
| nn.Linear(hidden_dim * 4, hidden_dim) |
| ), |
| 'norm2': nn.LayerNorm(hidden_dim) |
| }) for _ in range(n_layers) |
| ]) |
| |
| |
| self.shared_head = nn.Sequential( |
| nn.Linear(hidden_dim * 2, hidden_dim), |
| nn.ReLU(), |
| nn.Dropout(dropout), |
| ) |
| |
| |
| self.regression_head = nn.Linear(hidden_dim, 1) |
| |
| |
| self.classification_head = nn.Linear(hidden_dim, 3) |
| |
| def get_binding_class(self, affinity): |
| """Convert affinity values to class indices |
| 0: tight binding (>= 7.5) |
| 1: medium binding (6.0-7.5) |
| 2: weak binding (< 6.0) |
| """ |
| if isinstance(affinity, torch.Tensor): |
| tight_mask = affinity >= self.tight_threshold |
| weak_mask = affinity < self.weak_threshold |
| medium_mask = ~(tight_mask | weak_mask) |
| |
| classes = torch.zeros_like(affinity, dtype=torch.long) |
| classes[medium_mask] = 1 |
| classes[weak_mask] = 2 |
| return classes |
| else: |
| if affinity >= self.tight_threshold: |
| return 0 |
| elif affinity < self.weak_threshold: |
| return 2 |
| else: |
| return 1 |
| |
| def forward(self, protein_emb, smiles_emb): |
| protein = self.protein_norm(self.protein_projection(protein_emb)) |
| smiles = self.smiles_norm(self.smiles_projection(smiles_emb)) |
| |
| |
| |
| |
| |
| for layer in self.cross_attention_layers: |
| |
| attended_protein = layer['attention']( |
| protein, smiles, smiles |
| )[0] |
| protein = layer['norm1'](protein + attended_protein) |
| protein = layer['norm2'](protein + layer['ffn'](protein)) |
| |
| |
| attended_smiles = layer['attention']( |
| smiles, protein, protein |
| )[0] |
| smiles = layer['norm1'](smiles + attended_smiles) |
| smiles = layer['norm2'](smiles + layer['ffn'](smiles)) |
| |
| |
| protein_pool = torch.mean(protein, dim=0) |
| smiles_pool = torch.mean(smiles, dim=0) |
| |
| |
| combined = torch.cat([protein_pool, smiles_pool], dim=-1) |
| |
| |
| shared_features = self.shared_head(combined) |
| |
| regression_output = self.regression_head(shared_features) |
| classification_logits = self.classification_head(shared_features) |
| |
| return regression_output, classification_logits |
| |
| class BindingAffinity: |
| def __init__(self, prot_seq, device, model_type='PeptideCLM'): |
| super().__init__() |
| |
| if model_type == 'PepDoRA': |
| |
| model_name = "ChatterjeeLab/PepDoRA" |
| self.pep_tokenizer = AutoTokenizer.from_pretrained(model_name) |
| self.pep_model = AutoModel.from_pretrained(model_name) |
| |
| self.model = ImprovedBindingPredictor(smiles_dim=384) |
| checkpoint = torch.load('/scratch/pranamlab/tong/ReDi_discrete/smiles/scoring/functions/binding/best_model_optuna1.pt') |
| self.model.load_state_dict(checkpoint['model_state_dict']) |
| else: |
| |
| self.pep_model = AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer |
| self.pep_tokenizer = SMILES_SPE_Tokenizer('/scratch/pranamlab/tong/ReDi_discrete/smiles/smiles_tokenizer/new_vocab.txt', |
| '/scratch/pranamlab/tong/ReDi_discrete/smiles/smiles_tokenizer/new_splits.txt') |
|
|
| |
| self.model = ImprovedBindingPredictor(smiles_dim=768).to(device) |
| checkpoint = torch.load('/scratch/pranamlab/tong/ReDi_discrete/smiles/scoring/checkpoints/binding-affinity.pt', weights_only=False) |
| self.model.load_state_dict(checkpoint['model_state_dict']) |
| |
| self.model.eval() |
| |
| self.esm_model, alphabet = esm.pretrained.esm2_t33_650M_UR50D() |
| self.esm_model.to(device) |
| self.prot_tokenizer = alphabet.get_batch_converter() |
|
|
| data = [("target", prot_seq)] |
| |
| _, _, prot_tokens = self.prot_tokenizer(data) |
| prot_tokens = prot_tokens.to(device) |
| with torch.no_grad(): |
| results = self.esm_model.forward(prot_tokens, repr_layers=[33]) |
| prot_emb = results["representations"][33] |
| |
| self.prot_emb = prot_emb[0] |
| self.prot_emb = torch.mean(self.prot_emb, dim=0, keepdim=True).to(device) |
| |
| self.device = device |
| |
| def forward(self, input_seqs): |
| with torch.no_grad(): |
| scores = [] |
| for seq in input_seqs: |
| pep_tokens = self.pep_tokenizer(seq, return_tensors='pt', padding=True) |
| |
| with torch.no_grad(): |
| emb = self.pep_model(input_ids=pep_tokens['input_ids'], |
| attention_mask=pep_tokens['attention_mask'], |
| output_hidden_states=True) |
| |
| |
| pep_emb = emb.last_hidden_state.squeeze(0).to(self.device) |
| pep_emb = torch.mean(pep_emb, dim=0, keepdim=True) |
| |
| score, logits = self.model.forward(self.prot_emb, pep_emb) |
| scores.append(score.item()) |
| return torch.tensor(scores) |
| |
| def __call__(self, input_seqs: list): |
| return self.forward(input_seqs) |
|
|
| def unittest(): |
| amhr = 'MLGSLGLWALLPTAVEAPPNRRTCVFFEAPGVRGSTKTLGELLDTGTELPRAIRCLYSRCCFGIWNLTQDRAQVEMQGCRDSDEPGCESLHCDPSPRAHPSPGSTLFTCSCGTDFCNANYSHLPPPGSPGTPGSQGPQAAPGESIWMALVLLGLFLLLLLLLGSIILALLQRKNYRVRGEPVPEPRPDSGRDWSVELQELPELCFSQVIREGGHAVVWAGQLQGKLVAIKAFPPRSVAQFQAERALYELPGLQHDHIVRFITASRGGPGRLLSGPLLVLELHPKGSLCHYLTQYTSDWGSSLRMALSLAQGLAFLHEERWQNGQYKPGIAHRDLSSQNVLIREDGSCAIGDLGLALVLPGLTQPPAWTPTQPQGPAAIMEAGTQRYMAPELLDKTLDLQDWGMALRRADIYSLALLLWEILSRCPDLRPDSSPPPFQLAYEAELGNTPTSDELWALAVQERRRPYIPSTWRCFATDPDGLRELLEDCWDADPEARLTAECVQQRLAALAHPQESHPFPESCPRGCPPLCPEDCTSIPAPTILPCRPQRSACHFSVQQGPCSRNPQPACTLSPV' |
| tfr = 'MMDQARSAFSNLFGGEPLSYTRFSLARQVDGDNSHVEMKLAVDEEENADNNTKANVTKPKRCSGSICYGTIAVIVFFLIGFMIGYLGYCKGVEPKTECERLAGTESPVREEPGEDFPAARRLYWDDLKRKLSEKLDSTDFTGTIKLLNENSYVPREAGSQKDENLALYVENQFREFKLSKVWRDQHFVKIQVKDSAQNSVIIVDKNGRLVYLVENPGGYVAYSKAATVTGKLVHANFGTKKDFEDLYTPVNGSIVIVRAGKITFAEKVANAESLNAIGVLIYMDQTKFPIVNAELSFFGHAHLGTGDPYTPGFPSFNHTQFPPSRSSGLPNIPVQTISRAAAEKLFGNMEGDCPSDWKTDSTCRMVTSESKNVKLTVSNVLKEIKILNIFGVIKGFVEPDHYVVVGAQRDAWGPGAAKSGVGTALLLKLAQMFSDMVLKDGFQPSRSIIFASWSAGDFGSVGATEWLEGYLSSLHLKAFTYINLDKAVLGTSNFKVSASPLLYTLIEKTMQNVKHPVTGQFLYQDSNWASKVEKLTLDNAAFPFLAYSGIPAVSFCFCEDTDYPYLGTTMDTYKELIERIPELNKVARAAAEVAGQFVIKLTHDVELNLDYERYNSQLLSFVRDLNQYRADIKEMGLSLQWLYSARGDFFRATSRLTTDFGNAEKTDRFVMKKLNDRVMRVEYHFLSPYVSPKESPFRHVFWGSGSHTLPALLENLKLRKQNNGAFNETLFRNQLALATWTIQGAANALSGDVWDIDNEF' |
| gfap = 'MERRRITSAARRSYVSSGEMMVGGLAPGRRLGPGTRLSLARMPPPLPTRVDFSLAGALNAGFKETRASERAEMMELNDRFASYIEKVRFLEQQNKALAAELNQLRAKEPTKLADVYQAELRELRLRLDQLTANSARLEVERDNLAQDLATVRQKLQDETNLRLEAENNLAAYRQEADEATLARLDLERKIESLEEEIRFLRKIHEEEVRELQEQLARQQVHVELDVAKPDLTAALKEIRTQYEAMASSNMHEAEEWYRSKFADLTDAAARNAELLRQAKHEANDYRRQLQSLTCDLESLRGTNESLERQMREQEERHVREAASYQEALARLEEEGQSLKDEMARHLQEYQDLLNVKLALDIEIATYRKLLEGEENRITIPVQTFSNLQIRETSLDTKSVSEGHLKRNIVVKTVEMRDGEVIKESKQEHKDVM' |
| glp1 = 'MAGAPGPLRLALLLLGMVGRAGPRPQGATVSLWETVQKWREYRRQCQRSLTEDPPPATDLFCNRTFDEYACWPDGEPGSFVNVSCPWYLPWASSVPQGHVYRFCTAEGLWLQKDNSSLPWRDLSECEESKRGERSSPEEQLLFLYIIYTVGYALSFSALVIASAILLGFRHLHCTRNYIHLNLFASFILRALSVFIKDAALKWMYSTAAQQHQWDGLLSYQDSLSCRLVFLLMQYCVAANYYWLLVEGVYLYTLLAFSVLSEQWIFRLYVSIGWGVPLLFVVPWGIVKYLYEDEGCWTRNSNMNYWLIIRLPILFAIGVNFLIFVRVICIVVSKLKANLMCKTDIKCRLAKSTLTLIPLLGTHEVIFAFVMDEHARGTLRFIKLFTELSFTSFQGLMVAILYCFVNNEVQLEFRKSWERWRLEHLHIQRDSSMKPLKCPTSSLSSGATAGSSMYTATCQASCS' |
| glast = 'MTKSNGEEPKMGGRMERFQQGVRKRTLLAKKKVQNITKEDVKSYLFRNAFVLLTVTAVIVGTILGFTLRPYRMSYREVKYFSFPGELLMRMLQMLVLPLIISSLVTGMAALDSKASGKMGMRAVVYYMTTTIIAVVIGIIIVIIIHPGKGTKENMHREGKIVRVTAADAFLDLIRNMFPPNLVEACFKQFKTNYEKRSFKVPIQANETLVGAVINNVSEAMETLTRITEELVPVPGSVNGVNALGLVVFSMCFGFVIGNMKEQGQALREFFDSLNEAIMRLVAVIMWYAPVGILFLIAGKIVEMEDMGVIGGQLAMYTVTVIVGLLIHAVIVLPLLYFLVTRKNPWVFIGGLLQALITALGTSSSSATLPITFKCLEENNGVDKRVTRFVLPVGATINMDGTALYEALAAIFIAQVNNFELNFGQIITISITATAASIGAAGIPQAGLVTMVIVLTSVGLPTDDITLIIAVDWFLDRLRTTTNVLGDSLGAGIVEHLSRHELKNRDVEMGNSVIEENEMKKPYQLIAQDNETEKPIDSETKM' |
| ncam = 'LQTKDLIWTLFFLGTAVSLQVDIVPSQGEISVGESKFFLCQVAGDAKDKDISWFSPNGEKLTPNQQRISVVWNDDSSSTLTIYNANIDDAGIYKCVVTGEDGSESEATVNVKIFQKLMFKNAPTPQEFREGEDAVIVCDVVSSLPPTIIWKHKGRDVILKKDVRFIVLSNNYLQIRGIKKTDEGTYRCEGRILARGEINFKDIQVIVNVPPTIQARQNIVNATANLGQSVTLVCDAEGFPEPTMSWTKDGEQIEQEEDDEKYIFSDDSSQLTIKKVDKNDEAEYICIAENKAGEQDATIHLKVFAKPKITYVENQTAMELEEQVTLTCEASGDPIPSITWRTSTRNISSEEKASWTRPEKQETLDGHMVVRSHARVSSLTLKSIQYTDAGEYICTASNTIGQDSQSMYLEVQYAPKLQGPVAVYTWEGNQVNITCEVFAYPSATISWFRDGQLLPSSNYSNIKIYNTPSASYLEVTPDSENDFGNYNCTAVNRIGQESLEFILVQADTPSSPSIDQVEPYSSTAQVQFDEPEATGGVPILKYKAEWRAVGEEVWHSKWYDAKEASMEGIVTIVGLKPETTYAVRLAALNGKGLGEISAASEF' |
|
|
| binding = BindingAffinity(amhr, device='cuda:0') |
| seq = ['N[C@@H](CC(=O)N)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CCSC(F)F)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](Cc1ccc(O)cc1)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CCC(=O)N)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CC(C)C)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CCC(=O)N)C(=O)N[C@@H](CCC(=O)N)C(=O)N[C@@H](CC(=O)N)C(=O)O'] |
|
|
| scores = binding(seq) |
| print(scores) |
| print(len(scores)) |
|
|
| if __name__ == '__main__': |
| unittest() |