| import sys |
| import os |
| sys.path.append('/scratch/pranamlab/tong/ReDi_discrete/smiles') |
| import xgboost as xgb |
| import torch |
| import numpy as np |
| from transformers import AutoModelForMaskedLM |
| from smiles_tokenizer.my_tokenizers import SMILES_SPE_Tokenizer |
| import warnings |
| import numpy as np |
| from rdkit import Chem, rdBase, DataStructs |
|
|
|
|
| rdBase.DisableLog('rdApp.error') |
| warnings.filterwarnings("ignore", category=DeprecationWarning) |
| warnings.filterwarnings("ignore", category=UserWarning) |
| warnings.filterwarnings("ignore", category=FutureWarning) |
|
|
| class Solubility: |
| def __init__(self, device): |
| self.predictor = xgb.Booster(model_file='/scratch/pranamlab/tong/ReDi_discrete/smiles/scoring/checkpoints/solubility-xgboost.json') |
| self.emb_model = AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all').roformer.to(device) |
| self.tokenizer = SMILES_SPE_Tokenizer('/scratch/pranamlab/tong/ReDi_discrete/smiles/smiles_tokenizer/old_vocab.txt', |
| '/scratch/pranamlab/tong/ReDi_discrete/smiles/smiles_tokenizer/old_splits.txt') |
| self.device = device |
|
|
| def generate_embeddings(self, sequences): |
| embeddings = [] |
| for sequence in sequences: |
| tokenized = self.tokenizer(sequence, return_tensors='pt').to(self.device) |
| with torch.no_grad(): |
| output = self.emb_model(**tokenized) |
| |
| embedding = output.last_hidden_state.mean(dim=1).squeeze(0).cpu().numpy() |
| embeddings.append(embedding) |
| return np.array(embeddings) |
| |
| def get_scores(self, input_seqs: list): |
| scores = np.zeros(len(input_seqs)) |
| features = self.generate_embeddings(input_seqs) |
| |
| if len(features) == 0: |
| return scores |
| |
| features = np.nan_to_num(features, nan=0.) |
| features = np.clip(features, np.finfo(np.float32).min, np.finfo(np.float32).max) |
| |
| features = xgb.DMatrix(features) |
| |
| scores = self.predictor.predict(features) |
| return scores |
| |
| def __call__(self, input_seqs: list): |
| scores = self.get_scores(input_seqs) |
| return torch.tensor(scores) |
| |
| def unittest(): |
| solubility = Solubility(device='cuda:6') |
| seq = ["N[C@@H](CC(=O)N)C(=O)N[C@@H](CCC(=O)N)C(=O)N[C@@H](CCC(=O)N)C(=O)N[C@@H]([C@H](O)C)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CCC(=O)N)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CCC(=O)N)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](C)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CC(=O)N)C(=O)N[C@@H](CCC(=O)N)C(=O)O"] |
| scores = solubility(input_seqs=seq) |
| print(scores) |
| |
| if __name__ == '__main__': |
| unittest() |