| | import torch
|
| | import torch.nn as nn
|
| | import torch.nn.functional as F
|
| |
|
| | def infoNCE_loss1(mol_features, ms_features, temperature=0.1, norm=True):
|
| |
|
| | if norm:
|
| | mol_features = F.normalize(mol_features, p=2, dim=1)
|
| | ms_features = F.normalize(ms_features, p=2, dim=1)
|
| |
|
| |
|
| | logits = torch.mm(mol_features, ms_features.T) / temperature
|
| |
|
| |
|
| | batch_size = mol_features.size(0)
|
| | labels = torch.arange(batch_size, device=mol_features.device)
|
| |
|
| |
|
| | loss_mol = F.cross_entropy(logits, labels)
|
| | loss_trans = F.cross_entropy(logits.T, labels)
|
| | loss = (loss_mol + loss_trans) / 2
|
| |
|
| | return loss
|
| |
|
| | def infoNCE_loss2(mol_features, ms_features, temperature=0.1, alpha=0.75, norm=True):
|
| | """
|
| | 使用更合适的temperature (0.07是CLIP中常用的值)
|
| | 添加更多的数值稳定性措施
|
| | """
|
| | if norm:
|
| | mol_features = F.normalize(mol_features, p=2, dim=1)
|
| | ms_features = F.normalize(ms_features, p=2, dim=1)
|
| |
|
| | batch_size = mol_features.size(0)
|
| |
|
| |
|
| | logits_ab = torch.matmul(mol_features, ms_features.T) / temperature
|
| | logits_ba = torch.matmul(ms_features, mol_features.T) / temperature
|
| |
|
| |
|
| | labels = torch.arange(batch_size, device=mol_features.device)
|
| |
|
| |
|
| | loss_ab = F.cross_entropy(logits_ab, labels)
|
| | loss_ba = F.cross_entropy(logits_ba, labels)
|
| |
|
| | return alpha * loss_ab + (1 - alpha) * loss_ba
|
| |
|
| |
|
| | def contrastive_loss_with_hard_negatives(features1, features2, margin=1.0, hard_negative_ratio=0.3):
|
| | """
|
| | 改进的对比损失函数,包含困难负样本挖掘
|
| | """
|
| | batch_size = features1.shape[0]
|
| |
|
| |
|
| | similarity = torch.matmul(features1, features2.t())
|
| |
|
| |
|
| | positive_similarity = torch.diag(similarity)
|
| |
|
| |
|
| | mask = ~torch.eye(batch_size, dtype=torch.bool)
|
| | negative_similarities = similarity[mask].view(batch_size, batch_size-1)
|
| |
|
| |
|
| | k = int(batch_size * hard_negative_ratio)
|
| | hard_negatives, _ = torch.topk(negative_similarities, k=k, dim=1)
|
| |
|
| |
|
| | loss = 0
|
| | for i in range(batch_size):
|
| | pos_loss = 1 - positive_similarity[i]
|
| | neg_loss = torch.mean(torch.clamp(hard_negatives[i] - margin, min=0))
|
| | loss += pos_loss + neg_loss
|
| |
|
| | return loss / batch_size |