| from abc import ABC, abstractmethod |
| import torch |
| from torch import nn |
| from singleVis.backend import compute_cross_entropy_tf, convert_distance_to_probability, compute_cross_entropy |
|
|
| import torch |
| torch.manual_seed(0) |
| torch.cuda.manual_seed_all(0) |
| |
|
|
| """Losses modules for preserving four propertes""" |
| |
|
|
| class Loss(nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| |
| @abstractmethod |
| def forward(self, *args, **kwargs): |
| pass |
|
|
| class UmapLoss(nn.Module): |
| def __init__(self, negative_sample_rate, device, _a=1.0, _b=1.0, repulsion_strength=1.0): |
| super(UmapLoss, self).__init__() |
|
|
| self._negative_sample_rate = negative_sample_rate |
| self._a = _a, |
| self._b = _b, |
| self._repulsion_strength = repulsion_strength |
| self.DEVICE = torch.device(device) |
|
|
| @property |
| def a(self): |
| return self._a[0] |
|
|
| @property |
| def b(self): |
| return self._b[0] |
|
|
| def forward(self, embedding_to, embedding_from): |
| batch_size = embedding_to.shape[0] |
| |
| embedding_neg_to = torch.repeat_interleave(embedding_to, self._negative_sample_rate, dim=0) |
| repeat_neg = torch.repeat_interleave(embedding_from, self._negative_sample_rate, dim=0) |
| randperm = torch.randperm(repeat_neg.shape[0]) |
| embedding_neg_from = repeat_neg[randperm] |
| |
|
|
| |
| distance_embedding = torch.cat( |
| ( |
| torch.norm(embedding_to - embedding_from, dim=1), |
| torch.norm(embedding_neg_to - embedding_neg_from, dim=1), |
| ), |
| dim=0, |
| ) |
| probabilities_distance = convert_distance_to_probability( |
| distance_embedding, self.a, self.b |
| ) |
| probabilities_distance = probabilities_distance.to(self.DEVICE) |
|
|
| |
| probabilities_graph = torch.cat( |
| (torch.ones(batch_size), torch.zeros(batch_size * self._negative_sample_rate)), dim=0, |
| ) |
| probabilities_graph = probabilities_graph.to(device=self.DEVICE) |
|
|
| |
| (_, _, ce_loss) = compute_cross_entropy( |
| probabilities_graph, |
| probabilities_distance, |
| repulsion_strength=self._repulsion_strength, |
| ) |
|
|
| return torch.mean(ce_loss) |
|
|
|
|
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| class ReconstructionLoss(nn.Module): |
| def __init__(self, beta=1.0,alpha=0.5): |
| super(ReconstructionLoss, self).__init__() |
| self._beta = beta |
| self._alpha = alpha |
|
|
| def forward(self, edge_to, edge_from, recon_to, recon_from, a_to, a_from): |
| loss1 = torch.mean(torch.mean(torch.multiply(torch.pow((1+a_to), self._beta), torch.pow(edge_to - recon_to, 2)), 1)) |
| loss2 = torch.mean(torch.mean(torch.multiply(torch.pow((1+a_from), self._beta), torch.pow(edge_from - recon_from, 2)), 1)) |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| return (loss1 + loss2)/2 |
|
|
|
|
|
|
| class SmoothnessLoss(nn.Module): |
| def __init__(self, margin=0.0): |
| super(SmoothnessLoss, self).__init__() |
| self._margin = margin |
|
|
| def forward(self, embedding, target, Coefficient): |
| loss = torch.mean(Coefficient * torch.clamp(torch.norm(embedding-target, dim=1)-self._margin, min=0)) |
| return loss |
|
|
|
|
| class SingleVisLoss(nn.Module): |
| def __init__(self, umap_loss, recon_loss, lambd): |
| super(SingleVisLoss, self).__init__() |
| self.umap_loss = umap_loss |
| self.recon_loss = recon_loss |
| self.lambd = lambd |
|
|
| def forward(self, edge_to, edge_from, a_to, a_from, outputs): |
| embedding_to, embedding_from = outputs["umap"] |
| recon_to, recon_from = outputs["recon"] |
|
|
| recon_l = self.recon_loss(edge_to, edge_from, recon_to, recon_from, a_to, a_from) |
| |
| umap_l = self.umap_loss(embedding_to, embedding_from) |
|
|
| loss = umap_l + self.lambd * recon_l |
|
|
| return umap_l, recon_l, loss |
|
|
| class HybridLoss(nn.Module): |
| def __init__(self, umap_loss, recon_loss, smooth_loss, lambd1, lambd2): |
| super(HybridLoss, self).__init__() |
| self.umap_loss = umap_loss |
| self.recon_loss = recon_loss |
| self.smooth_loss = smooth_loss |
| self.lambd1 = lambd1 |
| self.lambd2 = lambd2 |
|
|
| def forward(self, edge_to, edge_from, a_to, a_from, embeded_to, coeff, outputs): |
| embedding_to, embedding_from = outputs["umap"] |
| recon_to, recon_from = outputs["recon"] |
|
|
| recon_l = self.recon_loss(edge_to, edge_from, recon_to, recon_from, a_to, a_from) |
| umap_l = self.umap_loss(embedding_to, embedding_from) |
| smooth_l = self.smooth_loss(embedding_to, embeded_to, coeff) |
|
|
| loss = umap_l + self.lambd1 * recon_l + self.lambd2 * smooth_l |
|
|
| return umap_l, recon_l, smooth_l, loss |
|
|
|
|
| class TemporalLoss(nn.Module): |
| def __init__(self, prev_w, device) -> None: |
| super(TemporalLoss, self).__init__() |
| self.prev_w = prev_w |
| self.device = device |
| for param_name in self.prev_w.keys(): |
| self.prev_w[param_name] = self.prev_w[param_name].to(device=self.device, dtype=torch.float32) |
|
|
| def forward(self, curr_module): |
| loss = torch.tensor(0., requires_grad=True).to(self.device) |
| |
| for name, curr_param in curr_module.named_parameters(): |
| |
| prev_param = self.prev_w[name] |
| |
| loss = loss + torch.sum(torch.square(curr_param-prev_param)) |
| |
| |
| |
| return loss |
|
|
|
|
| class DummyTemporalLoss(nn.Module): |
| def __init__(self, device) -> None: |
| super(DummyTemporalLoss, self).__init__() |
| self.device = device |
|
|
| def forward(self, curr_module): |
| loss = torch.tensor(0., requires_grad=True).to(self.device) |
| return loss |
| |
|
|
| class PositionRecoverLoss(nn.Module): |
| def __init__(self, device) -> None: |
| super(PositionRecoverLoss, self).__init__() |
| self.device = device |
| def forward(self, position, recover_position): |
| mse_loss = nn.MSELoss().to(self.device) |
| loss = mse_loss(position, recover_position) |
| return loss |
|
|
|
|
| class DVILoss(nn.Module): |
| def __init__(self, umap_loss, recon_loss, temporal_loss, lambd1, lambd2, device, umap_weight=1): |
| super(DVILoss, self).__init__() |
| self.umap_loss = umap_loss |
| self.recon_loss = recon_loss |
| self.temporal_loss = temporal_loss |
| self.lambd1 = lambd1 |
| self.lambd2 = lambd2 |
| self.device = device |
| self.umap_weight = umap_weight |
|
|
| def forward(self, edge_to, edge_from, a_to, a_from, curr_model, outputs): |
|
|
| embedding_to, embedding_from = outputs["umap"] |
| recon_to, recon_from = outputs["recon"] |
| |
|
|
| recon_l = self.recon_loss(edge_to, edge_from, recon_to, recon_from, a_to, a_from).to(self.device) |
| umap_l = self.umap_loss(embedding_to, embedding_from).to(self.device) |
| temporal_l = self.temporal_loss(curr_model).to(self.device) |
|
|
| loss = self.umap_weight * umap_l + self.lambd1 * recon_l + self.lambd2 * temporal_l |
|
|
| return self.umap_weight * umap_l, self.lambd1 *recon_l, self.lambd2 *temporal_l, loss |
|
|
| class MINE(nn.Module): |
| def __init__(self): |
| super(MINE, self).__init__() |
| |
| self.network = nn.Sequential( |
| nn.Linear(2, 100), |
| nn.ReLU(), |
| nn.Linear(100, 1), |
| ) |
|
|
| def forward(self, x, y): |
| joint = torch.cat((x, y), dim=1) |
| marginal = torch.cat((x, y[torch.randperm(x.size(0))]), dim=1) |
| t_joint = self.network(joint) |
| t_marginal = self.network(marginal) |
| |
| mi = torch.mean(t_joint) - torch.log(torch.mean(torch.exp(t_marginal))) |
| return -mi |
|
|
|
|
| class TVILoss(nn.Module): |
| def __init__(self, umap_loss, recon_loss, temporal_loss, MI_loss, lambd1, lambd2, lambd3, device): |
| super(TVILoss, self).__init__() |
| self.umap_loss = umap_loss |
| self.recon_loss = recon_loss |
| self.temporal_loss = temporal_loss |
| self.MI_loss = MI_loss |
| self.lambd1 = lambd1 |
| self.lambd2 = lambd2 |
| self.lambd3 = lambd3 |
| self.device = device |
|
|
| def forward(self, edge_to, edge_from, a_to, a_from, curr_model, outputs): |
| embedding_to, embedding_from = outputs["umap"] |
| recon_to, recon_from = outputs["recon"] |
| recon_l = self.recon_loss(edge_to, edge_from, recon_to, recon_from, a_to, a_from).to(self.device) |
| umap_l = self.umap_loss(embedding_to, embedding_from).to(self.device) |
| temporal_l = self.temporal_loss(curr_model).to(self.device) |
| |
| |
| |
| MI_l_embedding = self.MI_loss(embedding_to, embedding_from).to(self.device) |
| MI_l_edge = self.MI_loss(edge_to, edge_from).to(self.device) |
| |
| MI_l = (MI_l_embedding + MI_l_edge) / 2 |
| loss = umap_l + self.lambd1 * recon_l + self.lambd2 * temporal_l + self.lambd3 * MI_l |
|
|
| return umap_l, self.lambd1 * recon_l, self.lambd2 * temporal_l, loss |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
|
|
| |
| |
| |
|
|
| |
| |
| |
| |
|
|
| |
|
|
| |
|
|
|
|
|
|
| import tensorflow as tf |
| def umap_loss( |
| batch_size, |
| negative_sample_rate, |
| _a, |
| _b, |
| repulsion_strength=1.0, |
| ): |
| """ |
| Generate a keras-ccompatible loss function for UMAP loss |
| |
| Parameters |
| ---------- |
| batch_size : int |
| size of mini-batches |
| negative_sample_rate : int |
| number of negative samples per positive samples to train on |
| _a : float |
| distance parameter in embedding space |
| _b : float float |
| distance parameter in embedding space |
| repulsion_strength : float, optional |
| strength of repulsion vs attraction for cross-entropy, by default 1.0 |
| |
| Returns |
| ------- |
| loss : function |
| loss function that takes in a placeholder (0) and the output of the keras network |
| """ |
|
|
| @tf.function |
| def loss(placeholder_y, embed_to_from): |
| |
| embedding_to, embedding_from, weights = tf.split( |
| embed_to_from, num_or_size_splits=[2, 2, 1], axis=1 |
| ) |
| |
|
|
| |
| embedding_neg_to = tf.repeat(embedding_to, negative_sample_rate, axis=0) |
| repeat_neg = tf.repeat(embedding_from, negative_sample_rate, axis=0) |
| embedding_neg_from = tf.gather( |
| repeat_neg, tf.random.shuffle(tf.range(tf.shape(repeat_neg)[0])) |
| ) |
|
|
| |
| distance_embedding = tf.concat( |
| ( |
| tf.norm(embedding_to - embedding_from, axis=1), |
| tf.norm(embedding_neg_to - embedding_neg_from, axis=1), |
| ), |
| axis=0, |
| ) |
|
|
| |
| probabilities_distance = 1.0 / (1.0 + _a * tf.math.pow(distance_embedding, 2 * _b)) |
|
|
| |
| probabilities_graph = tf.concat( |
| (tf.ones(batch_size), tf.zeros(batch_size * negative_sample_rate)), axis=0, |
| ) |
| probabilities = tf.concat( |
| (tf.squeeze(weights), tf.zeros(batch_size * negative_sample_rate)), axis=0, |
| ) |
|
|
| |
| (attraction_loss, repellant_loss, ce_loss) = compute_cross_entropy_tf( |
| probabilities_graph, |
| probabilities_distance, |
| repulsion_strength=repulsion_strength, |
| ) |
|
|
| return tf.reduce_mean(ce_loss) |
|
|
| return loss |
|
|
| |
| def regularize_loss(): |
| ''' |
| Add temporal regularization L2 loss on weights |
| ''' |
|
|
| @tf.function |
| def loss(w_prev, w_current, to_alpha): |
| assert len(w_prev) == len(w_current) |
| |
| for j in range(len(w_prev)): |
| diff = tf.reduce_sum(tf.math.square(w_current[j] - w_prev[j])) |
| diff = tf.math.multiply(to_alpha, diff) |
| if j == 0: |
| alldiff = tf.reduce_mean(diff) |
| else: |
| alldiff += tf.reduce_mean(diff) |
| return alldiff |
|
|
| return loss |
|
|
| def reconstruction_loss( |
| beta=1 |
| ): |
| """ |
| Generate a keras-ccompatible loss function for customize reconstruction loss |
| |
| Parameters |
| ---------- |
| beta: hyperparameter |
| Returns |
| ------- |
| loss : function |
| """ |
|
|
| @tf.function |
| def loss(edge_to, edge_from, recon_to, recon_from, alpha_to, alpha_from): |
| loss1 = tf.reduce_mean(tf.reduce_mean(tf.math.multiply(tf.math.pow((1+alpha_to), beta), tf.math.pow(edge_to - recon_to, 2)), 1)) |
| loss2 = tf.reduce_mean(tf.reduce_mean(tf.math.multiply(tf.math.pow((1+alpha_from), beta), tf.math.pow(edge_from - recon_from, 2)), 1)) |
| return (loss1 + loss2)/2 |
|
|
| return loss |