import torch import torch.nn.functional as F import torch.distributed as dist from torch.utils.data import DataLoader from transformers import Trainer from src.utils import batch_to_device from src.classifier_utils import HomogeneousBatchSampler class EarlyExitTrainer(Trainer): def __init__(self, backbone_model, target_layer_idx, model_args, *args, **kwargs): self.max_length = kwargs.pop('max_length', 512) super().__init__(*args, **kwargs) self.backbone = backbone_model.to(self.args.device) self.backbone.eval() self.target_layer_idx = target_layer_idx self.model_args = model_args # 添加梯度检查标志 self._grad_check_done = False def get_train_dataloader(self) -> DataLoader: if self.train_dataset is None: raise ValueError("Trainer: training requires a train_dataset.") train_sampler = HomogeneousBatchSampler( self.train_dataset, batch_size=self._train_batch_size, drop_last=self.args.dataloader_drop_last ) return DataLoader( self.train_dataset, batch_sampler=train_sampler, collate_fn=self.data_collator, num_workers=self.args.dataloader_num_workers, pin_memory=self.args.dataloader_pin_memory, ) def create_optimizer(self): if self.optimizer is None: print(f"\n[Debug Rank {self.args.local_rank}] Creating Optimizer...") decay_parameters = [] no_decay_parameters = [] trainable_count = 0 # 注意:self.model 可能是 DDP 包装后的,所以用 self.model.named_parameters() for name, param in self.model.named_parameters(): if not param.requires_grad: continue trainable_count += 1 if "bias" in name or "LayerNorm" in name or "BatchNorm" in name: no_decay_parameters.append(param) else: decay_parameters.append(param) print(f"[Debug] Found {trainable_count} trainable parameters.") self.optimizer = torch.optim.AdamW( [ {"params": decay_parameters, "weight_decay": self.args.weight_decay}, {"params": no_decay_parameters, "weight_decay": 0.0}, ], lr=self.args.learning_rate, eps=self.args.adam_epsilon, ) return self.optimizer def _perform_pooling(self, hidden_state, attention_mask): pooling_method = self.model_args.pooling batch_size = hidden_state.shape[0] if pooling_method == 'last' or pooling_method == 'eos': left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0]) if left_padding: reps = hidden_state[torch.arange(batch_size), -1, :] else: eos_indices = attention_mask.sum(dim=1) - 1 reps = hidden_state[torch.arange(batch_size, device=hidden_state.device), eos_indices] else: reps = hidden_state[:, -1, :] if self.model_args.normalize: reps = F.normalize(reps, p=2, dim=-1) return reps def compute_loss(self, model, inputs, return_outputs=False, **kwargs): """ 覆盖 Trainer 的 compute_loss 方法 这是 Trainer 真正调用的损失计算入口 """ loss = self._compute_early_exit_loss(model, inputs) return (loss, None) if return_outputs else loss def _compute_early_exit_loss(self, model, inputs) -> torch.Tensor: """ 计算 Early Exit 分类器的损失 """ self.backbone.eval() model.train() device = self.args.device qry_inputs, tgt_inputs = inputs qry_inputs = batch_to_device(qry_inputs, device) tgt_inputs = batch_to_device(tgt_inputs, device) with torch.no_grad(): # Backbone Forward tgt_outputs = self.backbone.encoder(**tgt_inputs, return_dict=True, output_hidden_states=True) tgt_reps = self._perform_pooling(tgt_outputs.hidden_states[-1], tgt_inputs['attention_mask']) qry_outputs = self.backbone.encoder(**qry_inputs, return_dict=True, output_hidden_states=True) q_hidden_mid = qry_outputs.hidden_states[self.target_layer_idx] qry_reps_mid = self._perform_pooling(q_hidden_mid, qry_inputs['attention_mask']) batch_size = qry_reps_mid.size(0) # 特征工程 backbone_ptr = self.backbone.module if hasattr(self.backbone, 'module') else self.backbone temp = getattr(backbone_ptr, 'temperature', 0.02) sim_matrix = torch.matmul(qry_reps_mid, tgt_reps.T) / temp diag_mask = torch.eye(batch_size, dtype=torch.bool, device=device) sim_matrix_no_diag = sim_matrix.masked_fill(diag_mask, -1e9) all_topk_vals, all_topk_inds = torch.topk(sim_matrix, k=2, dim=1) feat_s1 = torch.diag(sim_matrix) # 使用对角线作为正样本分数 # 【关键修复】Margin = Top1 - Top2,而不是 Top1 - Top1 # all_topk_vals[:, 0] 是最大值(通常就是对角线) # all_topk_vals[:, 1] 是第二大值(hard negative) feat_margin = feat_s1 - all_topk_vals[:, 1] probs = torch.softmax(sim_matrix, dim=1) feat_entropy = -(probs * torch.log(probs + 1e-6)).sum(dim=1) # Norm/Var 近似 left_padding = (qry_inputs['attention_mask'][:, -1].sum() == batch_size) if left_padding: q_raw_pooled = q_hidden_mid[:, -1, :] else: eos_indices = qry_inputs['attention_mask'].sum(dim=1) - 1 q_raw_pooled = q_hidden_mid[torch.arange(batch_size, device=device), eos_indices] feat_norm = torch.norm(q_raw_pooled, p=2, dim=1) feat_var = torch.var(q_raw_pooled, dim=1) scalar_inputs = torch.stack([feat_s1, feat_margin, feat_entropy, feat_norm, feat_var], dim=1) # 模态判断 modality_idx = torch.zeros(batch_size, dtype=torch.long, device=device) if 'pixel_values' in qry_inputs and qry_inputs['pixel_values'] is not None: pv = qry_inputs['pixel_values'] if isinstance(pv, list): for i, item in enumerate(pv): if item is not None: modality_idx[i] = 1 elif isinstance(pv, torch.Tensor) and pv.numel() > 0: modality_idx.fill_(1) # 标签生成:衡量中间层检索质量 # 方法1:检查Top-1是否正确(对角线) preds = torch.argmax(sim_matrix, dim=1) ground_truth = torch.arange(batch_size, device=device) is_correct = (preds == ground_truth).float() # 方法2:使用归一化的相似度分数作为连续标签 diag_scores = torch.diag(sim_matrix) # 正样本得分 max_scores = sim_matrix.max(dim=1)[0] # 最高得分 # 混合标签:如果预测正确给1.0,否则用相对分数 # 这样可以给模型提供更丰富的监督信号 relative_scores = torch.sigmoid((diag_scores - max_scores) * 2) # 放大差异 labels = torch.where( is_correct.bool(), torch.ones_like(is_correct), relative_scores ).unsqueeze(1) # Classifier Forward # scalar_inputs 已经 detach,不需要梯度 # 模型参数自动有 requires_grad=True pred_probs = model(scalar_inputs, modality_idx) loss = F.binary_cross_entropy(pred_probs, labels) # ======================================================== # 【梯度探针】 仅在前 10 步检查梯度流 # ======================================================== if self.state.global_step < 10 and self.args.local_rank == 0: print(f"\n[Probe Step {self.state.global_step}] Loss: {loss.item():.4f}") print(f" - Loss has grad_fn: {loss.grad_fn is not None}") print(f" - Pred Probs: Mean={pred_probs.mean().item():.4f}, Std={pred_probs.std().item():.4f}") print(f" - Pred Probs Range: [{pred_probs.min().item():.4f}, {pred_probs.max().item():.4f}]") print(f" - Labels: Mean={labels.mean().item():.4f}, Std={labels.std().item():.4f}") print(f" - Labels Range: [{labels.min().item():.4f}, {labels.max().item():.4f}]") print(f" - Correct Rate: {is_correct.mean().item():.4f}") # 检查输入特征的统计 print(f" - Scalar Inputs: Mean={scalar_inputs.mean().item():.4f}, Std={scalar_inputs.std().item():.4f}") print(f" - Modality: Text={((modality_idx==0).sum().item())}, Image={((modality_idx==1).sum().item())}") # ======================================================== return loss def training_step(self, model, inputs, num_items_in_batch=None) -> torch.Tensor: """ 覆盖 Trainer 的 training_step 以添加梯度监控 注意:这个方法在新版 Transformers 中被调用,负责完整的前向+反向过程 Args: model: 要训练的模型 inputs: 输入数据 num_items_in_batch: batch 中的样本数(新版 Transformers 会传入) """ model.train() inputs = self._prepare_inputs(inputs) # 计算损失 with self.compute_loss_context_manager(): loss = self.compute_loss(model, inputs) if self.args.n_gpu > 1: loss = loss.mean() # 反向传播(使用 Accelerator) self.accelerator.backward(loss) # 在前几步检查梯度 if not self._grad_check_done and self.args.local_rank == 0: print(f"\n[Gradient Check After Backward - Step {self.state.global_step}]") inner_model = model.module if hasattr(model, 'module') else model has_grad = False total_grad_norm = 0.0 for name, param in inner_model.named_parameters(): if param.grad is not None: has_grad = True grad_norm = param.grad.norm().item() total_grad_norm += grad_norm ** 2 if self.state.global_step < 3: print(f" - {name}: grad_norm={grad_norm:.6f}") total_grad_norm = total_grad_norm ** 0.5 print(f" - Total Grad Norm: {total_grad_norm:.6f}") print(f" - Has Gradient: {has_grad}") if self.state.global_step >= 2: self._grad_check_done = True return loss.detach() / self.args.gradient_accumulation_steps