| | import pandas as pd |
| | import numpy as np |
| | import torch |
| | import torch.nn as nn |
| | import torch.optim as optim |
| | from torch.utils.data import Dataset, DataLoader |
| | from transformers import BertConfig, BertModel, AutoTokenizer |
| | from rdkit import Chem, RDLogger |
| | from rdkit.Chem.Scaffolds import MurckoScaffold |
| | import copy |
| | from tqdm import tqdm |
| | import os |
| | from sklearn.metrics import roc_auc_score, root_mean_squared_error, mean_absolute_error |
| | from itertools import compress |
| | from collections import defaultdict |
| | from sklearn.metrics.pairwise import cosine_similarity |
| | RDLogger.DisableLog('rdApp.*') |
| |
|
| |
|
| | torch.set_float32_matmul_precision('high') |
| |
|
| | |
| | class SmilesEnumerator: |
| | """Generates randomized SMILES strings for data augmentation.""" |
| | def randomize_smiles(self, smiles): |
| | try: |
| | mol = Chem.MolFromSmiles(smiles) |
| | return Chem.MolToSmiles(mol, doRandom=True, canonical=False) if mol else smiles |
| | except: |
| | return smiles |
| |
|
| |
|
| | def compute_embedding_similarity(encoder, smiles_list, tokenizer, device, max_len=256): |
| | encoder.eval() |
| | enumerator = SmilesEnumerator() |
| |
|
| | embeddings_orig = [] |
| | embeddings_aug = [] |
| |
|
| | with torch.no_grad(): |
| | for smi in smiles_list: |
| | |
| | encoding_orig = tokenizer( |
| | smi, |
| | truncation=True, |
| | padding='max_length', |
| | max_length=max_len, |
| | return_tensors='pt' |
| | ) |
| | |
| | smi_aug = enumerator.randomize_smiles(smi) |
| | encoding_aug = tokenizer( |
| | smi_aug, |
| | truncation=True, |
| | padding='max_length', |
| | max_length=max_len, |
| | return_tensors='pt' |
| | ) |
| |
|
| | input_ids_orig = encoding_orig.input_ids.to(device) |
| | attention_mask_orig = encoding_orig.attention_mask.to(device) |
| | input_ids_aug = encoding_aug.input_ids.to(device) |
| | attention_mask_aug = encoding_aug.attention_mask.to(device) |
| |
|
| | emb_orig = encoder(input_ids_orig, attention_mask_orig).cpu().numpy().flatten() |
| | emb_aug = encoder(input_ids_aug, attention_mask_aug).cpu().numpy().flatten() |
| |
|
| | embeddings_orig.append(emb_orig) |
| | embeddings_aug.append(emb_aug) |
| |
|
| | embeddings_orig = np.array(embeddings_orig) |
| | embeddings_aug = np.array(embeddings_aug) |
| |
|
| | |
| | similarities = np.array([cosine_similarity([embeddings_orig[i]], [embeddings_aug[i]])[0][0] for i in range(len(embeddings_orig))]) |
| | return similarities |
| |
|
| | |
| | def load_lists_from_url(data): |
| | if data == 'bbbp': |
| | df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/BBBP.csv') |
| | smiles, labels = df.smiles, df.p_np |
| | elif data == 'clintox': |
| | df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/clintox.csv.gz', compression='gzip') |
| | smiles = df.smiles |
| | labels = df.drop(['smiles'], axis=1) |
| | elif data == 'hiv': |
| | df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/HIV.csv') |
| | smiles, labels = df.smiles, df.HIV_active |
| | elif data == 'sider': |
| | df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/sider.csv.gz', compression='gzip') |
| | smiles = df.smiles |
| | labels = df.drop(['smiles'], axis=1) |
| | elif data == 'esol': |
| | df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/delaney-processed.csv') |
| | smiles = df.smiles |
| | labels = df['ESOL predicted log solubility in mols per litre'] |
| | elif data == 'freesolv': |
| | df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/SAMPL.csv') |
| | smiles = df.smiles |
| | labels = df.calc |
| | elif data == 'lipophicility': |
| | df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/Lipophilicity.csv') |
| | smiles, labels = df.smiles, df['exp'] |
| | elif data == 'tox21': |
| | df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/tox21.csv.gz', compression='gzip') |
| | df = df.dropna(axis=0, how='any').reset_index(drop=True) |
| | smiles = df.smiles |
| | labels = df.drop(['mol_id', 'smiles'], axis=1) |
| | elif data == 'bace': |
| | df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/bace.csv') |
| | smiles, labels = df.mol, df.Class |
| | elif data == 'qm8': |
| | df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/qm8.csv') |
| | df = df.dropna(axis=0, how='any').reset_index(drop=True) |
| | smiles = df.smiles |
| | labels = df.drop(['smiles', 'E2-PBE0.1', 'E1-PBE0.1', 'f1-PBE0.1', 'f2-PBE0.1'], axis=1) |
| | return smiles, labels |
| |
|
| | |
| | class ScaffoldSplitter: |
| | def __init__(self, data, seed, train_frac=0.8, val_frac=0.1, test_frac=0.1, include_chirality=True): |
| | self.data = data |
| | self.seed = seed |
| | self.include_chirality = include_chirality |
| | self.train_frac = train_frac |
| | self.val_frac = val_frac |
| | self.test_frac = test_frac |
| |
|
| | def generate_scaffold(self, smiles): |
| | mol = Chem.MolFromSmiles(smiles) |
| | scaffold = MurckoScaffold.MurckoScaffoldSmiles(mol=mol, includeChirality=self.include_chirality) |
| | return scaffold |
| |
|
| | def scaffold_split(self): |
| | smiles, labels = load_lists_from_url(self.data) |
| | non_null = np.ones(len(smiles)) == 0 |
| |
|
| | if self.data in {'tox21', 'sider', 'clintox'}: |
| | for i in range(len(smiles)): |
| | if Chem.MolFromSmiles(smiles[i]) and labels.loc[i].isnull().sum() == 0: |
| | non_null[i] = 1 |
| | else: |
| | for i in range(len(smiles)): |
| | if Chem.MolFromSmiles(smiles[i]): |
| | non_null[i] = 1 |
| |
|
| | smiles_list = list(compress(enumerate(smiles), non_null)) |
| | rng = np.random.RandomState(self.seed) |
| |
|
| | scaffolds = defaultdict(list) |
| | for i, sms in smiles_list: |
| | scaffold = self.generate_scaffold(sms) |
| | scaffolds[scaffold].append(i) |
| |
|
| | scaffold_sets = list(scaffolds.values()) |
| | rng.shuffle(scaffold_sets) |
| | n_total_val = int(np.floor(self.val_frac * len(smiles_list))) |
| | n_total_test = int(np.floor(self.test_frac * len(smiles_list))) |
| | train_idx, val_idx, test_idx = [], [], [] |
| |
|
| | for scaffold_set in scaffold_sets: |
| | if len(val_idx) + len(scaffold_set) <= n_total_val: |
| | val_idx.extend(scaffold_set) |
| | elif len(test_idx) + len(scaffold_set) <= n_total_test: |
| | test_idx.extend(scaffold_set) |
| | else: |
| | train_idx.extend(scaffold_set) |
| | return train_idx, val_idx, test_idx |
| |
|
| | |
| | def random_split_indices(n, seed=42, train_frac=0.8, val_frac=0.1, test_frac=0.1): |
| | np.random.seed(seed) |
| | indices = np.random.permutation(n) |
| | n_train = int(n * train_frac) |
| | n_val = int(n * val_frac) |
| | train_idx = indices[:n_train] |
| | val_idx = indices[n_train:n_train+n_val] |
| | test_idx = indices[n_train+n_val:] |
| | return train_idx.tolist(), val_idx.tolist(), test_idx.tolist() |
| |
|
| | |
| | class MoleculeDataset(Dataset): |
| | def __init__(self, smiles_list, labels, tokenizer, max_len=512): |
| | self.smiles_list = smiles_list |
| | self.labels = labels |
| | self.tokenizer = tokenizer |
| | self.max_len = max_len |
| |
|
| | def __len__(self): |
| | return len(self.smiles_list) |
| |
|
| | def __getitem__(self, idx): |
| | smiles = self.smiles_list[idx] |
| | label = self.labels.iloc[idx] |
| |
|
| | encoding = self.tokenizer( |
| | smiles, |
| | truncation=True, |
| | padding='max_length', |
| | max_length=self.max_len, |
| | return_tensors='pt' |
| | ) |
| | item = {key: val.squeeze(0) for key, val in encoding.items()} |
| | if isinstance(label, pd.Series): |
| | label_values = label.values.astype(np.float32) |
| | else: |
| | label_values = np.array([label], dtype=np.float32) |
| | item['labels'] = torch.tensor(label_values, dtype=torch.float) |
| | return item |
| |
|
| | |
| | def global_ap(x): |
| | return torch.mean(x.view(x.size(0), x.size(1), -1), dim=1) |
| |
|
| | class SimSonEncoder(nn.Module): |
| | def __init__(self, config: BertConfig, max_len: int, dropout: float = 0.1): |
| | super(SimSonEncoder, self).__init__() |
| | self.config = config |
| | self.max_len = max_len |
| | self.bert = BertModel(config, add_pooling_layer=False) |
| | self.linear = nn.Linear(config.hidden_size, max_len) |
| | self.dropout = nn.Dropout(dropout) |
| | def forward(self, input_ids, attention_mask=None): |
| | if attention_mask is None: |
| | attention_mask = input_ids.ne(self.config.pad_token_id) |
| | outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) |
| | hidden_states = self.dropout(outputs.last_hidden_state) |
| | pooled = global_ap(hidden_states) |
| | return self.linear(pooled) |
| |
|
| | class SimSonClassifier(nn.Module): |
| | def __init__(self, encoder: SimSonEncoder, num_labels: int, dropout=0.1): |
| | super(SimSonClassifier, self).__init__() |
| | self.encoder = encoder |
| | self.clf = nn.Linear(encoder.max_len, num_labels) |
| | self.relu = nn.ReLU() |
| | self.dropout = nn.Dropout(dropout) |
| | def forward(self, input_ids, attention_mask=None): |
| | x = self.encoder(input_ids, attention_mask) |
| | x = self.relu(self.dropout(x)) |
| | logits = self.clf(x) |
| | return logits |
| |
|
| | def load_encoder_params(self, state_dict_path): |
| | self.encoder.load_state_dict(torch.load(state_dict_path)) |
| | print("Pretrained encoder parameters loaded.") |
| |
|
| | |
| | def get_criterion(task_type, num_labels): |
| | if task_type == 'classification': |
| | return nn.BCEWithLogitsLoss() |
| | elif task_type == 'regression': |
| | return nn.MSELoss() |
| | else: |
| | raise ValueError(f"Unknown task type: {task_type}") |
| |
|
| | def train_epoch(model, dataloader, optimizer, scheduler, criterion, device): |
| | model.train() |
| | total_loss = 0 |
| | for batch in dataloader: |
| | inputs = {k: v.to(device) for k, v in batch.items() if k != 'labels'} |
| | labels = batch['labels'].to(device) |
| | optimizer.zero_grad() |
| | outputs = model(**inputs) |
| | loss = criterion(outputs, labels) |
| | loss.backward() |
| | optimizer.step() |
| | |
| | total_loss += loss.item() |
| | return total_loss / len(dataloader) |
| |
|
| | def eval_epoch(model, dataloader, criterion, device): |
| | model.eval() |
| | total_loss = 0 |
| | with torch.no_grad(): |
| | for batch in dataloader: |
| | inputs = {k: v.to(device) for k, v in batch.items() if k != 'labels'} |
| | labels = batch['labels'].to(device) |
| | outputs = model(**inputs) |
| | loss = criterion(outputs, labels) |
| | total_loss += loss.item() |
| | return total_loss / len(dataloader) |
| |
|
| | def test_model(model, dataloader, device): |
| | model.eval() |
| | all_preds, all_labels = [], [] |
| | with torch.no_grad(): |
| | for batch in dataloader: |
| | inputs = {k: v.to(device) for k, v in batch.items() if k != 'labels'} |
| | labels = batch['labels'] |
| | outputs = model(**inputs) |
| | preds = torch.sigmoid(outputs) |
| | all_preds.append(preds.cpu().numpy()) |
| | all_labels.append(labels.numpy()) |
| | return np.concatenate(all_preds), np.concatenate(all_labels) |
| |
|
| | def calc_val_metrics(model, dataloader, criterion, device, task_type): |
| | model.eval() |
| | all_labels, all_preds = [], [] |
| | total_loss = 0 |
| | with torch.no_grad(): |
| | for batch in dataloader: |
| | inputs = {k: v.to(device) for k, v in batch.items() if k != 'labels'} |
| | labels = batch['labels'].to(device) |
| | outputs = model(**inputs) |
| | loss = criterion(outputs, labels) |
| | total_loss += loss.item() |
| | if task_type == 'classification': |
| | pred_probs = torch.sigmoid(outputs).cpu().numpy() |
| | all_preds.append(pred_probs) |
| | all_labels.append(labels.cpu().numpy()) |
| | else: |
| | |
| | preds = outputs.cpu().numpy() |
| | all_preds.append(preds) |
| | all_labels.append(labels.cpu().numpy()) |
| | avg_loss = total_loss / len(dataloader) |
| | if task_type == 'classification': |
| | y_true = np.concatenate(all_labels) |
| | y_pred = np.concatenate(all_preds) |
| | try: |
| | score = roc_auc_score(y_true, y_pred, average='macro') |
| | except Exception: |
| | score = 0.0 |
| | return avg_loss, score |
| | else: |
| | return avg_loss, None |
| |
|
| | |
| | def main(): |
| | DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| | print(f"Using device: {DEVICE}") |
| |
|
| | DATASETS_TO_RUN = { |
| | |
| | |
| | |
| | |
| | |
| | |
| | 'clintox': {'task_type': 'classification', 'num_labels': 2, 'split': 'random'}, |
| | |
| | } |
| | PATIENCE = 15 |
| | EPOCHS = 50 |
| | LEARNING_RATE = 1e-4 |
| | BATCH_SIZE = 16 |
| | MAX_LEN = 512 |
| |
|
| | TOKENIZER = AutoTokenizer.from_pretrained('DeepChem/ChemBERTa-77M-MTR') |
| | ENCODER_CONFIG = BertConfig( |
| | vocab_size=TOKENIZER.vocab_size, |
| | hidden_size=768, |
| | num_hidden_layers=4, |
| | num_attention_heads=12, |
| | intermediate_size=2048, |
| | max_position_embeddings=512 |
| | ) |
| |
|
| | aggregated_results = {} |
| |
|
| | for name, info in DATASETS_TO_RUN.items(): |
| | print(f"\n{'='*20} Processing Dataset: {name.upper()} ({info['split']} split) {'='*20}") |
| | smiles, labels = load_lists_from_url(name) |
| |
|
| | |
| | if info.get('split', 'scaffold') == 'scaffold': |
| | splitter = ScaffoldSplitter(data=name, seed=42) |
| | train_idx, val_idx, test_idx = splitter.scaffold_split() |
| | elif info['split'] == 'random': |
| | train_idx, val_idx, test_idx = random_split_indices(len(smiles), seed=42) |
| | else: |
| | raise ValueError(f"Unknown split type for {name}: {info['split']}") |
| |
|
| | train_smiles = smiles.iloc[train_idx].reset_index(drop=True) |
| | train_labels = labels.iloc[train_idx].reset_index(drop=True) |
| | val_smiles = smiles.iloc[val_idx].reset_index(drop=True) |
| | val_labels = labels.iloc[val_idx].reset_index(drop=True) |
| | test_smiles = smiles.iloc[test_idx].reset_index(drop=True) |
| | test_labels = labels.iloc[test_idx].reset_index(drop=True) |
| | print(f"Data split - Train: {len(train_smiles)}, Val: {len(val_smiles)}, Test: {len(test_smiles)}") |
| |
|
| | train_dataset = MoleculeDataset(train_smiles, train_labels, TOKENIZER, MAX_LEN) |
| | val_dataset = MoleculeDataset(val_smiles, val_labels, TOKENIZER, MAX_LEN) |
| | test_dataset = MoleculeDataset(test_smiles, test_labels, TOKENIZER, MAX_LEN) |
| |
|
| | train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True) |
| | val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False) |
| | test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False) |
| |
|
| | encoder = SimSonEncoder(ENCODER_CONFIG, 512) |
| | encoder = torch.compile(encoder) |
| | model = SimSonClassifier(encoder, num_labels=info['num_labels']).to(DEVICE) |
| | model.load_encoder_params('../simson_checkpoints/checkpoint_best_model.bin') |
| | criterion = get_criterion(info['task_type'], info['num_labels']) |
| | optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=0.0024) |
| | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.59298) |
| |
|
| | best_val_loss = float('-inf') |
| | best_model_state = None |
| | current_patience = 0 |
| | for epoch in range(EPOCHS): |
| | train_loss = train_epoch(model, train_loader, optimizer, scheduler, criterion, DEVICE) |
| | val_loss, val_metric = calc_val_metrics(model, val_loader, criterion, 'cuda', info['task_type']) |
| | print(f"Epoch {epoch+1}/{EPOCHS} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | ROC AUC: {val_metric:.4f}") |
| |
|
| | if val_metric <= val_loss: |
| | best_val_loss = val_loss |
| | best_model_state = copy.deepcopy(model.state_dict()) |
| | print(f" -> New best model saved with validation loss: {best_val_loss:.4f}") |
| | current_patience = 0 |
| | else: |
| | current_patience += 1 |
| | if current_patience >= PATIENCE: |
| | print(f'Early stopping at {PATIENCE} epochs') |
| | break |
| |
|
| | print("\nTesting with the best model...") |
| | if not best_model_state is None: |
| | model.load_state_dict(best_model_state) |
| | test_loss = eval_epoch(model, test_loader, criterion, DEVICE) |
| | print(f'Test loss: {test_loss}') |
| | test_preds, test_true = test_model(model, test_loader, DEVICE) |
| |
|
| | aggregated_results[name] = { |
| | 'best_val_loss': best_val_loss, |
| | 'test_predictions': test_preds, |
| | 'test_labels': test_true |
| | } |
| | print(f"Finished testing for {name}.") |
| | test_smiles_list = list(test_smiles) |
| | similarities = compute_embedding_similarity( |
| | model.encoder, test_smiles_list, TOKENIZER, DEVICE, MAX_LEN |
| | ) |
| | print(f"Similarity score: {similarities.mean():.4f}") |
| | if name == 'do_not_save': |
| | torch.save(model.encoder.state_dict(), 'moleculenet_clintox_encoder.bin') |
| |
|
| |
|
| |
|
| | print(f"\n{'='*20} AGGREGATED RESULTS {'='*20}") |
| | for name, result in aggregated_results.items(): |
| | if name in ['bbbp', 'tox21', 'sider', 'clintox', 'hiv', 'bace']: |
| | auc = roc_auc_score(result['test_labels'], result['test_predictions'], average='macro') |
| | print(f'{name} ROC AUC: {auc}') |
| |
|
| | if name in ['lipophicility', 'esol', 'qm8']: |
| | rmse = root_mean_squared_error(result['test_labels'], result['test_predictions']) |
| | mae = mean_absolute_error(result['test_labels'], result['test_predictions']) |
| | print(f'{name} MAE: {mae}') |
| | print(f'{name} RMSE: {rmse}') |
| |
|
| | print("\nScript finished.") |
| |
|
| | if __name__ == '__main__': |
| | main() |
| |
|