| |
|
|
| import os |
| import wandb |
| import lightning.pytorch as pl |
|
|
| from omegaconf import OmegaConf |
| from lightning.pytorch.strategies import DDPStrategy |
| from lightning.pytorch.loggers import WandbLogger |
| from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor |
|
|
| from src.utils.model_utils import _print |
| from src.guidance.solubility_module import SolubilityClassifier |
| from src.guidance.dataloader import MembraneDataModule, get_datasets |
|
|
|
|
| config = OmegaConf.load("/scratch/sgoel/MeMDLM_v2/src/configs/guidance.yaml") |
| wandb.login(key='2b76a2fa2c1cdfddc5f443602c17b011fefb0a8f') |
|
|
| |
| datasets = get_datasets(config) |
| data_module = MembraneDataModule( |
| config=config, |
| train_dataset=datasets['train'], |
| val_dataset=datasets['val'], |
| test_dataset=datasets['test'], |
| ) |
|
|
| |
| |
| wandb_logger = WandbLogger(**config.wandb) |
|
|
| |
| lr_monitor = LearningRateMonitor(logging_interval="step") |
| checkpoint_callback = ModelCheckpoint( |
| monitor="val/loss", |
| save_top_k=1, |
| mode="min", |
| dirpath=config.checkpointing.save_dir, |
| filename="best_model", |
| ) |
|
|
| |
| trainer = pl.Trainer( |
| max_steps=config.training.max_steps, |
| accelerator="cuda", |
| devices=1, |
| |
| callbacks=[checkpoint_callback, lr_monitor], |
| logger=wandb_logger, |
| log_every_n_steps=config.training.log_every_n_steps |
| ) |
|
|
| |
| ckpt_dir = config.checkpointing.save_dir |
| os.makedirs(ckpt_dir, exist_ok=True) |
|
|
| |
| model = SolubilityClassifier(config) |
|
|
| |
| if config.training.mode == "train": |
| trainer.fit(model, datamodule=data_module) |
|
|
| elif config.training.mode == "test": |
| ckpt_path = os.path.join(ckpt_dir, "best_model.ckpt") |
| state_dict = model.get_state_dict(ckpt_path) |
| model.load_state_dict(state_dict) |
| trainer.test(model, datamodule=data_module, ckpt_path=ckpt_path) |
| else: |
| raise ValueError(f"{config.training.mode} is invalid. Must be 'train' or 'test'") |
|
|
| wandb.finish() |
|
|