| #!/usr/bin/env python3 | |
| # ------------------------------------------------------------------------ | |
| # Copyright (c) 2024-present, BAAI. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ------------------------------------------------------------------------ | |
| """URSA one-step distillation trainer (DiMO-style), 8-GPU distributed. | |
| Verified native inference regime (from A/B testing — ground truth): | |
| height=320, width=512, num_frames=49, guidance_scale=7, teacher_steps=50. | |
| no_cfg (guidance_scale=1) is NOT a valid baseline for this URSA checkpoint. | |
| Defaults in configs/distill_dimo.yaml are aligned to this regime. | |
| Launch command: | |
| accelerate launch --config_file accelerate_configs/deepspeed_zero2.yaml \\ | |
| --machine_rank 0 --num_machines 1 --num_processes 8 \\ | |
| scripts/train_distill_dimo.py \\ | |
| config="./configs/distill_dimo.yaml" \\ | |
| experiment.output_dir="./experiments/distill_dimo" \\ | |
| distill.teacher_ckpt="/path/to/URSA-1.7B-IBQ1024" \\ | |
| distill.prompt_source="/data/Koala_36M_*.csv" \\ | |
| distill.batch_size_per_gpu=1 | |
| Smoke test (single-GPU, 50 steps): | |
| accelerate launch --num_processes 1 \\ | |
| scripts/train_distill_dimo.py \\ | |
| config="./configs/distill_dimo.yaml" \\ | |
| experiment.output_dir="./experiments/smoke" \\ | |
| distill.teacher_ckpt="/path/to/URSA-1.7B-IBQ1024" \\ | |
| distill.prompt_source="prompts.txt" \\ | |
| training.max_train_steps=50 | |
| Algorithm summary (9 stages per iteration) | |
| ------------------------------------------ | |
| Stage 1 Tokenize → txt_ids [B, L] (CPU in worker, moved to GPU in run_step) | |
| Stage 2 x_init ~ Uniform(K) (+ p_init mixing from x_hat_prev) | |
| Stage 3 no_grad student(x_init) → x_hat [B, N], logp for PG | |
| Stage 4 x_t = scheduler.add_noise(x_hat_4d, t) [B,T,H,W], long | |
| Stage 5 no_grad teacher(x_t) → z_T_cond [B,N,K] (+ uncond if CFG) | |
| Stage 6 aux update × fake_rounds: Jeffrey(z_T_target, z_A_cond).backward() | |
| Stage 7 student KD forward on x_t → z_S_cond [B,N,K] | |
| Stage 8 reward = -KL(z_T_cond, z_S_cond) [detached]; adv = reward - baseline_ema | |
| Stage 9 Two-backward: | |
| 9a _no_sync_backward(lambda_kd * loss_kd) [frees KD graph] | |
| 9b accelerator.backward(lambda_pg * loss_pg - lambda_ent * H_mean) | |
| opt_student.step() | |
| """ | |
| import collections | |
| import copy | |
| import os | |
| import sys | |
| from typing import Optional | |
| import torch | |
| import torch.nn.functional as F | |
| from torch.utils.data import DataLoader | |
| _ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) | |
| if _ROOT not in sys.path: | |
| sys.path.insert(0, _ROOT) | |
| from diffnext.engine import engine_utils | |
| from diffnext.engine.lr_scheduler import CosineLR | |
| from diffnext.pipelines.ursa.pipeline_ursa_distill_dimo import ( | |
| URSADistillDiMOPipeline, | |
| _get_logits, | |
| _stable_kl, | |
| _stable_jeffrey, | |
| _build_guided_logits, | |
| _cfg_warmup_prob, | |
| _no_sync_backward, | |
| _reset_flex_attn, | |
| VERIFIED_NATIVE_DEFAULTS, | |
| check_verified_regime, | |
| ) | |
| from diffnext.utils import accelerate_utils | |
| from diffnext.utils import omegaconf_utils | |
| from diffnext.utils import profiler | |
| from src.distill.prompt_dataset import ( | |
| CSVSpec, | |
| InfiniteDataLoader, | |
| PromptDataset, | |
| make_collate_fn, | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # DistillTwinModel — single nn.Module wrapping student + aux for DeepSpeed | |
| # --------------------------------------------------------------------------- | |
| class DistillTwinModel(torch.nn.Module): | |
| """Wrapper that holds both student and aux as sub-modules. | |
| DeepSpeed (via Accelerate) only allows a single model in | |
| ``accelerator.prepare()``. This container satisfies that constraint | |
| while keeping student and aux as separately addressable sub-modules | |
| with independent param groups. | |
| """ | |
| def __init__(self, student: torch.nn.Module, aux: torch.nn.Module): | |
| super().__init__() | |
| self.student = student | |
| self.aux = aux | |
| def forward(self, which: str, input_ids, rope_pos=None, **kwargs): | |
| if which == "student": | |
| return self.student(input_ids, rope_pos=rope_pos, **kwargs) | |
| elif which == "aux": | |
| return self.aux(input_ids, rope_pos=rope_pos, **kwargs) | |
| else: | |
| raise ValueError(f"DistillTwinModel: unknown sub-model '{which}'") | |
| # --------------------------------------------------------------------------- | |
| # DistillTrainer | |
| # --------------------------------------------------------------------------- | |
| class DistillTrainer: | |
| """Training orchestrator for on-policy one-step distillation. | |
| Reuses the same accelerate / logger / checkpoint API as | |
| ``diffnext.engine.train_engine.Trainer`` so the distributed setup is | |
| identical to the original training framework. | |
| Key differences from standard Trainer: | |
| - Three models (teacher frozen, student + aux trainable) | |
| - Student and aux are wrapped in a single ``DistillTwinModel`` so that | |
| only one ``accelerator.prepare()`` call is needed (DeepSpeed requirement) | |
| - One optimizer with two param_groups: [0]=student, [1]=aux | |
| - LR schedulers for both param groups | |
| - Two-backward strategy within each step | |
| - PromptDataset (no video latents; prompt-only) | |
| - Stage 6 freezes student / unfreezes aux; Stages 7-9 do the reverse | |
| """ | |
| def __init__(self, config, accelerator, logger): | |
| self.config = config | |
| self.accelerator = accelerator | |
| self.logger = logger | |
| cfg = config.distill | |
| dtype = accelerate_utils.precision_to_dtype(config.training.mixed_precision) | |
| self.device = accelerator.device | |
| # -------- Pipeline (teacher + student + aux) ---------------------- | |
| logger.info(f"[init] Loading teacher from {cfg.teacher_ckpt} ...") | |
| self.pipe = URSADistillDiMOPipeline( | |
| teacher_ckpt=cfg.teacher_ckpt, | |
| compute_dtype=dtype, | |
| aux_noise_std=float(cfg.get("aux_noise_std", 0.0)), | |
| ) | |
| # Move teacher to GPU (not prepared by accelerate — frozen). | |
| self.pipe.teacher = self.pipe.teacher.to(self.device) | |
| self.pipe.scheduler.to(device=self.device) | |
| # Compute latents shape from video geometry. | |
| from src.distill.utils_ursa_inputs import compute_latents_shape | |
| # Read VAE strides from pipeline (falls back to URSA defaults 4/8). | |
| vae_t = int(getattr(self.pipe, "vae_temporal_stride", 4)) | |
| vae_s = int(getattr(self.pipe, "vae_spatial_stride", 8)) | |
| self.latents_shape = compute_latents_shape( | |
| cfg.num_frames, cfg.height, cfg.width, vae_t, vae_s | |
| ) | |
| T, H, W = self.latents_shape | |
| self.N = T * H * W | |
| self.K = self.pipe.codebook_size | |
| logger.info( | |
| f"[init] latents_shape=({T},{H},{W}) N={self.N} K={self.K} " | |
| f"CFG={'ON' if cfg.enable_teacher_cfg else 'OFF'}" | |
| ) | |
| # Pre-compute uncond token IDs (empty string, [1, L]) on CPU. | |
| self.txt_uncond_base_cpu = self.pipe.tokenizer( | |
| [""], | |
| max_length=int(cfg.max_prompt_length), | |
| padding="max_length", | |
| padding_side="left", | |
| truncation=True, | |
| return_tensors="pt", | |
| ).input_ids # [1, L] CPU | |
| # -------- Optimizers (before accelerate.prepare) ------------------ | |
| # Single optimizer with two param groups: | |
| # group[0] = student params, group[1] = aux params | |
| opt_cls = torch.optim.AdamW | |
| opt_s_params = dict( | |
| lr=float(config.optimizer_student.params.lr), | |
| betas=tuple(config.optimizer_student.params.get("betas", [0.9, 0.95])), | |
| weight_decay=float(config.optimizer_student.params.get("weight_decay", 0.01)), | |
| ) | |
| opt_a_params = dict( | |
| lr=float(config.optimizer_aux.params.lr), | |
| betas=tuple(config.optimizer_aux.params.get("betas", [0.9, 0.95])), | |
| weight_decay=float(config.optimizer_aux.params.get("weight_decay", 0.01)), | |
| ) | |
| def _enable_gcpt(m): | |
| # m.model.layers 是 Qwen3Model 的层列表 | |
| for layer in m.model.layers: | |
| layer.gradient_checkpointing = True | |
| layer.self_attn.gradient_checkpointing = True | |
| layer.mlp.gradient_checkpointing = True | |
| _enable_gcpt(self.pipe.student) | |
| _enable_gcpt(self.pipe.aux) | |
| # -------- 断点续传:在 ZeRO-3 切分参数前加载权重 ------------------ | |
| self.global_step = int(config.experiment.get("resume_iter", 0)) | |
| if self.global_step > 0: | |
| ckpt_dir = os.path.join( | |
| config.experiment.output_dir, "checkpoints", f"checkpoint-{self.global_step}" | |
| ) | |
| if os.path.exists(ckpt_dir): | |
| logger.info(f"[Resume] 正在从 {ckpt_dir} 恢复 Student 和 Aux 的权重...") | |
| # 必须在 map_location="cpu" 下加载,防止爆显存,随后 prepare 会自动分配 | |
| self.pipe.student.load_state_dict(torch.load(os.path.join(ckpt_dir, "student.pt"), map_location="cpu")) | |
| self.pipe.aux.load_state_dict(torch.load(os.path.join(ckpt_dir, "aux.pt"), map_location="cpu")) | |
| else: | |
| logger.warning(f"[Resume] 找不到检查点 {ckpt_dir},将从随机初始状态起步!") | |
| # -------- Wrap student + aux into a single DistillTwinModel -------- | |
| twin_model = DistillTwinModel(self.pipe.student, self.pipe.aux) | |
| opt_raw = opt_cls([ | |
| {"params": list(self.pipe.student.parameters()), **opt_s_params}, | |
| {"params": list(self.pipe.aux.parameters()), **opt_a_params}, | |
| ]) | |
| # -------- accelerate.prepare: single model + single optimizer ------ | |
| # Teacher is NOT prepared (frozen; no grad sync needed). | |
| self.model, self.optimizer = accelerator.prepare(twin_model, opt_raw) | |
| # LR schedulers (step() called manually at end of each step). | |
| self.scheduler_s = CosineLR( | |
| lr_max=float(config.optimizer_student.params.lr), | |
| lr_min=float(config.lr_scheduler.params.get("lr_min", 1e-6)), | |
| max_steps=int(config.training.max_train_steps), | |
| warmup_steps=int(config.lr_scheduler.params.get("warmup_steps", 500)), | |
| ) | |
| self.scheduler_a = CosineLR( | |
| lr_max=float(config.optimizer_aux.params.lr), | |
| lr_min=float(config.lr_scheduler.params.get("lr_min", 1e-6)), | |
| max_steps=int(config.training.max_train_steps), | |
| warmup_steps=int(config.lr_scheduler.params.get("warmup_steps", 500)), | |
| ) | |
| # -------- Dataset / DataLoader ------------------------------------ | |
| dataloader_cfg = config.get("prompt_dataloader", {}) | |
| dataset = PromptDataset( | |
| prompt_source=str(cfg.prompt_source), | |
| shuffle_files=bool(dataloader_cfg.get("shuffle_files", True)), | |
| shuffle_buffer=int(dataloader_cfg.get("shuffle_buffer", 0)), | |
| seed=int(config.training.seed), | |
| infinite=True, | |
| csv=CSVSpec(caption_field=str(dataloader_cfg.get("caption_field", "caption"))), | |
| ) | |
| # collate_fn: tokenize on CPU (no CUDA in workers). | |
| collate_fn = make_collate_fn( | |
| self.pipe.tokenizer, | |
| max_prompt_length=int(cfg.max_prompt_length), | |
| device=torch.device("cpu"), # CPU output — moved to GPU in run_step | |
| ) | |
| loader = DataLoader( | |
| dataset, | |
| batch_size=int(cfg.batch_size_per_gpu), | |
| shuffle=False, # IterableDataset: no shuffle flag | |
| drop_last=True, | |
| num_workers=int(dataloader_cfg.get("num_workers", 2)), | |
| collate_fn=collate_fn, | |
| pin_memory=True, | |
| ) | |
| # DataLoader is NOT prepared by accelerate because PromptDataset | |
| # handles per-rank file sharding internally via torch.distributed. | |
| self._inf_loader = InfiniteDataLoader(loader) | |
| # -------- Training state ------------------------------------------ | |
| # self.global_step = int(config.experiment.get("resume_iter", 0)) | |
| self.baseline_ema: float = 0.0 | |
| self.x_hat_prev: Optional[torch.Tensor] = None | |
| self.metrics = collections.OrderedDict() | |
| # -------- Verified regime validation -------------------------------- | |
| native = VERIFIED_NATIVE_DEFAULTS | |
| is_native = check_verified_regime( | |
| height=int(cfg.height), | |
| width=int(cfg.width), | |
| num_frames=int(cfg.num_frames), | |
| guidance_scale=float(cfg.teacher_cfg_scale) if cfg.enable_teacher_cfg else None, | |
| label="train", | |
| ) | |
| logger.info( | |
| f"[init] verified_native_regime={is_native} " | |
| f"geometry=({cfg.num_frames}×{cfg.height}×{cfg.width}) " | |
| f"teacher_cfg_scale={cfg.teacher_cfg_scale if cfg.enable_teacher_cfg else 'OFF'}" | |
| ) | |
| if not cfg.enable_teacher_cfg: | |
| logger.warning( | |
| "[WARN] Teacher CFG is DISABLED. no_cfg is known to produce " | |
| "blank/blurry output for this URSA checkpoint. " | |
| "Distillation without CFG is unlikely to produce useful results." | |
| ) | |
| elif float(cfg.teacher_cfg_scale) != native["guidance_scale"]: | |
| logger.warning( | |
| f"[WARN] teacher_cfg_scale={cfg.teacher_cfg_scale} differs from " | |
| f"the verified working value ({native['guidance_scale']}). " | |
| "Outputs may deviate from the official inference working point." | |
| ) | |
| logger.info( | |
| f"[init] student params: {engine_utils.count_params(self.pipe.student):.2f}M" | |
| ) | |
| logger.info( | |
| f"[init] max_train_steps={config.training.max_train_steps} " | |
| f"batch_size_per_gpu={cfg.batch_size_per_gpu} " | |
| f"num_processes={accelerator.num_processes}" | |
| ) | |
| # ----------------------------------------------------------------------- | |
| # run_step: Stages 1-9 | |
| # ----------------------------------------------------------------------- | |
| def run_step(self, step: int) -> dict: | |
| """Execute one distillation step (Stages 1-9).""" | |
| cfg = self.config.distill | |
| T, H, W = self.latents_shape | |
| N, K = self.N, self.K | |
| device = self.device | |
| stats = {"step": step} | |
| timer = profiler.Timer().tic() | |
| # Update LR from cosine schedulers. | |
| # param_groups[0] = student, param_groups[1] = aux | |
| lr_s = self.scheduler_s.get_lr() | |
| lr_a = self.scheduler_a.get_lr() | |
| stats["lr_student"] = lr_s | |
| stats["lr_aux"] = lr_a | |
| self.optimizer.param_groups[0]["lr"] = lr_s | |
| self.optimizer.param_groups[1]["lr"] = lr_a | |
| # ---------------------------------------------------------------- | |
| # Stage 1: Get tokenised batch (CPU → GPU) | |
| # ---------------------------------------------------------------- | |
| txt_ids = next(self._inf_loader) # [B, L] CPU tensor | |
| txt_ids = txt_ids.to(device, non_blocking=True) | |
| B = txt_ids.size(0) | |
| txt_uncond = None | |
| if cfg.enable_teacher_cfg: | |
| txt_uncond = self.txt_uncond_base_cpu.expand(B, -1).to(device) | |
| # # ---------------------------------------------------------------- | |
| # # Stage 2: Sample x_init ~ Uniform(K) with optional p_init mixing | |
| # # ---------------------------------------------------------------- | |
| # x_init = torch.randint(0, K, (B, T, H, W), device=device, dtype=torch.long) | |
| # if self.x_hat_prev is not None and float(cfg.p_init_mix_ratio) > 0: | |
| # n_mix = max(1, int(B * float(cfg.p_init_mix_ratio))) | |
| # x_init[:n_mix] = self.pipe.corrupt_tokens( | |
| # self.x_hat_prev[:n_mix], r=float(cfg.p_mix_corrupt_frac) | |
| # ) | |
| # ---------------------------------------------------------------- | |
| # Stage 2: Sample x_init ~ Uniform(K) with optional p_init mixing | |
| # ---------------------------------------------------------------- | |
| x_init = torch.randint(0, K, (B, T, H, W), device=device, dtype=torch.long) | |
| # 修复:使用概率触发,确保小 Batch 时模型依然能充分学习处理纯噪声 | |
| if self.x_hat_prev is not None and float(cfg.p_init_mix_ratio) > 0: | |
| if torch.rand(1).item() < float(cfg.p_init_mix_ratio): | |
| # 如果触发,只混合 batch 里的第一个样本 | |
| x_init[0] = self.pipe.corrupt_tokens( | |
| self.x_hat_prev[0:1], r=float(cfg.p_mix_corrupt_frac) | |
| ).squeeze(0) | |
| # ---------------------------------------------------------------- | |
| # Stage 3: Student 1-step on x_init — no_grad (only sample x_hat) | |
| # | |
| # Gradient-enabled forward on x_init is deferred to Stage 9b so | |
| # the KD computation graph (Stage 7, x_t) can be freed first. | |
| # ---------------------------------------------------------------- | |
| with torch.no_grad(): | |
| ids_init, rpos_init, _ = self.pipe.build_inputs( | |
| txt_ids, x_init, self.latents_shape | |
| ) | |
| logits_s_init = _get_logits( | |
| self.model("student", ids_init, rope_pos=rpos_init) | |
| ) | |
| z_s = self.pipe.extract_logits(logits_s_init, N) # [B, N, K] | |
| p_s = F.softmax(z_s / float(cfg.tau), dim=-1) # [B, N, K] | |
| x_hat = torch.multinomial(p_s.view(-1, K), 1).view(B, N) # [B, N] | |
| # if step == 1: | |
| # # 只抽 8 个 token 做 sum=1 检查,别全量 | |
| # idx = torch.randint(0, N, (8,), device=device) | |
| # p_err = (p_s[:, idx].sum(-1) - 1).abs().max().item() | |
| # assert p_err < 1e-3, f"p_s subset not normalised: {p_err}" | |
| del p_s, z_s, logits_s_init | |
| x_hat_4d = x_hat.view(B, T, H, W) | |
| # ---------------------------------------------------------------- | |
| # Stage 4: Pseudo-intermediate x_t = add_noise(x_hat, t) | |
| # ---------------------------------------------------------------- | |
| t = self.pipe.sample_t_curriculum( | |
| B, device, step, int(cfg.t_curriculum_steps) | |
| ) # [B] float ∈ (0.05, 0.995) | |
| with torch.no_grad(): | |
| x_t = self.pipe.scheduler.add_noise(x_hat_4d, t) # [B,T,H,W] long | |
| # # ---------------------------------------------------------------- | |
| # # Stage 5: Teacher forward — single [2B] forward when CFG enabled | |
| # # ---------------------------------------------------------------- | |
| # with torch.no_grad(): | |
| # if cfg.enable_teacher_cfg: | |
| # txt_dual = torch.cat([txt_ids, txt_uncond], dim=0) # [2B, L] | |
| # x_t_dual = torch.cat([x_t, x_t], dim=0) # [2B,T,H,W] | |
| # ids_dual, rpos_dual, _ = self.pipe.build_inputs( | |
| # txt_dual, x_t_dual, self.latents_shape | |
| # ) | |
| # logits_T_dual = _get_logits( | |
| # self.pipe.teacher(ids_dual, rope_pos=rpos_dual) | |
| # ) | |
| # z_T_dual = self.pipe.extract_logits(logits_T_dual, N) # [2B,N,K] | |
| # z_T_cond, z_T_uncond = z_T_dual.chunk(2, dim=0) # [B,N,K] | |
| # del logits_T_dual, z_T_dual | |
| # torch.cuda.empty_cache() | |
| # ids_t, rpos_t = ids_dual[:B], rpos_dual[:B] | |
| # else: | |
| # ids_t, rpos_t, _ = self.pipe.build_inputs( | |
| # txt_ids, x_t, self.latents_shape | |
| # ) | |
| # logits_T = _get_logits( | |
| # self.pipe.teacher(ids_t, rope_pos=rpos_t) | |
| # ) | |
| # z_T_cond = self.pipe.extract_logits(logits_T, N) # [B,N,K] | |
| # z_T_uncond = None | |
| # ids_dual, rpos_dual = ids_t, rpos_t | |
| # # CFG guided target with per-sample Bernoulli warmup. | |
| # z_T_guided = None | |
| # use_guided_ratio = 0.0 | |
| # if cfg.enable_teacher_cfg: | |
| # p_guided = _cfg_warmup_prob( | |
| # step, | |
| # float(cfg.teacher_cfg_prob), | |
| # int(cfg.teacher_cfg_warmup_steps), | |
| # ) | |
| # use_guided = torch.rand(B, device=device) < p_guided # [B] bool | |
| # use_guided_ratio = float(use_guided.float().mean().item()) | |
| # z_T_guided = _build_guided_logits( | |
| # z_T_cond, z_T_uncond, | |
| # t, float(cfg.teacher_cfg_scale), float(cfg.teacher_cfg_trunc), | |
| # ) | |
| # mask = use_guided.view(-1, 1, 1).expand_as(z_T_cond) | |
| # z_T_target = torch.where(mask, z_T_guided, z_T_cond.float()) | |
| # else: | |
| # z_T_target = z_T_cond | |
| # z_T_target = z_T_target.detach() # NO grad path to teacher | |
| # # # ---------------------------------------------------------------- | |
| # # # Stage 6: Aux update — fake_rounds iterations | |
| # # # | |
| # # # Freeze student so only aux gets gradients. With a single | |
| # # # DeepSpeed-wrapped optimizer this is the cleanest way to ensure | |
| # # # only aux params are updated. | |
| # # # ---------------------------------------------------------------- | |
| # # raw_twin = self.accelerator.unwrap_model(self.model) | |
| # # raw_twin.student.requires_grad_(False) | |
| # # raw_twin.aux.requires_grad_(True) | |
| # # loss_aux_cond_last = torch.tensor(0.0, device=device) | |
| # # loss_aux_uncond_last = torch.tensor(0.0, device=device) | |
| # # loss_aux_cond_sample_last = None | |
| # # for _fr in range(int(cfg.fake_rounds)): | |
| # # self.optimizer.zero_grad(set_to_none=True) | |
| # # if cfg.enable_teacher_cfg: | |
| # # logits_A_dual = _get_logits( | |
| # # self.model("aux", ids_dual.detach(), rope_pos=rpos_dual.detach()) | |
| # # ) | |
| # # z_A_dual = self.pipe.extract_logits(logits_A_dual, N) # [2B,N,K] | |
| # # z_A_cond, z_A_uncond = z_A_dual.chunk(2, dim=0) | |
| # # loss_aux_cond_sample = _stable_jeffrey( | |
| # # z_T_target, z_A_cond, float(cfg.tau_kd),chunk_size=1024 | |
| # # ) # [B] | |
| # # loss_aux_cond_v = loss_aux_cond_sample.mean() | |
| # # loss_aux_uncond_v = _stable_jeffrey( | |
| # # z_T_uncond.float().detach(), z_A_uncond, float(cfg.tau_kd),chunk_size=1024 | |
| # # ).mean() | |
| # # loss_aux_v = ( | |
| # # loss_aux_cond_v | |
| # # + float(cfg.lambda_kd_uncond) * loss_aux_uncond_v | |
| # # ) | |
| # # else: | |
| # # logits_A = _get_logits( | |
| # # self.model("aux", ids_t.detach(), rope_pos=rpos_t.detach()) | |
| # # ) | |
| # # z_A_cond = self.pipe.extract_logits(logits_A, N) | |
| # # loss_aux_cond_sample = _stable_jeffrey( | |
| # # z_T_target, z_A_cond, float(cfg.tau_kd),chunk_size=1024 | |
| # # ) | |
| # # loss_aux_cond_v = loss_aux_cond_sample.mean() | |
| # # loss_aux_uncond_v = torch.tensor(0.0, device=device) | |
| # # loss_aux_v = loss_aux_cond_v | |
| # # self.accelerator.backward(loss_aux_v) | |
| # # if float(cfg.grad_clip) > 0: | |
| # # torch.nn.utils.clip_grad_norm_( | |
| # # raw_twin.aux.parameters(), float(cfg.grad_clip) | |
| # # ) | |
| # # self.optimizer.step() | |
| # # self.optimizer.zero_grad(set_to_none=True) | |
| # # loss_aux_cond_last = loss_aux_cond_v.detach() | |
| # # loss_aux_uncond_last = loss_aux_uncond_v.detach() | |
| # # loss_aux_cond_sample_last = loss_aux_cond_sample.detach() # [B] | |
| # # # ---------------------------------------------------------------- | |
| # # # Stage 7: Student KD forward on x_t (with grad) | |
| # # # | |
| # # # Switch: freeze aux, unfreeze student for Stages 7-9. | |
| # # # ---------------------------------------------------------------- | |
| # # raw_twin.student.requires_grad_(True) | |
| # # raw_twin.aux.requires_grad_(False) | |
| # # self.optimizer.zero_grad(set_to_none=True) | |
| # # if cfg.enable_teacher_cfg: | |
| # # logits_S_dual = _get_logits( | |
| # # self.model("student", ids_dual.detach(), rope_pos=rpos_dual.detach()) | |
| # # ) | |
| # # z_S_dual = self.pipe.extract_logits(logits_S_dual, N) | |
| # # z_S_cond, z_S_uncond = z_S_dual.chunk(2, dim=0) | |
| # # loss_kd_cond = _stable_kl( | |
| # # z_T_target, z_S_cond, float(cfg.tau_kd), chunk_size=2048 | |
| # # ).mean() | |
| # # loss_kd_uncond = _stable_kl( | |
| # # z_T_uncond.float().detach(), z_S_uncond, float(cfg.tau_kd), chunk_size=2048 | |
| # # ).mean() | |
| # # loss_kd = loss_kd_cond + float(cfg.lambda_kd_uncond) * loss_kd_uncond | |
| # # else: | |
| # # logits_S = _get_logits( | |
| # # self.model("student", ids_t.detach(), rope_pos=rpos_t.detach()) | |
| # # ) | |
| # # z_S_cond = self.pipe.extract_logits(logits_S, N) | |
| # # loss_kd_cond = _stable_kl( | |
| # # z_T_target, z_S_cond, float(cfg.tau_kd), chunk_size=2048 | |
| # # ).mean() | |
| # # loss_kd_uncond = torch.tensor(0.0, device=device) | |
| # # loss_kd = loss_kd_cond | |
| # # # ---------------------------------------------------------------- | |
| # # # Stage 8: Reward + advantage (fully detached — no student grad) | |
| # # # | |
| # # # INVARIANT: reward and adv must never carry student gradients. | |
| # # # ---------------------------------------------------------------- | |
| # # if cfg.enable_teacher_cfg and cfg.reward_use_guided: | |
| # # z_T_for_rew = z_T_target # already detached | |
| # # else: | |
| # # z_T_for_rew = z_T_cond.detach() | |
| # # # reward[b] = -KL(z_T_cond || z_S_cond) with BOTH inputs detached | |
| # # with torch.no_grad(): | |
| # # reward = -_stable_kl( | |
| # # z_T_for_rew.detach(), z_S_cond.detach(), float(cfg.tau), chunk_size=1024 | |
| # # ) # [B] | |
| # # assert not reward.requires_grad, ( | |
| # # "[BUG] reward.requires_grad=True — student grad leaked into reward. " | |
| # # "z_S_cond must be detached before KL for reward." | |
| # # ) | |
| # # self.baseline_ema = ( | |
| # # 0.99 * self.baseline_ema + 0.01 * float(reward.mean().item()) | |
| # # ) | |
| # # adv = (reward - self.baseline_ema).detach() # [B] | |
| # # assert not adv.requires_grad, "[BUG] adv.requires_grad=True" | |
| # # # ---------------------------------------------------------------- | |
| # # # Stage 9: Two-backward student update | |
| # # # | |
| # # # 9a) KD backward first — frees the KD graph to save memory. | |
| # # # Uses no_sync() (no DDP all-reduce) so gradients are not | |
| # # # double-reduced when the PG backward syncs in 9b. | |
| # # # 9b) Fresh forward on x_init WITH grad → PG + entropy backward. | |
| # # # DDP all-reduce happens here (normal backward). | |
| # # # ---------------------------------------------------------------- | |
| # # # 9a: KD backward (no sync — first of two backwards) | |
| # # _no_sync_backward( | |
| # # self.accelerator, self.model, float(cfg.lambda_kd) * loss_kd | |
| # # ) | |
| # # # 9b: Policy + entropy — fresh forward on x_init WITH grad | |
| # # ids_init, rpos_init, _ = self.pipe.build_inputs( | |
| # # txt_ids, x_init, self.latents_shape | |
| # # ) | |
| # # logits_s_pol = _get_logits( | |
| # # self.model("student", ids_init, rope_pos=rpos_init) | |
| # # ) | |
| # # z_s_pol = self.pipe.extract_logits(logits_s_pol, N) # [B, N, K] | |
| # # logp_tok = F.log_softmax(z_s_pol / float(cfg.tau), dim=-1) # [B, N, K] | |
| # # p_s_pol = logp_tok.exp() | |
| # # # per-token average log-prob (recommended over log-prob sum) | |
| # # logp_sum = ( | |
| # # logp_tok.gather(-1, x_hat.unsqueeze(-1)).squeeze(-1).sum(-1) | |
| # # ) # [B] | |
| # # logp = logp_sum / N # [B] per-token logp | |
| # # H_mean = -(p_s_pol * logp_tok).sum(-1).mean() | |
| # # loss_pg = -(adv * logp).mean() | |
| # # lambda_ent_eff = float(cfg.lambda_ent) * (1.0 + 2.0 * use_guided_ratio) | |
| # # # Second backward: DDP all-reduce happens here. | |
| # # self.accelerator.backward( | |
| # # float(cfg.lambda_pg) * loss_pg - lambda_ent_eff * H_mean | |
| # # ) | |
| # # if float(cfg.grad_clip) > 0: | |
| # # torch.nn.utils.clip_grad_norm_( | |
| # # raw_twin.student.parameters(), float(cfg.grad_clip) | |
| # # ) | |
| # # self.optimizer.step() | |
| # # # Restore both sub-modules to trainable for next step. | |
| # # raw_twin.student.requires_grad_(True) | |
| # # raw_twin.aux.requires_grad_(True) | |
| # # # p_init mixing: store x_hat_4d (detached) for next step. | |
| # # self.x_hat_prev = x_hat_4d.detach() | |
| # # ---------------------------------------------------------------- | |
| # # Stage 6: Aux update — Fit sampled pseudo-target (x_hat) from student | |
| # # ---------------------------------------------------------------- | |
| # raw_twin = self.accelerator.unwrap_model(self.model) | |
| # raw_twin.student.requires_grad_(False) | |
| # raw_twin.aux.requires_grad_(True) | |
| # target_tokens = x_hat.detach() # [B, N] - 学生在 Stage 3 盲猜出来的画面 | |
| # for _fr in range(int(cfg.fake_rounds)): | |
| # self.optimizer.zero_grad(set_to_none=True) | |
| # if cfg.enable_teacher_cfg: | |
| # logits_A_dual = _get_logits( | |
| # self.model("aux", ids_dual.detach(), rope_pos=rpos_dual.detach()) | |
| # ) | |
| # z_A_dual = self.pipe.extract_logits(logits_A_dual, N) # [2B,N,K] | |
| # z_A_cond, z_A_uncond = z_A_dual.chunk(2, dim=0) | |
| # # Aux 拟合学生的假 token (Cross Entropy) | |
| # loss_aux_cond_v = F.cross_entropy( | |
| # z_A_cond.reshape(B * N, K), | |
| # target_tokens.reshape(B * N), | |
| # reduction="mean", | |
| # ) | |
| # loss_aux_uncond_v = F.cross_entropy( | |
| # z_A_uncond.reshape(B * N, K), | |
| # target_tokens.reshape(B * N), | |
| # reduction="mean", | |
| # ) | |
| # loss_aux_v = loss_aux_cond_v + float(cfg.lambda_kd_uncond) * loss_aux_uncond_v | |
| # else: | |
| # logits_A = _get_logits( | |
| # self.model("aux", ids_t.detach(), rope_pos=rpos_t.detach()) | |
| # ) | |
| # z_A_cond = self.pipe.extract_logits(logits_A, N) | |
| # loss_aux_cond_v = F.cross_entropy( | |
| # z_A_cond.reshape(B * N, K), | |
| # target_tokens.reshape(B * N), | |
| # reduction="mean", | |
| # ) | |
| # loss_aux_uncond_v = torch.tensor(0.0, device=device) | |
| # loss_aux_v = loss_aux_cond_v | |
| # self.accelerator.backward(loss_aux_v) | |
| # if float(cfg.grad_clip) > 0: | |
| # torch.nn.utils.clip_grad_norm_( | |
| # raw_twin.aux.parameters(), float(cfg.grad_clip) | |
| # ) | |
| # self.optimizer.step() | |
| # loss_aux_cond_last = loss_aux_cond_v.detach() | |
| # # ---------------------------------------------------------------- | |
| # # Stage 7 & 8: Student KD update & Aux Bridge (Gradient Injection) | |
| # # ---------------------------------------------------------------- | |
| # raw_twin.student.requires_grad_(True) | |
| # raw_twin.aux.requires_grad_(False) | |
| # self.optimizer.zero_grad(set_to_none=True) | |
| # # 7a. Student KD forward on x_t (保持原样) | |
| # if cfg.enable_teacher_cfg: | |
| # logits_S_dual = _get_logits( | |
| # self.model("student", ids_dual.detach(), rope_pos=rpos_dual.detach()) | |
| # ) | |
| # z_S_dual = self.pipe.extract_logits(logits_S_dual, N) | |
| # z_S_cond, z_S_uncond = z_S_dual.chunk(2, dim=0) | |
| # # --- [新增] 立刻释放显存 --- | |
| # del logits_S_dual, z_S_dual | |
| # loss_kd_cond = _stable_kl( | |
| # z_T_target, z_S_cond, float(cfg.tau_kd), chunk_size=256 #2048 | |
| # ).mean() | |
| # loss_kd_uncond = _stable_kl( | |
| # z_T_uncond.float().detach(), z_S_uncond, float(cfg.tau_kd), chunk_size=256 #2048 | |
| # ).mean() | |
| # loss_kd = loss_kd_cond + float(cfg.lambda_kd_uncond) * loss_kd_uncond | |
| # else: | |
| # logits_S = _get_logits( | |
| # self.model("student", ids_t.detach(), rope_pos=rpos_t.detach()) | |
| # ) | |
| # z_S_cond = self.pipe.extract_logits(logits_S, N) | |
| # loss_kd_cond = _stable_kl( | |
| # z_T_target, z_S_cond, float(cfg.tau_kd), chunk_size=256 #2048 | |
| # ).mean() | |
| # loss_kd_uncond = torch.tensor(0.0, device=device) | |
| # loss_kd = loss_kd_cond | |
| # # 7b. 获取 Aux 的预测 (无梯度) 作为计算桥梁 | |
| # with torch.no_grad(): | |
| # if cfg.enable_teacher_cfg: | |
| # logits_A_dual = _get_logits( | |
| # self.model("aux", ids_dual.detach(), rope_pos=rpos_dual.detach()) | |
| # ) | |
| # z_A_dual = self.pipe.extract_logits(logits_A_dual, N) | |
| # z_A_cond, _ = z_A_dual.chunk(2, dim=0) | |
| # # --- [新增] 立刻释放显存 --- | |
| # del logits_A_dual, z_A_dual | |
| # else: | |
| # logits_A = _get_logits( | |
| # self.model("aux", ids_t.detach(), rpos_t.detach()) | |
| # ) | |
| # z_A_cond = self.pipe.extract_logits(logits_A, N) | |
| # # 8. Student 对初始噪声 x_init 进行带梯度的前向传播 | |
| # ids_init, rpos_init, _ = self.pipe.build_inputs( | |
| # txt_ids, x_init, self.latents_shape | |
| # ) | |
| # logits_s_pol = _get_logits( | |
| # self.model("student", ids_init, rope_pos=rpos_init) | |
| # ) | |
| # z_s_pol = self.pipe.extract_logits(logits_s_pol, N) | |
| # # --- 核心数学修正:将 Logits 转换为概率,防止梯度爆炸 --- | |
| # p_T = F.softmax(z_T_target / float(cfg.tau_kd), dim=-1) | |
| # p_A = F.softmax(z_A_cond / float(cfg.tau_kd), dim=-1) | |
| # # 目标方向:Teacher 概率 - Aux 概率 (遵循论文公式推导) | |
| # bridge_target = (p_T - p_A).detach() | |
| # # 利用 MSE Trick 强制注入梯度 | |
| # loss_bridge = 0.5 * F.mse_loss( | |
| # z_s_pol.float(), | |
| # (z_s_pol.float() + bridge_target).detach() | |
| # ) | |
| # # 9. 单次反向传播 (合并 KD 和 Bridge) | |
| # # 借用原来的 lambda_pg 参数来控制 bridge 损失的权重 | |
| # loss_student = float(cfg.lambda_kd) * loss_kd + float(cfg.lambda_pg) * loss_bridge | |
| # self.accelerator.backward(loss_student) | |
| # if float(cfg.grad_clip) > 0: | |
| # torch.nn.utils.clip_grad_norm_( | |
| # raw_twin.student.parameters(), float(cfg.grad_clip) | |
| # ) | |
| # self.optimizer.step() | |
| # # 恢复两者的可训练状态 | |
| # raw_twin.student.requires_grad_(True) | |
| # raw_twin.aux.requires_grad_(True) | |
| # # --- 兼容原始日志输出的占位符 --- | |
| # H_mean = torch.tensor(0.0, device=device) | |
| # loss_pg = loss_bridge.detach() # 将 bridge 损失映射给 pg 显示 | |
| # logp = torch.tensor(0.0, device=device) | |
| # self.baseline_ema = 0.0 | |
| # ---------------------------------------------------------------- | |
| # Stage 5: Teacher forward — 破除视图死锁,生成目标后立刻释放 | |
| # ---------------------------------------------------------------- | |
| with torch.no_grad(): | |
| if cfg.enable_teacher_cfg: | |
| txt_dual = torch.cat([txt_ids, txt_uncond], dim=0) # [2B, L] | |
| x_t_dual = torch.cat([x_t, x_t], dim=0) # [2B,T,H,W] | |
| ids_dual, rpos_dual, _ = self.pipe.build_inputs( | |
| txt_dual, x_t_dual, self.latents_shape | |
| ) | |
| logits_T_dual = _get_logits( | |
| self.pipe.teacher(ids_dual, rope_pos=rpos_dual) | |
| ) | |
| z_T_dual = self.pipe.extract_logits(logits_T_dual, N) # [2B,N,K] | |
| # 【显存救星 1】使用 .clone() 打断视图依赖,使得原始巨型张量可以被回收 | |
| z_T_cond = z_T_dual[0:1].clone() # [1,N,K] | |
| z_T_uncond = z_T_dual[1:2].clone() # [1,N,K] | |
| ids_t, rpos_t = ids_dual[:B], rpos_dual[:B] | |
| # 立刻释放 17 GB 的双路缓冲 | |
| del logits_T_dual, z_T_dual | |
| torch.cuda.empty_cache() | |
| else: | |
| ids_t, rpos_t, _ = self.pipe.build_inputs(txt_ids, x_t, self.latents_shape) | |
| logits_T = _get_logits(self.pipe.teacher(ids_t, rope_pos=rpos_t)) | |
| z_T_cond = self.pipe.extract_logits(logits_T, N) | |
| z_T_uncond = None | |
| # 计算 CFG guided target | |
| z_T_guided = None | |
| use_guided_ratio = 0.0 | |
| if cfg.enable_teacher_cfg: | |
| p_guided = _cfg_warmup_prob(step, float(cfg.teacher_cfg_prob), int(cfg.teacher_cfg_warmup_steps)) | |
| use_guided = torch.rand(B, device=device) < p_guided | |
| use_guided_ratio = float(use_guided.float().mean().item()) | |
| z_T_guided = _build_guided_logits( | |
| z_T_cond, z_T_uncond, | |
| t, float(cfg.teacher_cfg_scale), float(cfg.teacher_cfg_trunc), | |
| ) | |
| mask = use_guided.view(-1, 1, 1).expand_as(z_T_cond) | |
| # 【显存救星 2】保持为 bf16 类型,避免膨胀到 8.5GB | |
| z_T_target = torch.where(mask, z_T_guided, z_T_cond).to(dtype=z_T_cond.dtype).detach() | |
| # 立刻清理所有中间推导变量 | |
| del z_T_cond, z_T_uncond, z_T_guided | |
| torch.cuda.empty_cache() | |
| else: | |
| z_T_target = z_T_cond.detach() | |
| # ---------------------------------------------------------------- | |
| # Stage 6: Aux update — 【显存救星 3】强行降维为单路前向传播 (Batch=1) | |
| # ---------------------------------------------------------------- | |
| raw_twin = self.accelerator.unwrap_model(self.model) | |
| raw_twin.student.requires_grad_(False) | |
| raw_twin.aux.requires_grad_(True) | |
| target_tokens = x_hat.detach() | |
| for _fr in range(int(cfg.fake_rounds)): | |
| self.optimizer.zero_grad(set_to_none=True) | |
| # 只处理单路 ids_t,不处理 dual,砍掉 Aux 50% 显存! | |
| logits_A = _get_logits( | |
| self.model("aux", ids_t.detach(), rope_pos=rpos_t.detach()) | |
| ) | |
| z_A_cond = self.pipe.extract_logits(logits_A, N) | |
| loss_aux_cond_v = F.cross_entropy( | |
| z_A_cond.reshape(B * N, K), | |
| target_tokens.reshape(B * N), | |
| reduction="mean", | |
| ) | |
| self.accelerator.backward(loss_aux_cond_v) | |
| if float(cfg.grad_clip) > 0: | |
| torch.nn.utils.clip_grad_norm_(raw_twin.aux.parameters(), float(cfg.grad_clip)) | |
| self.optimizer.step() | |
| # 必须立刻释放 | |
| del logits_A, z_A_cond | |
| torch.cuda.empty_cache() | |
| loss_aux_cond_last = loss_aux_cond_v.detach() | |
| # ---------------------------------------------------------------- | |
| # Stage 7 & 8: Student KD update & Aux Bridge | |
| # ---------------------------------------------------------------- | |
| raw_twin.student.requires_grad_(True) | |
| raw_twin.aux.requires_grad_(False) | |
| self.optimizer.zero_grad(set_to_none=True) | |
| # 7a. Student KD (强行降维为单路前向传播 Batch=1) | |
| logits_S = _get_logits( | |
| self.model("student", ids_t.detach(), rope_pos=rpos_t.detach()) | |
| ) | |
| z_S_cond = self.pipe.extract_logits(logits_S, N) | |
| # 使用 128 chunk size,确保极致安全 | |
| loss_kd = _stable_kl( | |
| z_T_target, z_S_cond, float(cfg.tau_kd), chunk_size=128 | |
| ).mean() | |
| del logits_S, z_S_cond | |
| torch.cuda.empty_cache() | |
| # 7b. 获取 Aux 的预测作为桥梁 (依然单路) | |
| with torch.no_grad(): | |
| logits_A = _get_logits( | |
| self.model("aux", ids_t.detach(), rope_pos=rpos_t.detach()) | |
| ) | |
| z_A_cond = self.pipe.extract_logits(logits_A, N) | |
| # 8. Student 对 x_init 进行前向传播 | |
| ids_init, rpos_init, _ = self.pipe.build_inputs(txt_ids, x_init, self.latents_shape) | |
| logits_s_pol = _get_logits( | |
| self.model("student", ids_init, rope_pos=rpos_init) | |
| ) | |
| z_s_pol = self.pipe.extract_logits(logits_s_pol, N) | |
| # 【显存救星 4】在 bf16 精度下计算 Softmax 概率,防止 float32 炸存 | |
| p_T = F.softmax(z_T_target / float(cfg.tau_kd), dim=-1).to(z_s_pol.dtype) | |
| p_A = F.softmax(z_A_cond / float(cfg.tau_kd), dim=-1).to(z_s_pol.dtype) | |
| bridge_target = (p_T - p_A).detach() | |
| # 拿到 bridge_target 后,前面所有百兆甚至 G 级的张量统统干掉 | |
| del p_T, p_A, logits_A, z_A_cond, z_T_target | |
| torch.cuda.empty_cache() | |
| # 伪梯度注入 | |
| loss_bridge = 0.5* K * F.mse_loss( | |
| z_s_pol.float(), | |
| (z_s_pol.float() + bridge_target.float()).detach() | |
| ) | |
| # 9. 统一反向传播 | |
| loss_student = float(cfg.lambda_kd) * loss_kd + float(cfg.lambda_pg) * loss_bridge | |
| self.accelerator.backward(loss_student) | |
| if float(cfg.grad_clip) > 0: | |
| torch.nn.utils.clip_grad_norm_(raw_twin.student.parameters(), float(cfg.grad_clip)) | |
| self.optimizer.step() | |
| # 恢复状态 | |
| raw_twin.student.requires_grad_(True) | |
| raw_twin.aux.requires_grad_(True) | |
| # 最后的清理 | |
| del logits_s_pol, z_s_pol, bridge_target | |
| torch.cuda.empty_cache() | |
| H_mean = torch.tensor(0.0, device=device) | |
| loss_pg = loss_bridge.detach() | |
| logp = torch.tensor(0.0, device=device) | |
| self.baseline_ema = 0.0 | |
| # Advance LR schedulers. | |
| self.scheduler_s.step() | |
| self.scheduler_a.step() | |
| # ---------------------------------------------------------------- | |
| # Step 1 sanity assertions (lightweight; runs only at step 1) | |
| # ---------------------------------------------------------------- | |
| # if step == 1: | |
| # self._step1_assertions( | |
| # x_init, ids_init, rpos_init, z_s, p_s, logp, | |
| # z_T_cond, z_S_cond, x_t, B, T, H, W, | |
| # ) | |
| # ---------------------------------------------------------------- | |
| # Token-level collapse detection | |
| # ---------------------------------------------------------------- | |
| tok_entropy = self._token_entropy(x_hat) | |
| if not hasattr(self, "_init_tok_entropy"): | |
| self._init_tok_entropy = tok_entropy | |
| collapse_frac = float(cfg.get("collapse_warn_frac", 0.2)) | |
| if tok_entropy < collapse_frac * self._init_tok_entropy: | |
| self.logger.warning( | |
| f"[COLLAPSE] step={step} tok_H={tok_entropy:.3f} " | |
| f"init={self._init_tok_entropy:.3f} " | |
| f"ratio={tok_entropy / max(self._init_tok_entropy, 1e-8):.2f} " | |
| f"< {collapse_frac}. Try increasing lambda_ent." | |
| ) | |
| stats["time"] = timer.toc() | |
| stats["metrics"] = collections.OrderedDict( | |
| sorted( | |
| { | |
| "loss_aux_cond": float(loss_aux_cond_last.item()), | |
| "loss_kd_cond": float(loss_kd.item()), | |
| "loss_pg": float(loss_pg.item()), | |
| "H_mean": float(H_mean.item()), | |
| "tok_entropy": float(tok_entropy), | |
| "mean_logp_tok": float(logp.mean().item()), | |
| "baseline_ema": float(self.baseline_ema), | |
| "use_guided_ratio": float(use_guided_ratio), | |
| }.items() | |
| ) | |
| ) | |
| return stats | |
| # ----------------------------------------------------------------------- | |
| # Train loop | |
| # ----------------------------------------------------------------------- | |
| def train_loop(self): | |
| """Main training loop (mirrors diffnext.engine.train_engine.Trainer).""" | |
| cfg_exp = self.config.experiment | |
| max_steps = int(self.config.training.max_train_steps) | |
| log_every = int(cfg_exp.log_every) | |
| save_every = int(cfg_exp.save_every) | |
| self.global_step = int(self.config.experiment.get("resume_iter", 0)) | |
| # Sync LR schedulers to resume step (set _step_count directly; | |
| # CosineLR uses _step_count internally in get_decay()). | |
| self.scheduler_s._step_count = self.global_step | |
| self.scheduler_a._step_count = self.global_step | |
| # [可选补充] 如果是续传,让 accelerator 自动恢复被切分的 Optimizer 等状态 | |
| if self.global_step > 0: | |
| ckpt_dir = os.path.join(self.config.experiment.output_dir, "checkpoints", f"checkpoint-{self.global_step}") | |
| if os.path.exists(ckpt_dir): | |
| self.accelerator.load_state(ckpt_dir) | |
| self.logger.info(f"✅ ZeRO-3 完整状态 (包含 Optimizer) 已从 {ckpt_dir} 恢复") | |
| timer = profiler.Timer() | |
| self.logger.info( | |
| f"[train] Starting from step {self.global_step} / {max_steps}" | |
| ) | |
| while self.global_step < max_steps: | |
| self.global_step += 1 | |
| with timer.tic_and_toc(): | |
| stats = self.run_step(self.global_step) | |
| self._add_metrics(stats) | |
| if self.global_step % log_every == 0: | |
| self._log_metrics(stats) | |
| if self.global_step % (10 * log_every) == 0: | |
| self.logger.info( | |
| profiler.get_progress(timer, self.global_step, max_steps) | |
| ) | |
| if self.global_step % save_every == 0: | |
| self.save(self.global_step) | |
| # Final log + save (only when loop ran at least one step). | |
| if self.global_step > int(self.config.experiment.get("resume_iter", 0)): | |
| self._log_metrics({**stats, "step": self.global_step}) # noqa: F821 | |
| self.accelerator.wait_for_everyone() | |
| self.save(self.global_step, suffix="final") | |
| self.accelerator.end_training() | |
| # ----------------------------------------------------------------------- | |
| # Checkpoint helpers | |
| # ----------------------------------------------------------------------- | |
| # def save(self, step: int, suffix: str = None) -> None: | |
| # """Save student + aux state_dicts (rank0 only). | |
| # Saved as: | |
| # <output_dir>/checkpoints/checkpoint-<step>/student.pt | |
| # <output_dir>/checkpoints/checkpoint-<step>/aux.pt | |
| # The student.pt can be used for inference by replacing the | |
| # transformer weights in a URSAPipeline (see README). | |
| # """ | |
| # if not self.accelerator.is_main_process: | |
| # return | |
| # folder = f"checkpoint-{suffix}" if suffix else f"checkpoint-{step}" | |
| # ckpt_dir = os.path.join( | |
| # self.config.experiment.output_dir, "checkpoints", folder | |
| # ) | |
| # os.makedirs(ckpt_dir, exist_ok=True) | |
| # raw_student = self.accelerator.unwrap_model(self.model).student | |
| # raw_aux = self.accelerator.unwrap_model(self.model).aux | |
| # student_path = os.path.join(ckpt_dir, "student.pt") | |
| # aux_path = os.path.join(ckpt_dir, "aux.pt") | |
| # torch.save(raw_student.state_dict(), student_path) | |
| # torch.save(raw_aux.state_dict(), aux_path) | |
| # # Also save training state for resuming. | |
| # state = { | |
| # "global_step": step, | |
| # "baseline_ema": self.baseline_ema, | |
| # "optimizer": self.optimizer.state_dict(), | |
| # } | |
| # torch.save(state, os.path.join(ckpt_dir, "train_state.pt")) | |
| # self.logger.info(f"[save] step={step} → {ckpt_dir}") | |
| def save(self, step: int, suffix: str = None) -> None: | |
| """Save student + aux state_dicts (支持 DeepSpeed ZeRO-3 自动聚合).""" | |
| # ⚠️ 【极其重要】:get_state_dict 必须由所有 8 张卡共同执行! | |
| # 绝对不能把它放在 is_main_process 判断的里面,否则会触发跨卡死锁! | |
| full_state_dict = self.accelerator.get_state_dict(self.model) | |
| # 只有主进程(0号卡)负责把聚合好的完整参数写进硬盘 | |
| if not self.accelerator.is_main_process: | |
| return | |
| folder = f"checkpoint-{suffix}" if suffix else f"checkpoint-{step}" | |
| ckpt_dir = os.path.join( | |
| self.config.experiment.output_dir, "checkpoints", folder | |
| ) | |
| os.makedirs(ckpt_dir, exist_ok=True) | |
| # 从 TwinModel 的完整字典中,根据前缀拆分出 student 和 aux 的独立权重 | |
| student_state = {k.replace("student.", ""): v for k, v in full_state_dict.items() if k.startswith("student.")} | |
| aux_state = {k.replace("aux.", ""): v for k, v in full_state_dict.items() if k.startswith("aux.")} | |
| student_path = os.path.join(ckpt_dir, "student.pt") | |
| aux_path = os.path.join(ckpt_dir, "aux.pt") | |
| torch.save(student_state, student_path) | |
| torch.save(aux_state, aux_path) | |
| # 保存辅助状态 | |
| state = { | |
| "global_step": step, | |
| "baseline_ema": self.baseline_ema, | |
| } | |
| torch.save(state, os.path.join(ckpt_dir, "train_state.pt")) | |
| self.logger.info(f"[save] step={step} → {ckpt_dir} (ZeRO-3 Gathered)") | |
| # ----------------------------------------------------------------------- | |
| # Logging helpers (same API as original Trainer) | |
| # ----------------------------------------------------------------------- | |
| def _add_metrics(self, stats: dict) -> None: | |
| for k, v in stats["metrics"].items(): | |
| if k not in self.metrics: | |
| self.metrics[k] = profiler.SmoothedValue() | |
| self.metrics[k].update(v) | |
| def _log_metrics(self, stats: dict) -> None: | |
| iter_template = "Iteration %d, lr_s=%.2e lr_a=%.2e, time=%.2fs" | |
| self.logger.info( | |
| iter_template | |
| % ( | |
| stats["step"], | |
| stats.get("lr_student", 0.0), | |
| stats.get("lr_aux", 0.0), | |
| stats.get("time", 0.0), | |
| ) | |
| ) | |
| metric_template = " Train %s: %s" | |
| for k, v in self.metrics.items(): | |
| self.logger.info(metric_template % (k, v)) | |
| tracker_logs = {k: v.median for k, v in self.metrics.items()} | |
| tracker_logs.update( | |
| { | |
| "lr_student": stats.get("lr_student", 0.0), | |
| "time": stats.get("time", 0.0), | |
| } | |
| ) | |
| self.accelerator.log(tracker_logs, step=stats["step"]) | |
| self.metrics.clear() | |
| # ----------------------------------------------------------------------- | |
| # Sanity checks (step 1 only) | |
| # ----------------------------------------------------------------------- | |
| def _step1_assertions( | |
| self, x_init, ids_init, rpos_init, z_s, p_s, logp, | |
| z_T_cond, z_S_cond, x_t, B, T, H, W, | |
| ) -> None: | |
| """Shape / value-domain assertions (mirrors single-card script).""" | |
| N, K = self.N, self.K | |
| lm_vocab = self.pipe.teacher.config.lm_vocab_size | |
| L_plus_N1 = ids_init.size(1) | |
| txt_len = L_plus_N1 - (N + 1) | |
| assert x_init.dtype == torch.long | |
| assert x_init.min() >= 0 and x_init.max() < K | |
| assert ids_init.shape == (B, L_plus_N1), ids_init.shape | |
| txt_part = ids_init[:, :txt_len] | |
| vis_part = ids_init[:, -N:] | |
| assert (txt_part < lm_vocab).all(), "text tokens in visual range" | |
| assert (vis_part >= lm_vocab).all(), "visual tokens not shifted" | |
| assert (vis_part < lm_vocab + K).all(), "visual tokens exceed lm_vocab+K" | |
| assert rpos_init.shape == (B, L_plus_N1, 3), rpos_init.shape | |
| assert z_s.shape == (B, N, K), z_s.shape | |
| p_err = float((p_s.sum(-1) - 1).abs().max().item()) | |
| assert p_err < 1e-3, f"p_s not normalised: max_dev={p_err:.2e}" | |
| assert not torch.isnan(logp).any(), "logp has NaN" | |
| assert not torch.isinf(logp).any(), "logp has Inf" | |
| assert x_t.min() >= 0 and x_t.max() < K | |
| assert z_T_cond.shape == z_S_cond.shape == (B, N, K), ( | |
| f"z_T_cond={z_T_cond.shape} z_S_cond={z_S_cond.shape}" | |
| ) | |
| # Teacher has no grad. | |
| teacher_grads = [ | |
| p for p in self.pipe.teacher.parameters() if p.grad is not None | |
| ] | |
| assert len(teacher_grads) == 0, "teacher has grads — not frozen" | |
| # Student has grad (from PG backward). | |
| raw_s = self.accelerator.unwrap_model(self.model).student | |
| student_grad_norms = [ | |
| float(p.grad.norm().item()) | |
| for p in raw_s.parameters() | |
| if p.grad is not None | |
| ] | |
| assert len(student_grad_norms) > 0, "student has NO grads — grad flow broken" | |
| # ########################## | |
| # raw_t = self.pipe.teacher | |
| # raw_s = self.accelerator.unwrap_model(self.model).student | |
| # # (a) 共享存储检查:零开销 | |
| # pt0 = next(raw_t.parameters()) | |
| # ps0 = next(raw_s.parameters()) | |
| # self.logger.info(f"[assert] shared_storage={pt0.data_ptr() == ps0.data_ptr()}") | |
| # # (b) 参数差异:只采样前 4096 个元素,避免巨型临时张量 | |
| # with torch.no_grad(): | |
| # a = pt0.view(-1)[:4096].float() | |
| # b = ps0.view(-1)[:4096].float() | |
| # self.logger.info(f"[assert] param_delta_sample_max={float((a-b).abs().max().item()):.3e}") | |
| # # (c) logits 差异:只采样小子块(64 token × 256 vocab) | |
| # with torch.no_grad(): | |
| # idx_n = torch.randint(0, self.N, (64,), device=z_T_cond.device) | |
| # idx_k = torch.randint(0, self.K, (256,), device=z_T_cond.device) | |
| # subT = z_T_cond[0, idx_n][:, idx_k].float() | |
| # subS = z_S_cond[0, idx_n][:, idx_k].float() | |
| # self.logger.info(f"[assert] logits_delta_sub_max={float((subT-subS).abs().max().item()):.3e}") | |
| # ########################### | |
| self.logger.info("[assert] Step-1 shape/grad assertions PASSED ✓") | |
| self.logger.info( | |
| f"[assert] z_T_cond shape={z_T_cond.shape} " | |
| f"min={float(z_T_cond.min().item()):.3f} " | |
| f"max={float(z_T_cond.max().item()):.3f}" | |
| ) | |
| self.logger.info( | |
| f"[assert] z_S_cond shape={z_S_cond.shape} " | |
| f"min={float(z_S_cond.min().item()):.3f} " | |
| f"max={float(z_S_cond.max().item()):.3f}" | |
| ) | |
| def _token_entropy(x_hat: torch.Tensor) -> float: | |
| """Histogram entropy of sampled token indices (collapse detection).""" | |
| counts = x_hat.flatten().bincount(minlength=1).float() | |
| p = counts / counts.sum() | |
| p = p[p > 0] | |
| return float(-(p * p.log()).sum().item()) | |
| def main(): | |
| """Entry point — identical pattern to scripts/train.py.""" | |
| config = omegaconf_utils.get_config() | |
| os.makedirs(config.experiment.output_dir, exist_ok=True) | |
| accelerator = accelerate_utils.build_accelerator(config) | |
| accelerate_utils.build_wandb(config, accelerator=accelerator) | |
| logger = accelerate_utils.set_logger( | |
| config.experiment.output_dir, accelerator=accelerator | |
| ) | |
| device_seed = int(config.training.seed) + accelerator.process_index | |
| engine_utils.manual_seed(device_seed, (accelerator.device.index, device_seed)) | |
| if accelerator.is_main_process: | |
| config_path = os.path.join(config.experiment.output_dir, "config.yaml") | |
| omegaconf_utils.save_config(config, config_path) | |
| logger.info(f"Config:\n{omegaconf_utils.config_to_yaml(config)}") | |
| trainer = DistillTrainer(config, accelerator, logger) | |
| trainer.train_loop() | |
| if __name__ == "__main__": | |
| main() | |