ProFound / engine /segment.py
Anonymise's picture
add necessary module
45461c9
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# DeiT: https://github.com/facebookresearch/deit
# BEiT: https://github.com/microsoft/unilm/tree/master/beit
# --------------------------------------------------------
import math
import sys
from typing import Iterable
import torch
import os
import util.misc as misc
import util.lr_sched as lr_sched
from monai.losses import DiceCELoss, DiceLoss
import numpy as np
from monai.metrics import DiceHelper
import surface_distance
from surface_distance import metrics
from util.meter import DiceMeter, HausdorffMeter, SurfaceDistanceMeter
# from monai.data import ImageDataset, create_test_image_3d, decollate_batch, DataLoader
from monai.inferers import sliding_window_inference
# from monai.metrics import DiceMetric
# from monai.transforms import Activations
import pdb
def train_one_epoch(
model,
data_loader,
optimizer,
device,
epoch: int,
loss_scaler,
log_writer=None,
args=None,
):
model.train(True)
metric_logger = misc.MetricLogger(delimiter=" ")
metric_logger.add_meter("lr", misc.SmoothedValue(window_size=1, fmt="{value:.6f}"))
header = "Epoch: [{}]".format(epoch)
print_freq = 20
if args.out_channels == 1:
loss_cal = DiceCELoss(sigmoid=True)
else:
loss_cal = DiceCELoss(to_onehot_y=True, softmax=True, include_background=False)
optimizer.zero_grad()
if log_writer is not None:
print("log_dir: {}".format(log_writer.log_dir))
last_norm = 0.0
for data_iter_step, (img, gt, dataidx) in enumerate(
metric_logger.log_every(data_loader, print_freq, header)
):
# we use a per iteration (instead of per epoch) lr scheduler
img, gt = img.to(device, non_blocking=True), gt.to(device, non_blocking=True)
lr_sched.adjust_learning_rate(
optimizer, data_iter_step / len(data_loader) + epoch, args
)
# print(img.shape, img.mean(), img.std())
# with torch.cuda.amp.autocast():
logit = model(img)
if isinstance(logit, list):
loss = loss_cal(logit[0], gt) + 0.4*loss_cal(logit[1], gt)
else:
loss = loss_cal(logit, gt)
loss_value = loss.item()
if not math.isfinite(loss_value):
print(
"nan",
torch.isnan(logit).any(),
torch.isnan(img).any(),
dataidx,
last_norm,
)
print(
"inf",
torch.isinf(logit).any(),
torch.isinf(img).any(),
dataidx,
last_norm,
)
print("Loss is {}, stopping training".format(loss_value))
sys.exit(1)
optimizer.zero_grad()
loss.backward()
# torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
# last_norm = loss_scaler(loss, optimizer, parameters=model.parameters())
# optimizer.zero_grad()
# torch.cuda.synchronize()
metric_logger.update(loss=loss_value)
lr = optimizer.param_groups[0]["lr"]
metric_logger.update(lr=lr)
loss_value_reduce = misc.all_reduce_mean(loss_value)
if log_writer is not None:
"""We use epoch_1000x as the x-axis in tensorboard.
This calibrates different curves when batch size changes.
"""
epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
log_writer.add_scalar("train_loss", loss_value_reduce, epoch_1000x)
log_writer.add_scalar("lr", lr, epoch_1000x)
# gather the stats from all processes
# metric_logger.synchronize_between_processes()
print("Averaged stats:", metric_logger)
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
def validation(model, data_loader_val, device, epoch, args):
model.eval()
if args.out_channels == 1:
dice_loss = DiceLoss(sigmoid=True)
else:
dice_loss = DiceLoss(to_onehot_y=True, softmax=True, include_background=False)
with torch.no_grad():
loss_summary = []
for idx, (img, gt, _) in enumerate(data_loader_val):
img, gt = img.to(device), gt.to(device)
mask = model(img)
loss = dice_loss(mask, gt)
loss_summary.append(loss.detach().cpu().numpy())
print(
"epoch: {}/{}, iter: {}/{}".format(
epoch, args.epochs, idx, len(data_loader_val)
)
+ " loss:"
+ str(loss_summary[-1].flatten()[0])
)
avg_loss = np.mean(loss_summary)
print("Averaged stats:", str(avg_loss))
return avg_loss
def test(model, test_loader, args, sliding_window=False):
model.eval()
filepath_best = os.path.join(args.output_dir, "best.pth.tar")
model.load_state_dict(torch.load(filepath_best)["model"], weights_only=False)
dice_meter = DiceMeter(args)
hausdorff_meter = HausdorffMeter(args)
sd_meter = SurfaceDistanceMeter(args)
log_stats = {}
with torch.no_grad():
for idx, (img, gt, _) in enumerate(test_loader):
img, gt = img.to(args.device), gt.to(args.device)
if sliding_window:
pred = sliding_window_inference(
img, args.crop_spatial_size, 4, model, overlap=0.5
)
else:
pred = model(img)
if args.num_classes == 1:
pred = torch.sigmoid(pred) > 0.5
else:
pred = torch.softmax(pred, dim=1)
pred = torch.argmax(pred, dim=1, keepdim=True)
dice_meter.update(pred, gt)
hausdorff_meter.update(pred, gt)
sd_meter.update(pred, gt)
print("- Test metrics Dice: ")
dice_class_avg, dice_avg = dice_meter.get_average()
print("Class wise: ", dice_class_avg)
print("Avg.: ", dice_avg)
print("- Test metrics Hausdorff95: ")
hsd_class_avg, hsd_avg = hausdorff_meter.get_average()
print("Class wise: ", hsd_class_avg)
print("Avg.: ", hsd_avg)
print("- Test metrics SurfaceDistance: ")
sd_class_avg, sd_avg = sd_meter.get_average()
print("Class wise: ", sd_class_avg)
print("Avg.: ", sd_avg)
log_stats = {
"dice_class_avg": dice_class_avg.tolist() if isinstance(dice_class_avg, np.ndarray) else dice_class_avg,
"dice_avg": dice_avg.tolist() if isinstance(dice_avg, np.ndarray) else dice_avg,
"hsd_class_avg": hsd_class_avg.tolist() if isinstance(hsd_class_avg, np.ndarray) else hsd_class_avg,
"hsd_avg": hsd_avg.tolist() if isinstance(hsd_avg, np.ndarray) else hsd_avg,
"sd_class_avg": sd_class_avg.tolist() if isinstance(sd_class_avg, np.ndarray) else sd_class_avg,
"sd_avg": sd_avg.tolist() if isinstance(sd_avg, np.ndarray) else sd_avg,
}
return log_stats