| from multiprocessing.sharedctypes import Value |
| import torch |
| import torch.distributed.nn |
| from torch import distributed as dist, nn as nn |
| from torch.nn import functional as F |
| import numpy as np |
| from sklearn.metrics import average_precision_score, roc_auc_score, accuracy_score |
|
|
| try: |
| import horovod.torch as hvd |
| except ImportError: |
| hvd = None |
|
|
|
|
| def gather_features( |
| audio_features, |
| text_features, |
| audio_features_mlp=None, |
| text_features_mlp=None, |
| local_loss=False, |
| gather_with_grad=False, |
| rank=0, |
| world_size=1, |
| use_horovod=False, |
| mlp_loss=False, |
| ): |
| if use_horovod: |
| assert hvd is not None, "Please install horovod" |
| if gather_with_grad: |
| all_audio_features = hvd.allgather(audio_features) |
| all_text_features = hvd.allgather(text_features) |
| if mlp_loss: |
| all_audio_features_mlp = hvd.allgather(audio_features_mlp) |
| all_text_features_mlp = hvd.allgather(text_features_mlp) |
| else: |
| with torch.no_grad(): |
| all_audio_features = hvd.allgather(audio_features) |
| all_text_features = hvd.allgather(text_features) |
| if mlp_loss: |
| all_audio_features_mlp = hvd.allgather(audio_features_mlp) |
| all_text_features_mlp = hvd.allgather(text_features_mlp) |
| if not local_loss: |
| |
| gathered_audio_features = list( |
| all_audio_features.chunk(world_size, dim=0) |
| ) |
| gathered_text_features = list( |
| all_text_features.chunk(world_size, dim=0) |
| ) |
| gathered_audio_features[rank] = audio_features |
| gathered_text_features[rank] = text_features |
| all_audio_features = torch.cat(gathered_audio_features, dim=0) |
| all_text_features = torch.cat(gathered_text_features, dim=0) |
| if mlp_loss: |
| gathered_audio_features_mlp = list( |
| all_audio_features_mlp.chunk(world_size, dim=0) |
| ) |
| gathered_text_features_mlp = list( |
| all_text_features_mlp.chunk(world_size, dim=0) |
| ) |
| gathered_audio_features_mlp[rank] = audio_features_mlp |
| gathered_text_features_mlp[rank] = text_features_mlp |
| all_audio_features_mlp = torch.cat( |
| gathered_audio_features_mlp, dim=0 |
| ) |
| all_text_features_mlp = torch.cat(gathered_text_features_mlp, dim=0) |
| else: |
| |
| if gather_with_grad: |
| all_audio_features = torch.cat( |
| torch.distributed.nn.all_gather(audio_features), dim=0 |
| ) |
| all_text_features = torch.cat( |
| torch.distributed.nn.all_gather(text_features), dim=0 |
| ) |
| if mlp_loss: |
| all_audio_features_mlp = torch.cat( |
| torch.distributed.nn.all_gather(audio_features_mlp), dim=0 |
| ) |
| all_text_features_mlp = torch.cat( |
| torch.distributed.nn.all_gather(text_features_mlp), dim=0 |
| ) |
| else: |
| gathered_audio_features = [ |
| torch.zeros_like(audio_features) for _ in range(world_size) |
| ] |
| gathered_text_features = [ |
| torch.zeros_like(text_features) for _ in range(world_size) |
| ] |
| dist.all_gather(gathered_audio_features, audio_features) |
| dist.all_gather(gathered_text_features, text_features) |
| if mlp_loss: |
| gathered_audio_features_mlp = [ |
| torch.zeros_like(audio_features_mlp) for _ in range(world_size) |
| ] |
| gathered_text_features_mlp = [ |
| torch.zeros_like(text_features_mlp) for _ in range(world_size) |
| ] |
| dist.all_gather(gathered_audio_features_mlp, audio_features_mlp) |
| dist.all_gather(gathered_text_features_mlp, text_features_mlp) |
| if not local_loss: |
| |
| gathered_audio_features[rank] = audio_features |
| gathered_text_features[rank] = text_features |
| if mlp_loss: |
| gathered_audio_features_mlp[rank] = audio_features_mlp |
| gathered_text_features_mlp[rank] = text_features_mlp |
|
|
| all_audio_features = torch.cat(gathered_audio_features, dim=0) |
| all_text_features = torch.cat(gathered_text_features, dim=0) |
| if mlp_loss: |
| all_audio_features_mlp = torch.cat(gathered_audio_features_mlp, dim=0) |
| all_text_features_mlp = torch.cat(gathered_text_features_mlp, dim=0) |
| if mlp_loss: |
| return ( |
| all_audio_features, |
| all_text_features, |
| all_audio_features_mlp, |
| all_text_features_mlp, |
| ) |
| else: |
| return all_audio_features, all_text_features |
|
|
|
|
| class ClipLoss(nn.Module): |
| def __init__( |
| self, |
| local_loss=False, |
| gather_with_grad=False, |
| cache_labels=False, |
| rank=0, |
| world_size=1, |
| use_horovod=False, |
| mlp_loss=False, |
| weight_loss_kappa=0, |
| ): |
| super().__init__() |
| self.local_loss = local_loss |
| self.gather_with_grad = gather_with_grad |
| self.cache_labels = cache_labels |
| self.rank = rank |
| self.world_size = world_size |
| self.use_horovod = use_horovod |
| self.mlp_loss = mlp_loss |
| self.weighted_loss = bool(weight_loss_kappa != 0) |
| self.weight_loss_kappa = weight_loss_kappa |
| |
| self.prev_num_logits = 0 |
| self.labels = {} |
|
|
| def forward( |
| self, |
| audio_features, |
| text_features, |
| logit_scale_a, |
| logit_scale_t=None, |
| audio_features_mlp=None, |
| text_features_mlp=None, |
| ): |
| device = audio_features.device |
| if self.mlp_loss: |
| if self.world_size > 1: |
| ( |
| all_audio_features, |
| all_text_features, |
| all_audio_features_mlp, |
| all_text_features_mlp, |
| ) = gather_features( |
| audio_features=audio_features, |
| text_features=text_features, |
| audio_features_mlp=audio_features_mlp, |
| text_features_mlp=text_features_mlp, |
| local_loss=self.local_loss, |
| gather_with_grad=self.gather_with_grad, |
| rank=self.rank, |
| world_size=self.world_size, |
| use_horovod=self.use_horovod, |
| mlp_loss=self.mlp_loss, |
| ) |
| if self.local_loss: |
| a_logits_per_audio = ( |
| logit_scale_a * audio_features @ all_text_features_mlp.T |
| ) |
| a_logits_per_text = ( |
| logit_scale_a * text_features_mlp @ all_audio_features.T |
| ) |
| t_logits_per_audio = ( |
| logit_scale_t * audio_features_mlp @ all_text_features.T |
| ) |
| t_logits_per_text = ( |
| logit_scale_t * text_features @ all_audio_features_mlp.T |
| ) |
| else: |
| a_logits_per_audio = ( |
| logit_scale_a * all_audio_features @ all_text_features_mlp.T |
| ) |
| a_logits_per_text = a_logits_per_audio.T |
| t_logits_per_audio = ( |
| logit_scale_t * all_audio_features_mlp @ all_text_features.T |
| ) |
| t_logits_per_text = t_logits_per_audio.T |
| else: |
| a_logits_per_audio = ( |
| logit_scale_a * audio_features @ text_features_mlp.T |
| ) |
| a_logits_per_text = logit_scale_a * text_features_mlp @ audio_features.T |
| t_logits_per_audio = ( |
| logit_scale_t * audio_features_mlp @ text_features.T |
| ) |
| t_logits_per_text = logit_scale_t * text_features @ audio_features_mlp.T |
|
|
| |
| num_logits = a_logits_per_audio.shape[0] |
| if self.prev_num_logits != num_logits or device not in self.labels: |
| labels = torch.arange(num_logits, device=device, dtype=torch.long) |
| if self.world_size > 1 and self.local_loss: |
| labels = labels + num_logits * self.rank |
| if self.cache_labels: |
| self.labels[device] = labels |
| self.prev_num_logits = num_logits |
| else: |
| labels = self.labels[device] |
|
|
| if not self.weighted_loss: |
| total_loss = ( |
| F.cross_entropy(a_logits_per_audio, labels) |
| + F.cross_entropy(a_logits_per_text, labels) |
| + F.cross_entropy(t_logits_per_audio, labels) |
| + F.cross_entropy(t_logits_per_text, labels) |
| ) / 4 |
| else: |
| audio_weight = (audio_features @ audio_features.T).detach() |
| audio_weight = ( |
| torch.exp( |
| torch.sum(audio_weight, axis=1) |
| / (self.weight_loss_kappa * len(audio_weight)) |
| ) |
| ).detach() |
| text_weight = (text_features @ text_features.T).detach() |
| text_weight = ( |
| torch.exp( |
| torch.sum(text_weight, axis=1) |
| / (self.weight_loss_kappa * len(text_features)) |
| ) |
| ).detach() |
| total_loss = ( |
| F.cross_entropy(a_logits_per_audio, labels, weight=audio_weight) |
| + F.cross_entropy(a_logits_per_text, labels, weight=audio_weight) |
| + F.cross_entropy(t_logits_per_audio, labels, weight=text_weight) |
| + F.cross_entropy(t_logits_per_text, labels, weight=text_weight) |
| ) / 4 |
| else: |
| if self.world_size > 1: |
| all_audio_features, all_text_features = gather_features( |
| audio_features=audio_features, |
| text_features=text_features, |
| local_loss=self.local_loss, |
| gather_with_grad=self.gather_with_grad, |
| rank=self.rank, |
| world_size=self.world_size, |
| use_horovod=self.use_horovod, |
| mlp_loss=self.mlp_loss, |
| ) |
|
|
| if self.local_loss: |
| logits_per_audio = ( |
| logit_scale_a * audio_features @ all_text_features.T |
| ) |
| logits_per_text = ( |
| logit_scale_a * text_features @ all_audio_features.T |
| ) |
| else: |
| logits_per_audio = ( |
| logit_scale_a * all_audio_features @ all_text_features.T |
| ) |
| logits_per_text = logits_per_audio.T |
| else: |
| logits_per_audio = logit_scale_a * audio_features @ text_features.T |
| logits_per_text = logit_scale_a * text_features @ audio_features.T |
|
|
| |
| num_logits = logits_per_audio.shape[0] |
| if self.prev_num_logits != num_logits or device not in self.labels: |
| labels = torch.arange(num_logits, device=device, dtype=torch.long) |
| if self.world_size > 1 and self.local_loss: |
| labels = labels + num_logits * self.rank |
| if self.cache_labels: |
| self.labels[device] = labels |
| self.prev_num_logits = num_logits |
| else: |
| labels = self.labels[device] |
| if not self.weighted_loss: |
| total_loss = ( |
| F.cross_entropy(logits_per_audio, labels) |
| + F.cross_entropy(logits_per_text, labels) |
| ) / 2 |
| else: |
| audio_weight = (all_audio_features @ all_audio_features.T).detach() |
| audio_weight = ( |
| torch.exp( |
| torch.sum(audio_weight, axis=1) |
| / (self.weight_loss_kappa * len(all_audio_features)) |
| ) |
| ).detach() |
| text_weight = (all_text_features @ all_text_features.T).detach() |
| text_weight = ( |
| torch.exp( |
| torch.sum(text_weight, axis=1) |
| / (self.weight_loss_kappa * len(all_text_features)) |
| ) |
| ).detach() |
| total_loss = ( |
| F.cross_entropy(logits_per_audio, labels, weight=text_weight) |
| + F.cross_entropy(logits_per_text, labels, weight=audio_weight) |
| ) / 2 |
| return total_loss |
|
|
|
|
| def lp_gather_features(pred, target, world_size=1, use_horovod=False): |
| if use_horovod: |
| assert hvd is not None, "Please install horovod" |
| with torch.no_grad(): |
| all_preds = hvd.allgather(pred) |
| all_targets = hvd.allgath(target) |
| else: |
| gathered_preds = [torch.zeros_like(pred) for _ in range(world_size)] |
| gathered_targets = [torch.zeros_like(target) for _ in range(world_size)] |
|
|
| dist.all_gather(gathered_preds, pred) |
| dist.all_gather(gathered_targets, target) |
| all_preds = torch.cat(gathered_preds, dim=0) |
| all_targets = torch.cat(gathered_targets, dim=0) |
|
|
| return all_preds, all_targets |
|
|
|
|
| def get_map(pred, target): |
| pred = torch.sigmoid(pred).numpy() |
| target = target.numpy() |
| return np.mean(average_precision_score(target, pred, average=None)) |
|
|
|
|
| def get_acc(pred, target): |
| pred = torch.argmax(pred, 1).numpy() |
| target = torch.argmax(target, 1).numpy() |
| return accuracy_score(target, pred) |
|
|
|
|
| def get_mauc(pred, target): |
| pred = torch.sigmoid(pred).numpy() |
| target = target.numpy() |
| return np.mean(roc_auc_score(target, pred, average=None)) |
|
|
|
|
| class LPMetrics(object): |
| def __init__(self, metric_names=["map", "acc", "mauc"]): |
| self.metrics = [] |
| for name in metric_names: |
| self.metrics.append(self.get_metric(name)) |
| self.metric_names = metric_names |
|
|
| def get_metric(self, name): |
| if name == "map": |
| return get_map |
| elif name == "acc": |
| return get_acc |
| elif name == "mauc": |
| return get_mauc |
| else: |
| raise ValueError(f"the metric should be at least one of [map, acc, mauc]") |
|
|
| def evaluate_mertics(self, pred, target): |
| metric_dict = {} |
| for i in range(len(self.metric_names)): |
| metric_dict[self.metric_names[i]] = self.metrics[i](pred, target) |
| return metric_dict |
|
|
|
|
| def calc_celoss(pred, target): |
| target = torch.argmax(target, 1).long() |
| return nn.CrossEntropyLoss()(pred, target) |
|
|
|
|
| class LPLoss(nn.Module): |
| def __init__(self, loss_name): |
| super().__init__() |
| if loss_name == "bce": |
| self.loss_func = nn.BCEWithLogitsLoss() |
| elif loss_name == "ce": |
| self.loss_func = calc_celoss |
| elif loss_name == "mse": |
| self.loss_func = nn.MSELoss() |
| else: |
| raise ValueError(f"the loss func should be at least one of [bce, ce, mse]") |
|
|
| def forward(self, pred, target): |
| loss = self.loss_func(pred, target) |
| return loss |
|
|