| import os
|
| import torch
|
| import numpy as np
|
|
|
|
|
| class EarlyStopping():
|
| def __init__(self, patience=5, min_delta=0.0):
|
| self.patience = patience
|
| self.min_delta = min_delta
|
| self.best_metric = None
|
| self.best_index = 0
|
|
|
|
|
| def check_patience(self, metric_history):
|
| if self.patience == 0:
|
| return False
|
| best_metric = min(metric_history)
|
| best_index = metric_history.index(best_metric)
|
|
|
| if len(metric_history) - best_index >= self.patience:
|
| return True
|
|
|
| return False
|
|
|
| def check_improvement(self, metric_history):
|
| if self.min_delta == 0.0:
|
| return False
|
|
|
| if len(metric_history) < 2:
|
| return False
|
|
|
| if metric_history[-2] - metric_history[-1] >= self.min_delta:
|
| return True
|
|
|
| return False
|
|
|
| def __call__(self, metric_history):
|
| if self.min_delta != 0.0:
|
| return self.check_patience(metric_history)
|
|
|
| if self.patience != 0:
|
| return self.check_improvement(metric_history)
|
|
|
| return False
|
|
|
| class SaveBestModel():
|
| def __init__(self, folder="./", mode='min'):
|
| self.best_metric = None
|
| self.folder = folder
|
| self.mode = mode
|
|
|
| def __call__(self, model, current_metric, model_name="best.pth"):
|
| if self.best_metric is None:
|
| self.best_metric = current_metric
|
| save_path = os.path.join(self.folder, model_name)
|
| torch.save(model.state_dict(), save_path)
|
| else:
|
| if (self.mode == 'min' and current_metric < self.best_metric) or \
|
| (self.mode == 'max' and current_metric > self.best_metric):
|
| self.best_metric = current_metric
|
| save_path = os.path.join(self.folder, model_name)
|
| torch.save(model.state_dict(), save_path)
|
|
|
|
|
| class ModelLoss():
|
| def __init__(self, task='segmentation', loss='focal', focal_alpha=0.25, focal_gamma=2.0):
|
| if task not in ['segmentation', 'mae']:
|
| raise ValueError(f"Unsupported task: {task}")
|
|
|
| self.task = task
|
| self.loss = loss
|
| self.focal_alpha = focal_alpha
|
| self.focal_gamma = focal_gamma
|
| if self.task == 'segmentation' and self.loss not in ['focal', 'cross_entropy']:
|
| raise ValueError(f"Unsupported loss for segmentation task: {self.loss}")
|
|
|
| @staticmethod
|
| def l1(input, target):
|
| return torch.nn.functional.l1_loss(input, target)
|
|
|
| @staticmethod
|
| def l2(input, target):
|
| return torch.nn.functional.mse_loss(input, target)
|
|
|
| @staticmethod
|
| def cross_entropy(input, target):
|
| return torch.nn.functional.cross_entropy(input, target)
|
|
|
| @staticmethod
|
| def dice_loss(input: torch.Tensor, target: torch.Tensor, eps=1e-6):
|
| input = torch.sigmoid(input)
|
| input = input.view(-1)
|
| target = target.view(-1)
|
|
|
| intersection = (input * target).sum()
|
| dice = (2. * intersection + eps) / (input.sum() + target.sum() + eps)
|
| return 1 - dice
|
|
|
| @staticmethod
|
| def focal_loss(input: torch.Tensor, target: torch.Tensor, alpha=0.25, gamma=2.0, eps=1e-6):
|
| input = torch.sigmoid(input)
|
| input = input.view(-1)
|
| target = target.view(-1)
|
|
|
| bce_loss = torch.nn.functional.binary_cross_entropy(input, target, reduction='none')
|
| p_t = input * target + (1 - input) * (1 - target)
|
| focal_loss = alpha * (1 - p_t) ** gamma * bce_loss
|
| return focal_loss.mean()
|
|
|
| def __call__(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
| if self.task == 'segmentation':
|
| loss = 0.0
|
| if self.loss == 'focal' or self.loss == 'default':
|
| loss = ModelLoss.focal_loss(input, target, alpha=self.focal_alpha, gamma=self.focal_gamma)
|
| elif self.loss == 'cross_entropy':
|
| loss = ModelLoss.cross_entropy(input, target)
|
| dice = ModelLoss.dice_loss(input, target)
|
| return loss + dice
|
| elif self.task == 'mae':
|
| return ModelLoss.l2(input, target)
|
| return None
|
|
|
|
|
|
|
| class ModelMetrics():
|
| def __init__(self, task='segmentation', device='cpu',threshold=0.5):
|
| if task not in ['segmentation', 'mae']:
|
| raise ValueError(f"Unsupported task: {task}")
|
|
|
| self.task = task
|
| self.device = device
|
| self.threshold = threshold
|
|
|
| @staticmethod
|
| def iou_score(pred, target, eps=1e-6):
|
| intersection = (pred * target).sum(dim=(1,2,3))
|
| union = pred.sum(dim=(1,2,3)) + target.sum(dim=(1,2,3)) - intersection
|
| return ((intersection + eps) / (union + eps)).mean()
|
|
|
| @staticmethod
|
| def dice_score(pred, target, eps=1e-6):
|
| intersection = (pred * target).sum(dim=(1,2,3))
|
| return ((2 * intersection + eps) /
|
| (pred.sum(dim=(1,2,3)) + target.sum(dim=(1,2,3)) + eps)).mean()
|
|
|
| @staticmethod
|
| def pixel_accuracy(pred, target):
|
| return (pred == target).float().mean()
|
|
|
| @staticmethod
|
| def pixel_precision(pred, target, eps=1e-6):
|
|
|
| true_positive = (pred * target).sum()
|
| predicted_positive = pred.sum()
|
| return (true_positive.float() + eps) / (predicted_positive.float() + eps)
|
|
|
| @staticmethod
|
| def recall(pred, target, eps=1e-6):
|
| true_positive = (pred * target).sum()
|
| actual_positive = target.sum()
|
| return (true_positive.float() + eps) / (actual_positive.float() + eps)
|
|
|
| @staticmethod
|
| def l1(input, target):
|
| return torch.nn.functional.l1_loss(input, target)
|
|
|
| @staticmethod
|
| def l2(input, target):
|
| return torch.nn.functional.mse_loss(input, target)
|
|
|
| def getLabels(self):
|
| if self.task == 'segmentation':
|
| return ['iou', 'dice', 'accuracy', 'precision', 'recall']
|
| elif self.task == 'mae':
|
| return ['l1', 'l2']
|
| return []
|
|
|
| def __call__(self, pred: torch.Tensor, target: torch.Tensor):
|
| """
|
| pred: (B, 1, H, W) => output of model BEFORE sigmoid
|
| target: (B, H, W) or (B, 1, H, W)
|
| """
|
|
|
| if self.task == 'segmentation':
|
|
|
| pred = torch.sigmoid(pred)
|
| pred = (pred > 0.5).float()
|
| iou = self.iou_score(pred, target)
|
| dice = self.dice_score(pred, target)
|
| acc = self.pixel_accuracy(pred, target)
|
| pres = self.pixel_precision(pred, target)
|
| recall = self.recall(pred, target)
|
| return iou.item(), dice.item(), acc.item(), pres.item(), recall.item()
|
|
|
| if self.task == 'mae':
|
| l1 = self.l1(pred, target)
|
| l2 = self.l2(pred, target)
|
| return l1.item(), l2.item()
|
|
|
| def validate_mae(model, val_loader, metrics):
|
| model.eval()
|
| validation_accumulator = []
|
| with torch.no_grad():
|
| for sources, targets in val_loader:
|
| sources = sources.cuda()
|
| targets = targets.cuda()
|
|
|
| pred = model(sources)
|
|
|
| validation_accumulator.append(metrics(pred, targets))
|
|
|
| return np.mean(validation_accumulator, axis=0) |