| import torch as t |
| import torch.nn as nn |
| |
|
|
|
|
| class SaliencyLoss(nn.Module): |
| def __init__(self): |
| super(SaliencyLoss, self).__init__() |
|
|
| def forward(self, preds, labels, loss_type='cc'): |
| losses = [] |
| if loss_type == 'cc': |
| for i in range(labels.shape[0]): |
| loss = loss_CC(preds[i],labels[i]) |
| losses.append(loss) |
|
|
| elif loss_type == 'kldiv': |
| for i in range(labels.shape[0]): |
| loss = loss_KLdiv(preds[i],labels[i]) |
| losses.append(loss) |
|
|
| elif loss_type == 'sim': |
| for i in range(labels.shape[0]): |
| loss = loss_similarity(preds[i],labels[i]) |
| losses.append(loss) |
|
|
| elif loss_type == 'nss': |
| for i in range(labels.shape[0]): |
| loss = loss_NSS(preds[i],labels[i]) |
| losses.append(loss) |
| |
| return t.stack(losses).mean(dim=0, keepdim=True) |
| |
| |
| def loss_KLdiv(pred_map, gt_map): |
| eps = 2.2204e-16 |
| pred_map = pred_map/t.sum(pred_map) |
| gt_map = gt_map/t.sum(gt_map) |
| div = t.sum(t.mul(gt_map, t.log(eps + t.div(gt_map,pred_map+eps)))) |
| return div |
| |
| |
| def loss_CC(pred_map,gt_map): |
| gt_map_ = (gt_map - t.mean(gt_map)) |
| pred_map_ = (pred_map - t.mean(pred_map)) |
| cc = t.sum(t.mul(gt_map_,pred_map_))/t.sqrt(t.sum(t.mul(gt_map_,gt_map_))*t.sum(t.mul(pred_map_,pred_map_))) |
| return cc |
|
|
|
|
| def loss_similarity(pred_map,gt_map): |
| gt_map = (gt_map - t.min(gt_map))/(t.max(gt_map)-t.min(gt_map)) |
| gt_map = gt_map/t.sum(gt_map) |
| |
| pred_map = (pred_map - t.min(pred_map))/(t.max(pred_map)-t.min(pred_map)) |
| pred_map = pred_map/t.sum(pred_map) |
| |
| diff = t.min(gt_map,pred_map) |
| score = t.sum(diff) |
| |
| return score |
| |
| |
| def loss_NSS(pred_map,fix_map): |
| '''ground truth here is fixation map''' |
|
|
| pred_map_ = (pred_map - t.mean(pred_map))/t.std(pred_map) |
| mask = fix_map.gt(0) |
| score = t.mean(t.masked_select(pred_map_, mask)) |
| return score |
|
|