| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import os |
| import shutil |
| import time |
|
|
| import numpy as np |
| import torch |
| import torch.nn.parallel |
| import torch.utils.data.distributed |
| from tensorboardX import SummaryWriter |
| from torch.amp import GradScaler, autocast |
| from utils.utils import AverageMeter, distributed_all_gather |
|
|
| from monai.data import decollate_batch |
|
|
|
|
| def train_epoch(model, loader, optimizer, scaler, epoch, loss_func, args): |
| model.train() |
| start_time = time.time() |
| run_loss = AverageMeter() |
| for idx, batch_data in enumerate(loader): |
| if isinstance(batch_data, list): |
| data, target, text = batch_data |
| else: |
| data, target, text = batch_data["image"], batch_data["label"], batch_data["text_feature"] |
| data, target, text = data.cuda(args.rank), target.cuda(args.rank), text.cuda(args.rank) |
| optimizer.zero_grad(set_to_none=True) |
| with autocast('cuda',enabled=args.amp): |
| logits = model(data,text) |
| loss = loss_func(logits, target) |
| if args.amp: |
| scaler.scale(loss).backward() |
| scaler.step(optimizer) |
| scaler.update() |
| else: |
| loss.backward() |
| optimizer.step() |
| if args.distributed: |
| loss_list = distributed_all_gather([loss], out_numpy=True, is_valid=idx < loader.sampler.valid_length) |
| run_loss.update( |
| np.mean(np.mean(np.stack(loss_list, axis=0), axis=0), axis=0), n=args.batch_size * args.world_size |
| ) |
| else: |
| run_loss.update(loss.item(), n=args.batch_size) |
| if args.rank == 0: |
| print( |
| "Epoch {}/{} {}/{}".format(epoch, args.max_epochs, idx, len(loader)), |
| "loss: {:.4f}".format(run_loss.avg), |
| "time {:.2f}s".format(time.time() - start_time), |
| ) |
| start_time = time.time() |
| '''for param in model.parameters(): |
| param.grad = None''' |
| optimizer.zero_grad(set_to_none=True) |
| return run_loss.avg |
|
|
|
|
| def val_epoch(model, loader, epoch, acc_func, args, post_sigmoid=None, post_pred=None): |
| model.eval() |
| start_time = time.time() |
| run_acc = AverageMeter() |
|
|
| with torch.no_grad(): |
| for idx, batch_data in enumerate(loader): |
| data, target, text = batch_data["image"], batch_data["label"], batch_data["text_feature"] |
| data, target, text = data.cuda(args.rank), target.cuda(args.rank), text.cuda(args.rank) |
| with autocast('cuda',enabled=args.amp): |
| logits = model(data,text) |
| val_labels_list = decollate_batch(target) |
| val_outputs_list = decollate_batch(logits) |
| val_output_convert = [post_pred(post_sigmoid(val_pred_tensor)) for val_pred_tensor in val_outputs_list] |
| acc_func.reset() |
| acc_func(y_pred=val_output_convert, y=val_labels_list) |
| acc, not_nans = acc_func.aggregate() |
| acc = acc.cuda(args.rank) |
| if args.distributed: |
| acc_list, not_nans_list = distributed_all_gather( |
| [acc, not_nans], out_numpy=True, is_valid=idx < loader.sampler.valid_length |
| ) |
| for al, nl in zip(acc_list, not_nans_list): |
| run_acc.update(al, n=nl) |
| else: |
| run_acc.update(acc.cpu().numpy(), n=not_nans.cpu().numpy()) |
|
|
| if args.rank == 0: |
| Dice_TC = run_acc.avg[0] |
| Dice_WT = run_acc.avg[1] |
| Dice_ET = run_acc.avg[2] |
| print( |
| "Val {}/{} {}/{}".format(epoch, args.max_epochs, idx, len(loader)), |
| ", Dice_TC:", |
| Dice_TC, |
| ", Dice_WT:", |
| Dice_WT, |
| ", Dice_ET:", |
| Dice_ET, |
| ", time {:.2f}s".format(time.time() - start_time), |
| ) |
| start_time = time.time() |
|
|
| return run_acc.avg |
|
|
|
|
| def save_checkpoint(model, epoch, args, filename="model.pt", best_acc=0, optimizer=None, scheduler=None): |
| state_dict = model.state_dict() if not args.distributed else model.module.state_dict() |
| save_dict = {"epoch": epoch, "best_acc": best_acc, "state_dict": state_dict} |
| if optimizer is not None: |
| save_dict["optimizer"] = optimizer.state_dict() |
| if scheduler is not None: |
| save_dict["scheduler"] = scheduler.state_dict() |
| filename = os.path.join(args.logdir, filename) |
| torch.save(save_dict, filename) |
| print("Saving checkpoint", filename) |
|
|
|
|
| def run_training( |
| model, |
| train_loader, |
| val_loader, |
| optimizer, |
| loss_func, |
| acc_func, |
| args, |
| scheduler=None, |
| start_epoch=0, |
| post_sigmoid=None, |
| post_pred=None, |
| semantic_classes=None, |
| ): |
| writer = None |
| if args.logdir is not None and args.rank == 0: |
| writer = SummaryWriter(log_dir=args.logdir) |
| if args.rank == 0: |
| print("Writing Tensorboard logs to ", args.logdir) |
| scaler = None |
| if args.amp: |
| scaler = GradScaler() |
| val_acc_max = 0.0 |
| for epoch in range(start_epoch, args.max_epochs): |
| if args.distributed: |
| train_loader.sampler.set_epoch(epoch) |
| torch.distributed.barrier() |
| print(args.rank, time.ctime(), "Epoch:", epoch) |
| epoch_time = time.time() |
| train_loss = train_epoch( |
| model, train_loader, optimizer, scaler=scaler, epoch=epoch, loss_func=loss_func, args=args |
| ) |
| if args.rank == 0: |
| print( |
| "Final training {}/{}".format(epoch, args.max_epochs - 1), |
| "loss: {:.4f}".format(train_loss), |
| "time {:.2f}s".format(time.time() - epoch_time), |
| ) |
| if args.rank == 0 and writer is not None: |
| writer.add_scalar("train_loss", train_loss, epoch) |
| b_new_best = False |
| if (epoch + 1) % args.val_every == 0: |
| if args.distributed: |
| torch.distributed.barrier() |
| epoch_time = time.time() |
| val_acc = val_epoch( |
| model, |
| val_loader, |
| epoch=epoch, |
| acc_func=acc_func, |
| args=args, |
| post_sigmoid=post_sigmoid, |
| post_pred=post_pred, |
| ) |
|
|
| if args.rank == 0: |
| Dice_TC = val_acc[0] |
| Dice_WT = val_acc[1] |
| Dice_ET = val_acc[2] |
| print( |
| "Final validation stats {}/{}".format(epoch, args.max_epochs - 1), |
| ", Dice_TC:", |
| Dice_TC, |
| ", Dice_WT:", |
| Dice_WT, |
| ", Dice_ET:", |
| Dice_ET, |
| ", time {:.2f}s".format(time.time() - epoch_time), |
| ) |
|
|
| if writer is not None: |
| writer.add_scalar("Mean_Val_Dice", np.mean(val_acc), epoch) |
| if semantic_classes is not None: |
| for val_channel_ind in range(len(semantic_classes)): |
| if val_channel_ind < val_acc.size: |
| writer.add_scalar(semantic_classes[val_channel_ind], val_acc[val_channel_ind], epoch) |
| val_avg_acc = np.mean(val_acc) |
| if val_avg_acc > val_acc_max: |
| print("new best ({:.6f} --> {:.6f}). ".format(val_acc_max, val_avg_acc)) |
| val_acc_max = val_avg_acc |
| b_new_best = True |
| if args.rank == 0 and args.logdir is not None and args.save_checkpoint: |
| save_checkpoint( |
| model, epoch, args, best_acc=val_acc_max, optimizer=optimizer, scheduler=scheduler |
| ) |
| if args.rank == 0 and args.logdir is not None and args.save_checkpoint: |
| print("Saving") |
| save_checkpoint(model, epoch, args, best_acc=val_acc_max, filename="model_final.pt") |
| if b_new_best: |
| print("Copying to model.pt new best model!!!!") |
| shutil.copyfile(os.path.join(args.logdir, "model_final.pt"), os.path.join(args.logdir, "model.pt")) |
|
|
| if scheduler is not None: |
| scheduler.step() |
|
|
| print("Training Finished !, Best Accuracy: ", val_acc_max) |
|
|
| return val_acc_max |
|
|