| |
|
|
| """ |
| train_brain2vec.py |
| |
| Trains a 3D VAE-based Brain2Vec model using MONAI. This script implements |
| autoencoder training with adversarial loss (via a patch discriminator), |
| a perceptual loss, and KL divergence regularization for robust latent |
| representations. |
| |
| Example usage: |
| python train_brain2vec.py \ |
| --dataset_csv inputs.csv \ |
| --cache_dir ./ae_cache \ |
| --output_dir ./ae_output \ |
| --n_epochs 10 |
| """ |
|
|
| import os |
| os.environ["PYTORCH_WEIGHTS_ONLY"] = "False" |
| from typing import Optional, Union |
| import pandas as pd |
| import argparse |
| import numpy as np |
| import warnings |
| import torch |
| import torch.nn as nn |
| from torch import Tensor |
| from torch.optim.optimizer import Optimizer |
| from torch.nn import L1Loss |
| from torch.utils.data import DataLoader |
| from torch.amp import autocast |
| from torch.amp import GradScaler |
| from generative.networks.nets import ( |
| AutoencoderKL, |
| PatchDiscriminator, |
| ) |
| from generative.losses import PerceptualLoss, PatchAdversarialLoss |
| from monai.data import Dataset, PersistentDataset |
| from monai.transforms.transform import Transform |
| from monai import transforms |
| from monai.utils import set_determinism |
| from monai.data.meta_tensor import MetaTensor |
| import torch.serialization |
| from numpy.core.multiarray import _reconstruct |
| from numpy import ndarray, dtype |
| torch.serialization.add_safe_globals([_reconstruct]) |
| torch.serialization.add_safe_globals([MetaTensor]) |
| torch.serialization.add_safe_globals([ndarray]) |
| torch.serialization.add_safe_globals([dtype]) |
| from tqdm import tqdm |
| import matplotlib.pyplot as plt |
| from torch.utils.tensorboard import SummaryWriter |
|
|
| |
| RESOLUTION = 2 |
|
|
| |
| INPUT_SHAPE_1mm = (182, 218, 182) |
|
|
| |
| INPUT_SHAPE_1p5mm = (122, 146, 122) |
|
|
| |
| |
| INPUT_SHAPE_AE = (80, 96, 80) |
|
|
| |
| LATENT_SHAPE_AE = (1, 10, 12, 10) |
|
|
|
|
| def load_if(checkpoints_path: Optional[str], network: nn.Module) -> nn.Module: |
| """ |
| Load pretrained weights if available. |
| |
| Args: |
| checkpoints_path (Optional[str]): path of the checkpoints |
| network (nn.Module): the neural network to initialize |
| |
| Returns: |
| nn.Module: the initialized neural network |
| """ |
| if checkpoints_path is not None: |
| assert os.path.exists(checkpoints_path), 'Invalid path' |
| network.load_state_dict(torch.load(checkpoints_path)) |
| return network |
|
|
|
|
| def init_autoencoder(checkpoints_path: Optional[str] = None) -> nn.Module: |
| """ |
| Load the KL autoencoder (pretrained if `checkpoints_path` points to previous params). |
| |
| Args: |
| checkpoints_path (Optional[str], optional): path of the checkpoints. Defaults to None. |
| |
| Returns: |
| nn.Module: the KL autoencoder |
| """ |
| autoencoder = AutoencoderKL(spatial_dims=3, |
| in_channels=1, |
| out_channels=1, |
| latent_channels=1, |
| num_channels=(64, 128, 128, 128), |
| num_res_blocks=2, |
| norm_num_groups=32, |
| norm_eps=1e-06, |
| attention_levels=(False, False, False, False), |
| with_decoder_nonlocal_attn=False, |
| with_encoder_nonlocal_attn=False) |
| return load_if(checkpoints_path, autoencoder) |
|
|
|
|
| def init_patch_discriminator(checkpoints_path: Optional[str] = None) -> nn.Module: |
| """ |
| Load the patch discriminator (pretrained if `checkpoints_path` points to previous params). |
| |
| Args: |
| checkpoints_path (Optional[str], optional): path of the checkpoints. Defaults to None. |
| |
| Returns: |
| nn.Module: the patch discriminator |
| """ |
| patch_discriminator = PatchDiscriminator(spatial_dims=3, |
| num_layers_d=3, |
| num_channels=32, |
| in_channels=1, |
| out_channels=1) |
| return load_if(checkpoints_path, patch_discriminator) |
|
|
|
|
| class KLDivergenceLoss: |
| """ |
| A class for computing the Kullback-Leibler divergence loss. |
| """ |
| |
| def __call__(self, z_mu: Tensor, z_sigma: Tensor) -> Tensor: |
| """ |
| Computes the KL divergence loss for the given parameters. |
| |
| Args: |
| z_mu (Tensor): The mean of the distribution. |
| z_sigma (Tensor): The standard deviation of the distribution. |
| |
| Returns: |
| Tensor: The computed KL divergence loss, averaged over the batch size. |
| """ |
|
|
| kl_loss = 0.5 * torch.sum(z_mu.pow(2) + z_sigma.pow(2) - torch.log(z_sigma.pow(2)) - 1, dim=[1, 2, 3, 4]) |
| return torch.sum(kl_loss) / kl_loss.shape[0] |
|
|
|
|
| class GradientAccumulation: |
| """ |
| Implements gradient accumulation to facilitate training with larger |
| effective batch sizes than what can be physically accommodated in memory. |
| """ |
|
|
| def __init__(self, |
| actual_batch_size: int, |
| expect_batch_size: int, |
| loader_len: int, |
| optimizer: Optimizer, |
| grad_scaler: Optional[GradScaler] = None) -> None: |
| """ |
| Initializes the GradientAccumulation instance with the necessary parameters for |
| managing gradient accumulation. |
| |
| Args: |
| actual_batch_size (int): The size of the mini-batches actually used in training. |
| expect_batch_size (int): The desired (effective) batch size to simulate through gradient accumulation. |
| loader_len (int): The length of the data loader, representing the total number of mini-batches. |
| optimizer (Optimizer): The optimizer used for performing optimization steps. |
| grad_scaler (Optional[GradScaler], optional): A GradScaler for mixed precision training. Defaults to None. |
| |
| Raises: |
| AssertionError: If `expect_batch_size` is not divisible by `actual_batch_size`. |
| """ |
|
|
| assert expect_batch_size % actual_batch_size == 0, \ |
| 'expect_batch_size must be divisible by actual_batch_size' |
| self.actual_batch_size = actual_batch_size |
| self.expect_batch_size = expect_batch_size |
| self.loader_len = loader_len |
| self.optimizer = optimizer |
| self.grad_scaler = grad_scaler |
|
|
| |
| |
| self.steps_until_update = expect_batch_size / actual_batch_size |
|
|
| def step(self, loss: Tensor, step: int) -> None: |
| """ |
| Performs a backward pass for the given loss and potentially executes an optimization |
| step if the conditions for gradient accumulation are met. The optimization step is taken |
| only after a specified number of steps (defined by the expected batch size) or at the end |
| of the dataset. |
| |
| Args: |
| loss (Tensor): The loss value for the current forward pass. |
| step (int): The current step (mini-batch index) within the epoch. |
| """ |
| loss = loss / self.expect_batch_size |
| |
| if self.grad_scaler is not None: |
| self.grad_scaler.scale(loss).backward() |
| else: |
| loss.backward() |
| if (step + 1) % self.steps_until_update == 0 or (step + 1) == self.loader_len: |
| if self.grad_scaler is not None: |
| self.grad_scaler.step(self.optimizer) |
| self.grad_scaler.update() |
| else: |
| self.optimizer.step() |
| self.optimizer.zero_grad(set_to_none=True) |
|
|
|
|
| class AverageLoss: |
| """ |
| Utility class to track losses |
| and metrics during training. |
| """ |
|
|
| def __init__(self): |
| self.losses_accumulator = {} |
| |
| def put(self, loss_key:str, loss_value:Union[int,float]) -> None: |
| """ |
| Store value |
| |
| Args: |
| loss_key (str): Metric name |
| loss_value (int | float): Metric value to store |
| """ |
| if loss_key not in self.losses_accumulator: |
| self.losses_accumulator[loss_key] = [] |
| self.losses_accumulator[loss_key].append(loss_value) |
| |
| def pop_avg(self, loss_key:str) -> float: |
| """ |
| Average the stored values of a given metric |
| |
| Args: |
| loss_key (str): Metric name |
| |
| Returns: |
| float: average of the stored values |
| """ |
| if loss_key not in self.losses_accumulator: |
| return None |
| losses = self.losses_accumulator[loss_key] |
| self.losses_accumulator[loss_key] = [] |
| return sum(losses) / len(losses) |
| |
| def to_tensorboard(self, writer: SummaryWriter, step: int): |
| """ |
| Logs the average value of all the metrics stored |
| into Tensorboard. |
| |
| Args: |
| writer (SummaryWriter): Tensorboard writer |
| step (int): Tensorboard logging global step |
| """ |
| for metric_key in self.losses_accumulator.keys(): |
| writer.add_scalar(metric_key, self.pop_avg(metric_key), step) |
|
|
|
|
| def get_dataset_from_pd(df: pd.DataFrame, transforms_fn: Transform, cache_dir: Optional[str]) -> Union[Dataset,PersistentDataset]: |
| """ |
| If `cache_dir` is defined, returns a `monai.data.PersistenDataset`. |
| Otherwise, returns a simple `monai.data.Dataset`. |
| |
| Args: |
| df (pd.DataFrame): Dataframe describing each image in the longitudinal dataset. |
| transforms_fn (Transform): Set of transformations |
| cache_dir (Optional[str]): Cache directory (ensure enough storage is available) |
| |
| Returns: |
| Dataset|PersistentDataset: The dataset |
| """ |
| assert cache_dir is None or os.path.exists(cache_dir), 'Invalid cache directory path' |
| data = df.to_dict(orient='records') |
| return Dataset(data=data, transform=transforms_fn) if cache_dir is None \ |
| else PersistentDataset(data=data, transform=transforms_fn, cache_dir=cache_dir) |
|
|
|
|
| def tb_display_reconstruction(writer, step, image, recon): |
| """ |
| Display reconstruction in TensorBoard during AE training. |
| """ |
| plt.style.use('dark_background') |
| _, ax = plt.subplots(ncols=3, nrows=2, figsize=(7, 5)) |
| for _ax in ax.flatten(): _ax.set_axis_off() |
|
|
| if len(image.shape) == 4: image = image.squeeze(0) |
| if len(recon.shape) == 4: recon = recon.squeeze(0) |
|
|
| ax[0, 0].set_title('original image', color='cyan') |
| ax[0, 0].imshow(image[image.shape[0] // 2, :, :], cmap='gray') |
| ax[0, 1].imshow(image[:, image.shape[1] // 2, :], cmap='gray') |
| ax[0, 2].imshow(image[:, :, image.shape[2] // 2], cmap='gray') |
|
|
| ax[1, 0].set_title('reconstructed image', color='magenta') |
| ax[1, 0].imshow(recon[recon.shape[0] // 2, :, :], cmap='gray') |
| ax[1, 1].imshow(recon[:, recon.shape[1] // 2, :], cmap='gray') |
| ax[1, 2].imshow(recon[:, :, recon.shape[2] // 2], cmap='gray') |
|
|
| plt.tight_layout() |
| writer.add_figure('Reconstruction', plt.gcf(), global_step=step) |
|
|
|
|
| def set_environment(seed: int = 0) -> None: |
| """ |
| Set deterministic behavior for reproducibility. |
| |
| Args: |
| seed (int, optional): Seed value. Defaults to 0. |
| """ |
| set_determinism(seed) |
|
|
|
|
| def train( |
| dataset_csv: str, |
| cache_dir: str, |
| output_dir: str, |
| aekl_ckpt: Optional[str] = None, |
| disc_ckpt: Optional[str] = None, |
| num_workers: int = 8, |
| n_epochs: int = 5, |
| max_batch_size: int = 2, |
| batch_size: int = 16, |
| lr: float = 1e-4, |
| aug_p: float = 0.8, |
| device: str = ('cuda' if torch.cuda.is_available() else |
| 'cpu'), |
| ) -> None: |
| """ |
| Train the autoencoder and discriminator models. |
| |
| Args: |
| dataset_csv (str): Path to the dataset CSV file. |
| cache_dir (str): Directory for caching data. |
| output_dir (str): Directory to save model checkpoints. |
| aekl_ckpt (Optional[str], optional): Path to the autoencoder checkpoint. Defaults to None. |
| disc_ckpt (Optional[str], optional): Path to the discriminator checkpoint. Defaults to None. |
| num_workers (int, optional): Number of data loader workers. Defaults to 8. |
| n_epochs (int, optional): Number of training epochs. Defaults to 5. |
| max_batch_size (int, optional): Actual batch size per iteration. Defaults to 2. |
| batch_size (int, optional): Expected (effective) batch size. Defaults to 16. |
| lr (float, optional): Learning rate. Defaults to 1e-4. |
| aug_p (float, optional): Augmentation probability. Defaults to 0.8. |
| device (str, optional): Device to run the training on. Defaults to 'cuda' if available. |
| """ |
| set_environment(0) |
|
|
| transforms_fn = transforms.Compose([ |
| transforms.CopyItemsD(keys={'image_path'}, names=['image']), |
| transforms.LoadImageD(image_only=True, keys=['image']), |
| transforms.EnsureChannelFirstD(keys=['image']), |
| transforms.SpacingD(pixdim=2, keys=['image']), |
| transforms.ResizeWithPadOrCropD(spatial_size=(80, 96, 80), mode='minimum', keys=['image']), |
| transforms.ScaleIntensityD(minv=0, maxv=1, keys=['image']) |
| ]) |
|
|
| dataset_df = pd.read_csv(dataset_csv) |
| train_df = dataset_df[dataset_df.split == 'train'] |
| trainset = get_dataset_from_pd(train_df, transforms_fn, cache_dir) |
|
|
| train_loader = DataLoader( |
| dataset=trainset, |
| num_workers=num_workers, |
| batch_size=max_batch_size, |
| shuffle=True, |
| persistent_workers=True, |
| pin_memory=True, |
| ) |
|
|
| print('Device is %s' %(device)) |
| autoencoder = init_autoencoder(aekl_ckpt).to(device) |
| discriminator = init_patch_discriminator(disc_ckpt).to(device) |
|
|
| |
| adv_weight = 0.025 |
| perceptual_weight = 0.001 |
| kl_weight = 1e-7 |
|
|
| |
| l1_loss_fn = L1Loss() |
| kl_loss_fn = KLDivergenceLoss() |
| adv_loss_fn = PatchAdversarialLoss(criterion="least_squares") |
|
|
| with warnings.catch_warnings(): |
| warnings.simplefilter("ignore") |
| perc_loss_fn = PerceptualLoss( |
| spatial_dims=3, |
| network_type="squeeze", |
| is_fake_3d=True, |
| fake_3d_ratio=0.2 |
| ).to(device) |
|
|
| |
| optimizer_g = torch.optim.Adam(autoencoder.parameters(), lr=lr) |
| optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=lr) |
|
|
| |
| gradacc_g = GradientAccumulation( |
| actual_batch_size=max_batch_size, |
| expect_batch_size=batch_size, |
| loader_len=len(train_loader), |
| optimizer=optimizer_g, |
| grad_scaler=GradScaler() |
| ) |
|
|
| gradacc_d = GradientAccumulation( |
| actual_batch_size=max_batch_size, |
| expect_batch_size=batch_size, |
| loader_len=len(train_loader), |
| optimizer=optimizer_d, |
| grad_scaler=GradScaler() |
| ) |
|
|
| |
| avgloss = AverageLoss() |
| writer = SummaryWriter() |
| total_counter = 0 |
|
|
| for epoch in range(n_epochs): |
| print(f"[DEBUG] Starting epoch {epoch}/{n_epochs-1}") |
| autoencoder.train() |
| progress_bar = tqdm(enumerate(train_loader), total=len(train_loader)) |
| progress_bar.set_description(f'Epoch {epoch}') |
|
|
| for step, batch in progress_bar: |
| |
| with autocast(device, enabled=True): |
| images = batch["image"].to(device) |
| reconstruction, z_mu, z_sigma = autoencoder(images) |
|
|
| logits_fake = discriminator(reconstruction.contiguous().float())[-1] |
|
|
| rec_loss = l1_loss_fn(reconstruction.float(), images.float()) |
| kl_loss = kl_weight * kl_loss_fn(z_mu, z_sigma) |
| per_loss = perceptual_weight * perc_loss_fn(reconstruction.float(), images.float()) |
| gen_loss = adv_weight * adv_loss_fn(logits_fake, target_is_real=True, for_discriminator=False) |
|
|
| loss_g = rec_loss + kl_loss + per_loss + gen_loss |
|
|
| gradacc_g.step(loss_g, step) |
|
|
| |
| with autocast(device, enabled=True): |
| logits_fake = discriminator(reconstruction.contiguous().detach())[-1] |
| d_loss_fake = adv_loss_fn(logits_fake, target_is_real=False, for_discriminator=True) |
| logits_real = discriminator(images.contiguous().detach())[-1] |
| d_loss_real = adv_loss_fn(logits_real, target_is_real=True, for_discriminator=True) |
| discriminator_loss = (d_loss_fake + d_loss_real) * 0.5 |
| loss_d = adv_weight * discriminator_loss |
|
|
| gradacc_d.step(loss_d, step) |
|
|
| |
| avgloss.put('Generator/reconstruction_loss', rec_loss.item()) |
| avgloss.put('Generator/perceptual_loss', per_loss.item()) |
| avgloss.put('Generator/adversarial_loss', gen_loss.item()) |
| avgloss.put('Generator/kl_regularization', kl_loss.item()) |
| avgloss.put('Discriminator/adversarial_loss', loss_d.item()) |
|
|
| if total_counter % 10 == 0: |
| step_log = total_counter // 10 |
| avgloss.to_tensorboard(writer, step_log) |
| tb_display_reconstruction( |
| writer, |
| step_log, |
| images[0].detach().cpu(), |
| reconstruction[0].detach().cpu() |
| ) |
|
|
| total_counter += 1 |
|
|
| |
| os.makedirs(output_dir, exist_ok=True) |
| torch.save(discriminator.state_dict(), os.path.join(output_dir, f'discriminator-ep-{epoch}.pth')) |
| torch.save(autoencoder.state_dict(), os.path.join(output_dir, f'autoencoder-ep-{epoch}.pth')) |
|
|
| writer.close() |
| print("Training completed and models saved.") |
|
|
|
|
| def main(): |
| """ |
| Main function to parse command-line arguments and run train(). |
| """ |
| import argparse |
|
|
| parser = argparse.ArgumentParser(description="brain2vec Training Script") |
|
|
| parser.add_argument('--dataset_csv', type=str, required=True, help='Path to the dataset CSV file.') |
| parser.add_argument('--cache_dir', type=str, required=True, help='Directory for caching data.') |
| parser.add_argument('--output_dir', type=str, required=True, help='Directory to save model checkpoints.') |
| parser.add_argument('--aekl_ckpt', type=str, default=None, help='Path to the autoencoder checkpoint.') |
| parser.add_argument('--disc_ckpt', type=str, default=None, help='Path to the discriminator checkpoint.') |
| parser.add_argument('--num_workers', type=int, default=8, help='Number of data loader workers.') |
| parser.add_argument('--n_epochs', type=int, default=5, help='Number of training epochs.') |
| parser.add_argument('--max_batch_size', type=int, default=2, help='Actual batch size per iteration.') |
| parser.add_argument('--batch_size', type=int, default=16, help='Expected (effective) batch size.') |
| parser.add_argument('--lr', type=float, default=1e-4, help='Learning rate.') |
| parser.add_argument('--aug_p', type=float, default=0.8, help='Augmentation probability.') |
|
|
| args = parser.parse_args() |
|
|
| train( |
| dataset_csv=args.dataset_csv, |
| cache_dir=args.cache_dir, |
| output_dir=args.output_dir, |
| aekl_ckpt=args.aekl_ckpt, |
| disc_ckpt=args.disc_ckpt, |
| num_workers=args.num_workers, |
| n_epochs=args.n_epochs, |
| max_batch_size=args.max_batch_size, |
| batch_size=args.batch_size, |
| lr=args.lr, |
| aug_p=args.aug_p, |
| ) |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|