from torch import Tensor import torch.distributed as dist import torch import torch.nn.functional as F import os # 新增 class SimpleContrastiveLoss: def __init__(self, temperature: float = 0.02, alpha: float = 0.05, # [新增] 是否做双向对比(q->p 与 p->q),默认关闭 symmetric: bool = False, # [新增] 标签平滑,缓解假负例影响 label_smoothing: float = 0.0): self.temperature = temperature self.alpha = alpha # 视角0(第20层)的权重,视角1(最后一层)权重=1-alpha self.symmetric = symmetric # [新增] self.label_smoothing = label_smoothing # [新增] # [新增] 统一标签构造(适配 n_q != n_t 的场景) def _make_labels(self, n_rows: int, n_cols: int, device): factor = max(1, n_cols // n_rows) # labels: 0, factor, 2*factor, ... return torch.arange(0, n_rows * factor, factor, device=device, dtype=torch.long) # [新增] 单次 CE 计算(支持 label_smoothing) def _ce(self, logits: Tensor, labels: Tensor, reduction: str = 'mean') -> Tensor: return F.cross_entropy(logits, labels, reduction=reduction, label_smoothing=self.label_smoothing) def __call__(self, x: Tensor, y: Tensor, target: Tensor = None, reduction: str = 'mean') -> Tensor: """ - 常规:x=[B, D], y=[B, D] -> 单向 CE(InfoNCE) - 扩展: * x=[B, 2, D], y=[B, D] -> 两个 query 视角各自对 y 做 CE,并按 alpha 加权 * x=[B, D], y=[B, 2, D] -> x 分别对 y 的两个视角做 CE,并按 alpha 加权 * x=[B, 2, D], y=[B, 2, D] -> 匹配视角(0↔0, 1↔1)各算 CE 加权求和 - [新增] symmetric=True 时,加上反向(p->q)的 CE,最终取均值 """ Bx = x.size(0) By = y.size(0) # 构造标签 if target is None: target = self._make_labels(Bx, By, x.device) # 单视角 x,y if x.dim() == 2 and y.dim() == 2: logits = (x @ y.transpose(0, 1)) / self.temperature loss = self._ce(logits, target, reduction=reduction) if self.symmetric: # [新增] 双向 labels_rev = self._make_labels(By, Bx, x.device) logits_rev = (y @ x.transpose(0, 1)) / self.temperature loss = 0.5 * (loss + self._ce(logits_rev, labels_rev, reduction=reduction)) return loss # x为双视角, y为单视角 if x.dim() == 3 and y.dim() == 2: assert x.size(1) == 2, f"Expect x=[B,2,D], got {tuple(x.size())}" w0, w1 = self.alpha, 1.0 - self.alpha q0, q1 = x[:, 0, :], x[:, 1, :] logits0 = (q0 @ y.transpose(0, 1)) / self.temperature logits1 = (q1 @ y.transpose(0, 1)) / self.temperature loss0 = self._ce(logits0, target, reduction=reduction) loss1 = self._ce(logits1, target, reduction=reduction) loss = w0 * loss0 + w1 * loss1 if self.symmetric: # [新增] 双向:y -> x 的两个视角 labels_rev = self._make_labels(By, Bx, x.device) logits0_rev = (y @ q0.transpose(0, 1)) / self.temperature logits1_rev = (y @ q1.transpose(0, 1)) / self.temperature loss0_rev = self._ce(logits0_rev, labels_rev, reduction=reduction) loss1_rev = self._ce(logits1_rev, labels_rev, reduction=reduction) loss = 0.5 * (loss + (w0 * loss0_rev + w1 * loss1_rev)) return loss # x为单视角, y为双视角 if x.dim() == 2 and y.dim() == 3: assert y.size(1) == 2, f"Expect y=[B,2,D], got {tuple(y.size())}" w0, w1 = self.alpha, 1.0 - self.alpha p0, p1 = y[:, 0, :], y[:, 1, :] logits0 = (x @ p0.transpose(0, 1)) / self.temperature logits1 = (x @ p1.transpose(0, 1)) / self.temperature loss0 = self._ce(logits0, target, reduction=reduction) loss1 = self._ce(logits1, target, reduction=reduction) loss = w0 * loss0 + w1 * loss1 if self.symmetric: # [新增] 双向:y 的两个视角 -> x labels_rev = self._make_labels(By, Bx, x.device) logits0_rev = (p0 @ x.transpose(0, 1)) / self.temperature logits1_rev = (p1 @ x.transpose(0, 1)) / self.temperature loss0_rev = self._ce(logits0_rev, labels_rev, reduction=reduction) loss1_rev = self._ce(logits1_rev, labels_rev, reduction=reduction) loss = 0.5 * (loss + (w0 * loss0_rev + w1 * loss1_rev)) return loss # 匹配视角:x,y均为双视角 if x.dim() == 3 and y.dim() == 3: assert x.size(1) == y.size(1) == 2, f"Expect x,y=[B,2,D], got {tuple(x.size())}, {tuple(y.size())}" w0, w1 = self.alpha, 1.0 - self.alpha q0, q1 = x[:, 0, :], x[:, 1, :] p0, p1 = y[:, 0, :], y[:, 1, :] logits0 = (q0 @ p0.transpose(0, 1)) / self.temperature logits1 = (q1 @ p1.transpose(0, 1)) / self.temperature loss0 = self._ce(logits0, target, reduction=reduction) loss1 = self._ce(logits1, target, reduction=reduction) loss = w0 * loss0 + w1 * loss1 if self.symmetric: # [新增] 双向(p->q)匹配视角 labels_rev = self._make_labels(By, Bx, x.device) logits0_rev = (p0 @ q0.transpose(0, 1)) / self.temperature logits1_rev = (p1 @ q1.transpose(0, 1)) / self.temperature loss0_rev = self._ce(logits0_rev, labels_rev, reduction=reduction) loss1_rev = self._ce(logits1_rev, labels_rev, reduction=reduction) loss = 0.5 * (loss + (w0 * loss0_rev + w1 * loss1_rev)) return loss raise ValueError(f"Unsupported shapes: x {tuple(x.size())}, y {tuple(y.size())}") class DistributedContrastiveLoss(SimpleContrastiveLoss): def __init__(self, n_target: int = 0, scale_loss: bool = True, temperature: float = 0.02, alpha: float = 0.05, # [新增] 透传 symmetric/label_smoothing symmetric: bool = False, label_smoothing: float = 0.0): assert dist.is_initialized(), "Distributed training has not been properly initialized." super().__init__(temperature=temperature, alpha=alpha, symmetric=symmetric, label_smoothing=label_smoothing) # [修改] self.word_size = dist.get_world_size() self.rank = dist.get_rank() self.scale_loss = scale_loss def __call__(self, x: Tensor, y: Tensor, **kwargs): dist_x = self.gather_tensor(x) dist_y = self.gather_tensor(y) loss = super().__call__(dist_x, dist_y, **kwargs) if self.scale_loss: loss = loss * self.word_size return loss def gather_tensor(self, t): gathered = [torch.empty_like(t) for _ in range(self.word_size)] dist.all_gather(gathered, t) gathered[self.rank] = t # 保留本rank的梯度 return torch.cat(gathered, dim=0) class InExampleContrastiveLoss: """ 保持不变 x.shape=[bsz, hdim], y.shape=[bsz, num_label, hdim] """ def __init__(self, n_hard_negatives: int = 0, temperature: float = 1.0, ndim: int = None, *args, **kwargs): self.target_per_qry = n_hard_negatives + 1 self.temperature = temperature self.ndim = ndim def __call__(self, x: Tensor, y: Tensor, reduction: str = 'mean'): if torch.distributed.is_initialized(): x = dist_utils.dist_gather(x) y = dist_utils.dist_gather(y) bsz, ndim = x.size(0), x.size(1) target = torch.zeros(bsz, dtype=torch.long, device=x.device) if self.ndim: ndim = self.ndim x = x[:, :ndim] y = y[:, :ndim] logits = torch.einsum('bod,bsd->bs', x.view(bsz, 1, ndim), y.view(bsz, -1, ndim)) * self.temperature preds = torch.argmax(logits, dim=-1) loss = F.cross_entropy(logits, target, reduction=reduction) loss_detail = {"logits": logits, "labels": target, "preds": preds} return loss, loss_detail