| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | import math |
| | import sys |
| | import torch |
| | import os |
| | import util.misc as misc |
| | import util.lr_sched as lr_sched |
| | import numpy as np |
| | from util.metric import accuracy, ConfusionMatrix, kappa |
| | from sklearn.metrics import ( |
| | roc_auc_score, |
| | top_k_accuracy_score, |
| | f1_score, |
| | confusion_matrix, |
| | ) |
| | from torchmetrics.classification import ( |
| | BinarySpecificityAtSensitivity, |
| | BinarySensitivityAtSpecificity, |
| | ) |
| |
|
| |
|
| | import pdb |
| |
|
| |
|
| | def train_one_epoch( |
| | model, |
| | data_loader, |
| | optimizer, |
| | device, |
| | epoch: int, |
| | loss_scaler, |
| | log_writer=None, |
| | args=None, |
| | ): |
| | model.train(True) |
| | metric_logger = misc.MetricLogger(delimiter=" ") |
| | metric_logger.add_meter("lr", misc.SmoothedValue(window_size=1, fmt="{value:.6f}")) |
| | header = "Epoch: [{}]".format(epoch) |
| | print_freq = 20 |
| |
|
| | if args.dataset == "promis": |
| | loss_cal = torch.nn.BCEWithLogitsLoss() |
| | else: |
| | if args.num_classes > 1: |
| | loss_cal = torch.nn.CrossEntropyLoss() |
| | else: |
| | loss_cal = torch.nn.BCEWithLogitsLoss() |
| | |
| | optimizer.zero_grad() |
| |
|
| | if log_writer is not None: |
| | print("log_dir: {}".format(log_writer.log_dir)) |
| | last_norm = 0.0 |
| | for data_iter_step, (img, gt, dataidx) in enumerate( |
| | metric_logger.log_every(data_loader, print_freq, header) |
| | ): |
| | |
| | img, gt = img.to(device, non_blocking=True), gt.to(device, non_blocking=True) |
| | lr_sched.adjust_learning_rate( |
| | optimizer, data_iter_step / len(data_loader) + epoch, args |
| | ) |
| | logit = model(img) |
| | |
| | loss = loss_cal(logit, gt) |
| | loss_value = loss.item() |
| |
|
| | if not math.isfinite(loss_value): |
| | print( |
| | "nan", |
| | torch.isnan(logit).any(), |
| | torch.isnan(img).any(), |
| | dataidx, |
| | last_norm, |
| | ) |
| | print( |
| | "inf", |
| | torch.isinf(logit).any(), |
| | torch.isinf(img).any(), |
| | dataidx, |
| | last_norm, |
| | ) |
| | print("Loss is {}, stopping training".format(loss_value)) |
| | sys.exit(1) |
| |
|
| | optimizer.zero_grad() |
| | loss.backward() |
| | |
| | optimizer.step() |
| |
|
| | |
| | |
| | |
| | metric_logger.update(loss=loss_value) |
| |
|
| | lr = optimizer.param_groups[0]["lr"] |
| | metric_logger.update(lr=lr) |
| |
|
| | loss_value_reduce = misc.all_reduce_mean(loss_value) |
| | if log_writer is not None: |
| | """We use epoch_1000x as the x-axis in tensorboard. |
| | This calibrates different curves when batch size changes. |
| | """ |
| | epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000) |
| | log_writer.add_scalar("train_loss", loss_value_reduce, epoch_1000x) |
| | log_writer.add_scalar("lr", lr, epoch_1000x) |
| |
|
| | |
| | |
| | print("Averaged stats:", metric_logger) |
| | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} |
| |
|
| |
|
| | def validation(model, data_loader_val, device, epoch, args): |
| | model.eval() |
| |
|
| | if args.dataset == "promis": |
| | loss_cal = torch.nn.BCEWithLogitsLoss() |
| | else: |
| | if args.num_classes > 1: |
| | loss_cal = torch.nn.CrossEntropyLoss() |
| | else: |
| | loss_cal = torch.nn.BCEWithLogitsLoss() |
| | |
| | with torch.no_grad(): |
| | loss_summary = [] |
| | for idx, (img, gt, _) in enumerate(data_loader_val): |
| | img, gt = img.to(device), gt.to(device) |
| | mask = model(img) |
| | loss = loss_cal(mask, gt) |
| | loss_summary.append(loss.detach().cpu().numpy()) |
| | print( |
| | "epoch: {}/{}, iter: {}/{}".format( |
| | epoch, args.epochs, idx, len(data_loader_val) |
| | ) |
| | + " loss:" |
| | + str(loss_summary[-1].flatten()[0]) |
| | ) |
| | avg_loss = np.mean(loss_summary) |
| | print("Averaged stats:", str(avg_loss)) |
| | return avg_loss |
| |
|
| |
|
| | def test(model, test_loader, args): |
| | filepath_best = os.path.join(args.output_dir, "best.pth.tar") |
| | model.load_state_dict(torch.load(filepath_best)["model"], weights_only=False) |
| | model.eval() |
| | prob, gts = [], [] |
| | with torch.no_grad(): |
| | for idx, (img, gt, _) in enumerate(test_loader): |
| | img, gt = img.to(args.device), gt.to(args.device) |
| | logit = model(img) |
| | prob.append(logit) |
| | gts.append(gt) |
| |
|
| | if args.dataset == "risk": |
| | return test_risk(prob, gts) |
| | elif args.dataset == "screening": |
| | return test_screening(prob, gts) |
| | elif args.dataset == "promis": |
| | return test_promis(prob, gts) |
| | else: |
| | raise NotImplementedError(f"unknown dataset: {args.dataset}") |
| |
|
| |
|
| | def test_risk(prob, gts): |
| | log_stats = {} |
| | prob = torch.cat(prob, 0) |
| | prob = torch.softmax(prob, dim=-1).cpu().numpy() |
| | gts = torch.cat(gts, 0).cpu().numpy() |
| |
|
| | score_acc = top_k_accuracy_score(gts, prob, k=1) * 100 |
| | score_qwk = kappa(gts, np.argmax(prob, 1)) |
| | score_auc = roc_auc_score(gts, prob, multi_class="ovr") * 100 |
| | score_f1 = f1_score(gts, np.argmax(prob, 1), average="macro") * 100 |
| |
|
| | print("score") |
| | print(f"acc\t auc \t qwk \t f1") |
| | print(f"{score_acc:.2f} \t {score_auc:.2f} \t {score_qwk:.4f} \t {score_f1:.2f}") |
| | log_stats["4-class_acc"] = f"{score_acc:.2f}" |
| | log_stats["4-class_auc"] = f"{score_auc:.2f}" |
| | log_stats["4-class_qwk"] = f"{score_qwk:.4f}" |
| | log_stats["4-class_f1"] = f"{score_f1:.2f}" |
| |
|
| | |
| |
|
| | sig_prob = np.sum(prob[:, 1:], -1) |
| | sig_gts = (gts > 0).astype(int) |
| | sig_acc = top_k_accuracy_score(sig_gts, sig_prob, k=1) * 100 |
| | sig_auc = roc_auc_score(sig_gts, sig_prob) * 100 |
| | sig_f1 = f1_score(sig_gts, sig_prob > 0.5) * 100 |
| |
|
| | print("Pirads >=3") |
| | print(f"auc \t f1 ") |
| | print(f"{sig_auc:.2f} \t {sig_f1:.2f}") |
| |
|
| | log_stats["leq3_auc"]=f"{sig_auc:.2f}" |
| | log_stats["leq3_f1"]=f"{sig_f1:.2f}" |
| |
|
| | for i in [0.8, 0.9]: |
| | sig_spec = BinarySpecificityAtSensitivity(min_sensitivity=i, thresholds=None) |
| | sig_specificity, _ = sig_spec( |
| | torch.from_numpy(sig_prob), torch.from_numpy(sig_gts) |
| | ) |
| | sig_specificity = sig_specificity * 100 |
| | sig_sens = BinarySensitivityAtSpecificity(min_specificity=i, thresholds=None) |
| | sig_sensitivity, _ = sig_sens( |
| | torch.from_numpy(sig_prob), torch.from_numpy(sig_gts) |
| | ) |
| | sig_sensitivity = sig_sensitivity* 100 |
| |
|
| | print(f"min: {i}") |
| | print(f"Specificity at Sensitivity \t Sensitivity at Specificity") |
| | print(f"{sig_specificity:.2f} \t {sig_sensitivity:.2f} ") |
| | log_stats[f"leq3_specificity_at_{i}"]=f"{sig_specificity:.2f}" |
| | log_stats[f"leq3_sensitivity_at_{i}"]=f"{sig_sensitivity:.2f}" |
| |
|
| | sig_prob = np.sum(prob[:, 2:], -1) |
| | sig_gts = (gts > 1).astype(int) |
| | sig_acc = top_k_accuracy_score(sig_gts, sig_prob, k=1) * 100 |
| | sig_auc = roc_auc_score(sig_gts, sig_prob) * 100 |
| | sig_f1 = f1_score(sig_gts, sig_prob > 0.5) * 100 |
| |
|
| | print("Pirads >=4") |
| | print(f"auc \t f1 ") |
| | print(f"{sig_auc:.2f} \t {sig_f1:.2f}") |
| |
|
| | log_stats["leq4_auc"]=f"{sig_auc:.2f}" |
| | log_stats["leq4_f1"]=f"{sig_f1:.2f}" |
| |
|
| | for i in [0.8, 0.9]: |
| | sig_spec = BinarySpecificityAtSensitivity(min_sensitivity=i, thresholds=None) |
| | sig_specificity, _ = sig_spec( |
| | torch.from_numpy(sig_prob), torch.from_numpy(sig_gts) |
| | ) |
| | sig_specificity = sig_specificity * 100 |
| | sig_sens = BinarySensitivityAtSpecificity(min_specificity=i, thresholds=None) |
| | sig_sensitivity, _ = sig_sens( |
| | torch.from_numpy(sig_prob), torch.from_numpy(sig_gts) |
| | ) |
| | sig_sensitivity = sig_sensitivity* 100 |
| |
|
| | print(f"min: {i}") |
| | print(f"Specificity at Sensitivity \t Sensitivity at Specificity") |
| | print(f"{sig_specificity:.2f} \t {sig_sensitivity:.2f} ") |
| | log_stats[f"leq4_specificity_at_{i}"]=f"{sig_specificity:.2f}" |
| | log_stats[f"leq4_sensitivity_at_{i}"]=f"{sig_sensitivity:.2f}" |
| | return log_stats |
| |
|
| |
|
| | def test_screening(prob, gts): |
| | prob = torch.cat(prob, 0) |
| | prob = torch.sigmoid(prob).cpu().numpy() |
| | gts = torch.cat(gts, 0).long().cpu().numpy() |
| |
|
| | np.savez("result.npz", gts=gts, prob=prob) |
| | score_acc = top_k_accuracy_score(gts, prob, k=1) * 100 |
| | score_auc = roc_auc_score(gts, prob) * 100 |
| | score_f1 = f1_score(gts, np.argmax(prob, 1)) * 100 |
| |
|
| | print(f"acc\t auc \t f1") |
| | print(f"{score_acc:.2f} \t {score_auc:.2f} \t {score_f1:.2f}") |
| |
|
| | for i in [0.8, 0.9]: |
| | sig_spec = BinarySpecificityAtSensitivity(min_sensitivity=i, thresholds=None) |
| | sig_specificity, _ = sig_spec(torch.from_numpy(prob), torch.from_numpy(gts)) |
| | sig_sens = BinarySensitivityAtSpecificity(min_specificity=i, thresholds=None) |
| | sig_sensitivity, _ = sig_sens(torch.from_numpy(prob), torch.from_numpy(gts)) |
| |
|
| | print(f"min: {i}") |
| | print(f"Specificity at Sensitivity \t Sensitivity at Specificity") |
| | print(f"{sig_specificity* 100:.2f} \t {sig_sensitivity* 100:.2f} ") |
| |
|
| | log_stats = None |
| | return log_stats |
| |
|
| |
|
| |
|
| | def test_promis(prob, gts): |
| | log_stats = {} |
| |
|
| | prob = torch.cat(prob, 0) |
| | prob = torch.sigmoid(prob).cpu().numpy() |
| | gts = torch.cat(gts, 0).cpu().numpy().astype(int) |
| |
|
| | |
| | zone_prob = prob.reshape(-1) |
| | zone_gt = gts.reshape(-1) |
| | print(f"zone level performance") |
| |
|
| | auc = roc_auc_score(zone_prob, zone_gt) * 100 |
| | print(f"AUC: {auc:.2f}") |
| | for i in [0.8, 0.9]: |
| | sig_spec = BinarySpecificityAtSensitivity(min_sensitivity=i, thresholds=None) |
| | sig_specificity, _ = sig_spec( |
| | torch.from_numpy(zone_prob), torch.from_numpy(zone_gt) |
| | ) |
| | sig_sens = BinarySensitivityAtSpecificity(min_specificity=i, thresholds=None) |
| | sig_sensitivity, _ = sig_sens( |
| | torch.from_numpy(zone_prob), torch.from_numpy(zone_gt) |
| | ) |
| |
|
| | print(f"min: {i}") |
| | print(f"Specificity at Sensitivity \t Sensitivity at Specificity") |
| | print(f"{sig_specificity* 100:.2f} \t {sig_sensitivity* 100:.2f} ") |
| |
|
| |
|
| |
|
| |
|
| | |
| | patient_prob = prob.max(-1) |
| | patient_gt = gts.max(-1) |
| | |
| | print(f"patient level performance") |
| |
|
| | auc = roc_auc_score(patient_prob, patient_gt) * 100 |
| | print(f"AUC: {auc:.2f}") |
| | for i in [0.8, 0.9]: |
| | sig_spec = BinarySpecificityAtSensitivity(min_sensitivity=i, thresholds=None) |
| | sig_specificity, _ = sig_spec( |
| | torch.from_numpy(patient_prob), torch.from_numpy(patient_gt) |
| | ) |
| | sig_sens = BinarySensitivityAtSpecificity(min_specificity=i, thresholds=None) |
| | sig_sensitivity, _ = sig_sens( |
| | torch.from_numpy(patient_prob), torch.from_numpy(patient_gt) |
| | ) |
| |
|
| | print(f"min: {i}") |
| | print(f"Specificity at Sensitivity \t Sensitivity at Specificity") |
| | print(f"{sig_specificity* 100:.2f} \t {sig_sensitivity* 100:.2f} ") |
| |
|
| | |
| | return log_stats |