| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from collections import defaultdict |
|
|
|
|
| def dice_loss(input_mask, cls_gt): |
| num_objects = input_mask.shape[1] |
| losses = [] |
| for i in range(num_objects): |
| mask = input_mask[:,i].flatten(start_dim=1) |
| |
| gt = (cls_gt==(i+1)).float().flatten(start_dim=1) |
| numerator = 2 * (mask * gt).sum(-1) |
| denominator = mask.sum(-1) + gt.sum(-1) |
| loss = 1 - (numerator + 1) / (denominator + 1) |
| losses.append(loss) |
| return torch.cat(losses).mean() |
|
|
| def l1_loss(input, target): |
| out = torch.abs(input - target) |
| return out.mean() |
|
|
|
|
| |
| class BootstrappedCE(nn.Module): |
| def __init__(self, start_warm, end_warm, top_p=0.15): |
| super().__init__() |
|
|
| self.start_warm = start_warm |
| self.end_warm = end_warm |
| self.top_p = top_p |
|
|
| def forward(self, input, target, it): |
| if it < self.start_warm: |
|
|
| return F.cross_entropy(input, target), 1.0 |
|
|
| raw_loss = F.cross_entropy(input, target, reduction='none').view(-1) |
| num_pixels = raw_loss.numel() |
|
|
| if it > self.end_warm: |
| this_p = self.top_p |
| else: |
| this_p = self.top_p + (1-self.top_p)*((self.end_warm-it)/(self.end_warm-self.start_warm)) |
| loss, _ = torch.topk(raw_loss, int(num_pixels * this_p), sorted=False) |
| return loss.mean(), this_p |
|
|
|
|
| class LossComputer: |
| def __init__(self, config): |
| super().__init__() |
| self.config = config |
| self.bce = BootstrappedCE(config['start_warm'], config['end_warm']) |
|
|
| def compute(self, data, num_objects, it): |
| losses = defaultdict(int) |
|
|
| b, t = data['rgb'].shape[:2] |
|
|
| losses['total_loss'] = 0 |
| for ti in range(1, t): |
| for bi in range(b): |
| loss, p = self.bce(data[f'logits_{ti}'][bi:bi+1, :num_objects[bi]+1], data['cls_gt'][bi:bi+1,ti,0], it) |
| losses['p'] += p / b / (t-1) |
| losses[f'ce_loss_{ti}'] += loss / b |
|
|
| losses['total_loss'] += losses['ce_loss_%d'%ti] |
| losses[f'dice_loss_{ti}'] = dice_loss(data[f'masks_{ti}'], data['cls_gt'][:,ti,0]) |
| losses['total_loss'] += losses[f'dice_loss_{ti}'] |
|
|
| return losses |
|
|
|
|
| def compute_l1loss(self, data, num_objects, it): |
| losses = defaultdict(int) |
|
|
| b, t = data['rgb'].shape[:2] |
|
|
| losses['total_loss'] = 0 |
| for ti in range(1, t): |
| for bi in range(b): |
| losses['p'] = 0 |
| losses[f'ce_loss_{ti}'] = 0 |
|
|
| losses['total_loss'] += losses['ce_loss_%d'%ti] |
| losses[f'dice_loss_{ti}'] = l1_loss(data[f'masks_{ti}'], data['cls_gt'][:,ti]) |
| losses['total_loss'] += losses[f'dice_loss_{ti}'] |
|
|
| return losses |
|
|