| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from torch.utils.data import Dataset, DataLoader |
| | import pandas as pd |
| | import numpy as np |
| | from sklearn.metrics import roc_auc_score, average_precision_score |
| | from transformers import BertModel, BertConfig |
| | import os |
| | import json |
| | from collections import defaultdict |
| | from rdkit import Chem |
| | from rdkit.Chem import Scaffolds |
| | import warnings |
| | warnings.filterwarnings('ignore') |
| | from transformers import AutoTokenizer |
| |
|
| | |
| | def global_ap(x, dim=1): |
| | return torch.mean(x, dim=dim) |
| |
|
| | class SimSonClassifier(nn.Module): |
| | def __init__(self, config: BertConfig, max_len: int, num_labels: int, dropout: float = 0.1): |
| | super(SimSonClassifier, self).__init__() |
| | self.config = config |
| | self.max_len = max_len |
| | self.num_labels = num_labels |
| | |
| | |
| | self.bert = BertModel(config, add_pooling_layer=False) |
| | self.dropout = nn.Dropout(dropout) |
| | |
| | |
| | self.classifier = nn.Linear(config.hidden_size, num_labels) |
| | |
| | def forward(self, input_ids, attention_mask=None): |
| | if attention_mask is None: |
| | attention_mask = input_ids.ne(0) |
| | |
| | outputs = self.bert( |
| | input_ids=input_ids, |
| | attention_mask=attention_mask |
| | ) |
| | |
| | hidden_states = outputs.last_hidden_state |
| | hidden_states = self.dropout(hidden_states) |
| | |
| | |
| | pooled = global_ap(hidden_states) |
| | |
| | |
| | logits = self.classifier(pooled) |
| | |
| | return logits |
| | |
| | def load_encoder_weights(self, encoder_path): |
| | """Load pretrained SimSonEncoder weights into the classifier""" |
| | encoder_state = torch.load(encoder_path, map_location='cpu') |
| | |
| | |
| | classifier_state = {} |
| | for key, value in encoder_state.items(): |
| | if key.startswith('bert.') or key.startswith('dropout.'): |
| | classifier_state[key] = value |
| | |
| | |
| | self.load_state_dict(classifier_state, strict=False) |
| | print(f"Loaded encoder weights from {encoder_path}") |
| |
|
| |
|
| |
|
| | def load_moleculenet_data(dataset_name): |
| | """Load MoleculeNet dataset and return SMILES and labels""" |
| | if dataset_name == 'bbbp': |
| | df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/BBBP.csv') |
| | smiles, labels = df.smiles, df.p_np |
| | elif dataset_name == 'clintox': |
| | df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/clintox.csv.gz', compression='gzip') |
| | smiles = df.smiles |
| | labels = df.drop(['smiles'], axis=1) |
| | elif dataset_name == 'hiv': |
| | df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/HIV.csv') |
| | smiles, labels = df.smiles, df.HIV_active |
| | elif dataset_name == 'sider': |
| | df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/sider.csv.gz', compression='gzip') |
| | smiles = df.smiles |
| | labels = df.drop(['smiles'], axis=1) |
| | elif dataset_name == 'tox21': |
| | df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/tox21.csv.gz', compression='gzip') |
| | df = df.dropna(axis=0, how='any').reset_index(drop=True) |
| | smiles = df.smiles |
| | labels = df.drop(['mol_id', 'smiles'], axis=1) |
| | elif dataset_name == 'bace': |
| | df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/bace.csv') |
| | smiles, labels = df.mol, df.Class |
| | else: |
| | raise ValueError(f"Dataset {dataset_name} not supported") |
| | |
| | return smiles, labels |
| |
|
| | class MoleculeDataset(Dataset): |
| | def __init__(self, smiles_list, labels, tokenizer, max_length=512): |
| | self.smiles = smiles_list |
| | self.labels = labels |
| | self.tokenizer = tokenizer |
| | self.max_length = max_length |
| | |
| | def __len__(self): |
| | return len(self.smiles) |
| | |
| | def __getitem__(self, idx): |
| | smiles = self.smiles[idx] |
| | |
| | |
| | encoding = self.tokenizer( |
| | smiles, |
| | truncation=True, |
| | padding='max_length', |
| | max_length=self.max_length, |
| | return_tensors='pt' |
| | ) |
| | |
| | |
| | if isinstance(self.labels, pd.Series): |
| | label = torch.tensor(self.labels.iloc[idx], dtype=torch.float32) |
| | else: |
| | label = torch.tensor(self.labels.iloc[idx].values, dtype=torch.float32) |
| | |
| | return { |
| | 'input_ids': encoding['input_ids'].flatten(), |
| | 'attention_mask': encoding['attention_mask'].flatten(), |
| | 'labels': label |
| | } |
| |
|
| | def get_loss_fn(num_labels): |
| | """Get appropriate loss function based on number of labels""" |
| | if num_labels == 1: |
| | return nn.BCEWithLogitsLoss() |
| | else: |
| | return nn.BCEWithLogitsLoss() |
| |
|
| | def compute_metrics(predictions, labels, num_labels): |
| | """Compute ROC-AUC for single or multi-label classification""" |
| | predictions = torch.sigmoid(predictions).cpu().numpy() |
| | labels = labels.cpu().numpy() |
| | |
| | if num_labels == 1: |
| | |
| | try: |
| | auc = roc_auc_score(labels, predictions) |
| | return {'roc_auc': auc} |
| | except: |
| | return {'roc_auc': 0.5} |
| | else: |
| | |
| | aucs = [] |
| | for i in range(num_labels): |
| | try: |
| | auc = roc_auc_score(labels[:, i], predictions[:, i]) |
| | aucs.append(auc) |
| | except: |
| | aucs.append(0.5) |
| | return {'roc_auc': np.mean(aucs), 'individual_aucs': aucs} |
| |
|
| | def train_epoch(model, dataloader, optimizer, loss_fn, device): |
| | model.train() |
| | total_loss = 0 |
| | |
| | for batch in dataloader: |
| | input_ids = batch['input_ids'].to(device) |
| | attention_mask = batch['attention_mask'].to(device) |
| | labels = batch['labels'].to(device) |
| | |
| | optimizer.zero_grad() |
| | |
| | outputs = model(input_ids, attention_mask) |
| | loss = loss_fn(outputs, labels) |
| | |
| | loss.backward() |
| | optimizer.step() |
| | |
| | total_loss += loss.item() |
| | |
| | return total_loss / len(dataloader) |
| |
|
| | def evaluate(model, dataloader, loss_fn, num_labels, device): |
| | model.eval() |
| | total_loss = 0 |
| | all_predictions = [] |
| | all_labels = [] |
| | |
| | with torch.no_grad(): |
| | for batch in dataloader: |
| | input_ids = batch['input_ids'].to(device) |
| | attention_mask = batch['attention_mask'].to(device) |
| | labels = batch['labels'].to(device) |
| | |
| | outputs = model(input_ids, attention_mask) |
| | loss = loss_fn(outputs, labels) |
| | |
| | total_loss += loss.item() |
| | all_predictions.append(outputs) |
| | all_labels.append(labels) |
| | |
| | all_predictions = torch.cat(all_predictions) |
| | all_labels = torch.cat(all_labels) |
| | |
| | metrics = compute_metrics(all_predictions, all_labels, num_labels) |
| | avg_loss = total_loss / len(dataloader) |
| | |
| | return avg_loss, metrics |
| |
|
| | def run_experiment(dataset_name, config, tokenizer, encoder_path=None, |
| | batch_size=32, learning_rate=1e-4, epochs=50, device='cuda'): |
| | """Run complete experiment for one dataset""" |
| | print(f"\n=== Running experiment for {dataset_name.upper()} ===") |
| | |
| | |
| | smiles, labels = load_moleculenet_data(dataset_name) |
| | print(f"Loaded {len(smiles)} samples") |
| | |
| | |
| | if isinstance(labels, pd.Series): |
| | num_labels = 1 |
| | else: |
| | num_labels = labels.shape[1] |
| | print(f"Number of labels: {num_labels}") |
| | |
| | |
| | smiles_list = smiles.tolist() |
| | train_idx, valid_idx, test_idx = scaffold_split(smiles_list) |
| | |
| | print(f"Split sizes - Train: {len(train_idx)}, Valid: {len(valid_idx)}, Test: {len(test_idx)}") |
| | |
| | train_smiles = [smiles_list[i] for i in train_idx] |
| | valid_smiles = [smiles_list[i] for i in valid_idx] |
| | test_smiles = [smiles_list[i] for i in test_idx] |
| | |
| | if isinstance(labels, pd.Series): |
| | train_labels = labels.iloc[list(train_idx)] |
| | valid_labels = labels.iloc[list(valid_idx)] |
| | test_labels = labels.iloc[list(test_idx)] |
| | else: |
| | train_labels = labels.iloc[list(train_idx)] |
| | valid_labels = labels.iloc[list(valid_idx)] |
| | test_labels = labels.iloc[list(test_idx)] |
| | |
| | |
| | train_dataset = MoleculeDataset(train_smiles, train_labels, tokenizer) |
| | valid_dataset = MoleculeDataset(valid_smiles, valid_labels, tokenizer) |
| | test_dataset = MoleculeDataset(test_smiles, test_labels, tokenizer) |
| | |
| | train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) |
| | valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False) |
| | test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) |
| | |
| | |
| | model = SimSonClassifier(config, max_len=512, num_labels=num_labels).to(device) |
| | |
| | |
| | if encoder_path: |
| | model.load_encoder_weights(encoder_path) |
| | |
| | |
| | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) |
| | loss_fn = get_loss_fn(num_labels) |
| | |
| | best_valid_loss = float('inf') |
| | best_model_path = f'best_{dataset_name}_model.pth' |
| | |
| | |
| | for epoch in range(epochs): |
| | train_loss = train_epoch(model, train_loader, optimizer, loss_fn, device) |
| | valid_loss, valid_metrics = evaluate(model, valid_loader, loss_fn, num_labels, device) |
| | |
| | |
| | if valid_loss < best_valid_loss: |
| | best_valid_loss = valid_loss |
| | torch.save(model.state_dict(), best_model_path) |
| | |
| | if epoch % 10 == 0: |
| | print(f"Epoch {epoch}: Train Loss = {train_loss:.4f}, " |
| | f"Valid Loss = {valid_loss:.4f}, Valid AUC = {valid_metrics['roc_auc']:.4f}") |
| | |
| | |
| | model.load_state_dict(torch.load(best_model_path)) |
| | test_loss, test_metrics = evaluate(model, test_loader, loss_fn, num_labels, device) |
| | |
| | print(f"Final Test Results - Loss: {test_loss:.4f}, ROC-AUC: {test_metrics['roc_auc']:.4f}") |
| | |
| | |
| | os.remove(best_model_path) |
| | |
| | return { |
| | 'dataset': dataset_name, |
| | 'num_labels': num_labels, |
| | 'test_loss': test_loss, |
| | 'test_roc_auc': test_metrics['roc_auc'], |
| | 'individual_aucs': test_metrics.get('individual_aucs', None) |
| | } |
| |
|
| | def main(): |
| | """Main function to run all experiments""" |
| | |
| | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| | print(f"Using device: {device}") |
| | |
| | |
| | |
| | |
| | tokenizer_path = 'DeepChem/ChemBERTa-77M-MTR' |
| | tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) |
| |
|
| | |
| | config = BertConfig( |
| | vocab_size=tokenizer.vocab_size, |
| | hidden_size=768, |
| | num_hidden_layers=4, |
| | num_attention_heads=12, |
| | intermediate_size=2048, |
| | max_position_embeddings=512 |
| | ) |
| | |
| | datasets = ['bbbp', 'tox21', 'sider', 'clintox', 'hiv', 'bace'] |
| | |
| | |
| | encoder_path = 'simson_checkpoints_small/simson_model_single_gpu.bin' |
| | |
| | |
| | all_results = [] |
| | for dataset in datasets: |
| | try: |
| | result = run_experiment( |
| | dataset, |
| | config, |
| | tokenizer, |
| | encoder_path=encoder_path, |
| | device=device |
| | ) |
| | all_results.append(result) |
| | except Exception as e: |
| | print(f"Error with {dataset}: {e}") |
| | |
| | |
| | print("\n" + "="*60) |
| | print("FINAL RESULTS SUMMARY") |
| | print("="*60) |
| | |
| | results_df = pd.DataFrame(all_results) |
| | print(results_df.to_string(index=False)) |
| | |
| | |
| | results_df.to_csv('moleculenet_results.csv', index=False) |
| | print(f"\nResults saved to moleculenet_results.csv") |
| | |
| | return results_df |
| |
|
| | if __name__ == "__main__": |
| | |
| | results = main() |
| |
|