#!/usr/bin/env python3 # ------------------------------------------------------------------------ # Copyright (c) 2024-present, BAAI. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ------------------------------------------------------------------------ """URSA one-step distillation trainer (DiMO-style), 8-GPU distributed. Verified native inference regime (from A/B testing — ground truth): height=320, width=512, num_frames=49, guidance_scale=7, teacher_steps=50. no_cfg (guidance_scale=1) is NOT a valid baseline for this URSA checkpoint. Defaults in configs/distill_dimo.yaml are aligned to this regime. Launch command: accelerate launch --config_file accelerate_configs/deepspeed_zero2.yaml \\ --machine_rank 0 --num_machines 1 --num_processes 8 \\ scripts/train_distill_dimo.py \\ config="./configs/distill_dimo.yaml" \\ experiment.output_dir="./experiments/distill_dimo" \\ distill.teacher_ckpt="/path/to/URSA-1.7B-IBQ1024" \\ distill.prompt_source="/data/Koala_36M_*.csv" \\ distill.batch_size_per_gpu=1 Smoke test (single-GPU, 50 steps): accelerate launch --num_processes 1 \\ scripts/train_distill_dimo.py \\ config="./configs/distill_dimo.yaml" \\ experiment.output_dir="./experiments/smoke" \\ distill.teacher_ckpt="/path/to/URSA-1.7B-IBQ1024" \\ distill.prompt_source="prompts.txt" \\ training.max_train_steps=50 Algorithm summary (9 stages per iteration) ------------------------------------------ Stage 1 Tokenize → txt_ids [B, L] (CPU in worker, moved to GPU in run_step) Stage 2 x_init ~ Uniform(K) (+ p_init mixing from x_hat_prev) Stage 3 no_grad student(x_init) → x_hat [B, N], logp for PG Stage 4 x_t = scheduler.add_noise(x_hat_4d, t) [B,T,H,W], long Stage 5 no_grad teacher(x_t) → z_T_cond [B,N,K] (+ uncond if CFG) Stage 6 aux update × fake_rounds: Jeffrey(z_T_target, z_A_cond).backward() Stage 7 student KD forward on x_t → z_S_cond [B,N,K] Stage 8 reward = -KL(z_T_cond, z_S_cond) [detached]; adv = reward - baseline_ema Stage 9 Two-backward: 9a _no_sync_backward(lambda_kd * loss_kd) [frees KD graph] 9b accelerator.backward(lambda_pg * loss_pg - lambda_ent * H_mean) opt_student.step() """ import collections import copy import os import sys from typing import Optional import torch import torch.nn.functional as F from torch.utils.data import DataLoader _ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) if _ROOT not in sys.path: sys.path.insert(0, _ROOT) from diffnext.engine import engine_utils from diffnext.engine.lr_scheduler import CosineLR from diffnext.pipelines.ursa.pipeline_ursa_distill_dimo import ( URSADistillDiMOPipeline, _get_logits, _stable_kl, _stable_jeffrey, _build_guided_logits, _cfg_warmup_prob, _no_sync_backward, _reset_flex_attn, VERIFIED_NATIVE_DEFAULTS, check_verified_regime, ) from diffnext.utils import accelerate_utils from diffnext.utils import omegaconf_utils from diffnext.utils import profiler from src.distill.prompt_dataset import ( CSVSpec, InfiniteDataLoader, PromptDataset, make_collate_fn, ) # --------------------------------------------------------------------------- # DistillTwinModel — single nn.Module wrapping student + aux for DeepSpeed # --------------------------------------------------------------------------- class DistillTwinModel(torch.nn.Module): """Wrapper that holds both student and aux as sub-modules. DeepSpeed (via Accelerate) only allows a single model in ``accelerator.prepare()``. This container satisfies that constraint while keeping student and aux as separately addressable sub-modules with independent param groups. """ def __init__(self, student: torch.nn.Module, aux: torch.nn.Module): super().__init__() self.student = student self.aux = aux def forward(self, which: str, input_ids, rope_pos=None, **kwargs): if which == "student": return self.student(input_ids, rope_pos=rope_pos, **kwargs) elif which == "aux": return self.aux(input_ids, rope_pos=rope_pos, **kwargs) else: raise ValueError(f"DistillTwinModel: unknown sub-model '{which}'") # --------------------------------------------------------------------------- # DistillTrainer # --------------------------------------------------------------------------- class DistillTrainer: """Training orchestrator for on-policy one-step distillation. Reuses the same accelerate / logger / checkpoint API as ``diffnext.engine.train_engine.Trainer`` so the distributed setup is identical to the original training framework. Key differences from standard Trainer: - Three models (teacher frozen, student + aux trainable) - Student and aux are wrapped in a single ``DistillTwinModel`` so that only one ``accelerator.prepare()`` call is needed (DeepSpeed requirement) - One optimizer with two param_groups: [0]=student, [1]=aux - LR schedulers for both param groups - Two-backward strategy within each step - PromptDataset (no video latents; prompt-only) - Stage 6 freezes student / unfreezes aux; Stages 7-9 do the reverse """ def __init__(self, config, accelerator, logger): self.config = config self.accelerator = accelerator self.logger = logger cfg = config.distill dtype = accelerate_utils.precision_to_dtype(config.training.mixed_precision) self.device = accelerator.device # -------- Pipeline (teacher + student + aux) ---------------------- logger.info(f"[init] Loading teacher from {cfg.teacher_ckpt} ...") self.pipe = URSADistillDiMOPipeline( teacher_ckpt=cfg.teacher_ckpt, compute_dtype=dtype, aux_noise_std=float(cfg.get("aux_noise_std", 0.0)), ) # Move teacher to GPU (not prepared by accelerate — frozen). self.pipe.teacher = self.pipe.teacher.to(self.device) self.pipe.scheduler.to(device=self.device) # Compute latents shape from video geometry. from src.distill.utils_ursa_inputs import compute_latents_shape # Read VAE strides from pipeline (falls back to URSA defaults 4/8). vae_t = int(getattr(self.pipe, "vae_temporal_stride", 4)) vae_s = int(getattr(self.pipe, "vae_spatial_stride", 8)) self.latents_shape = compute_latents_shape( cfg.num_frames, cfg.height, cfg.width, vae_t, vae_s ) T, H, W = self.latents_shape self.N = T * H * W self.K = self.pipe.codebook_size logger.info( f"[init] latents_shape=({T},{H},{W}) N={self.N} K={self.K} " f"CFG={'ON' if cfg.enable_teacher_cfg else 'OFF'}" ) # Pre-compute uncond token IDs (empty string, [1, L]) on CPU. self.txt_uncond_base_cpu = self.pipe.tokenizer( [""], max_length=int(cfg.max_prompt_length), padding="max_length", padding_side="left", truncation=True, return_tensors="pt", ).input_ids # [1, L] CPU # -------- Optimizers (before accelerate.prepare) ------------------ # Single optimizer with two param groups: # group[0] = student params, group[1] = aux params opt_cls = torch.optim.AdamW opt_s_params = dict( lr=float(config.optimizer_student.params.lr), betas=tuple(config.optimizer_student.params.get("betas", [0.9, 0.95])), weight_decay=float(config.optimizer_student.params.get("weight_decay", 0.01)), ) opt_a_params = dict( lr=float(config.optimizer_aux.params.lr), betas=tuple(config.optimizer_aux.params.get("betas", [0.9, 0.95])), weight_decay=float(config.optimizer_aux.params.get("weight_decay", 0.01)), ) def _enable_gcpt(m): # m.model.layers 是 Qwen3Model 的层列表 for layer in m.model.layers: layer.gradient_checkpointing = True layer.self_attn.gradient_checkpointing = True layer.mlp.gradient_checkpointing = True _enable_gcpt(self.pipe.student) _enable_gcpt(self.pipe.aux) # -------- 断点续传:在 ZeRO-3 切分参数前加载权重 ------------------ self.global_step = int(config.experiment.get("resume_iter", 0)) if self.global_step > 0: ckpt_dir = os.path.join( config.experiment.output_dir, "checkpoints", f"checkpoint-{self.global_step}" ) if os.path.exists(ckpt_dir): logger.info(f"[Resume] 正在从 {ckpt_dir} 恢复 Student 和 Aux 的权重...") # 必须在 map_location="cpu" 下加载,防止爆显存,随后 prepare 会自动分配 self.pipe.student.load_state_dict(torch.load(os.path.join(ckpt_dir, "student.pt"), map_location="cpu")) self.pipe.aux.load_state_dict(torch.load(os.path.join(ckpt_dir, "aux.pt"), map_location="cpu")) else: logger.warning(f"[Resume] 找不到检查点 {ckpt_dir},将从随机初始状态起步!") # -------- Wrap student + aux into a single DistillTwinModel -------- twin_model = DistillTwinModel(self.pipe.student, self.pipe.aux) opt_raw = opt_cls([ {"params": list(self.pipe.student.parameters()), **opt_s_params}, {"params": list(self.pipe.aux.parameters()), **opt_a_params}, ]) # -------- accelerate.prepare: single model + single optimizer ------ # Teacher is NOT prepared (frozen; no grad sync needed). self.model, self.optimizer = accelerator.prepare(twin_model, opt_raw) # LR schedulers (step() called manually at end of each step). self.scheduler_s = CosineLR( lr_max=float(config.optimizer_student.params.lr), lr_min=float(config.lr_scheduler.params.get("lr_min", 1e-6)), max_steps=int(config.training.max_train_steps), warmup_steps=int(config.lr_scheduler.params.get("warmup_steps", 500)), ) self.scheduler_a = CosineLR( lr_max=float(config.optimizer_aux.params.lr), lr_min=float(config.lr_scheduler.params.get("lr_min", 1e-6)), max_steps=int(config.training.max_train_steps), warmup_steps=int(config.lr_scheduler.params.get("warmup_steps", 500)), ) # -------- Dataset / DataLoader ------------------------------------ dataloader_cfg = config.get("prompt_dataloader", {}) dataset = PromptDataset( prompt_source=str(cfg.prompt_source), shuffle_files=bool(dataloader_cfg.get("shuffle_files", True)), shuffle_buffer=int(dataloader_cfg.get("shuffle_buffer", 0)), seed=int(config.training.seed), infinite=True, csv=CSVSpec(caption_field=str(dataloader_cfg.get("caption_field", "caption"))), ) # collate_fn: tokenize on CPU (no CUDA in workers). collate_fn = make_collate_fn( self.pipe.tokenizer, max_prompt_length=int(cfg.max_prompt_length), device=torch.device("cpu"), # CPU output — moved to GPU in run_step ) loader = DataLoader( dataset, batch_size=int(cfg.batch_size_per_gpu), shuffle=False, # IterableDataset: no shuffle flag drop_last=True, num_workers=int(dataloader_cfg.get("num_workers", 2)), collate_fn=collate_fn, pin_memory=True, ) # DataLoader is NOT prepared by accelerate because PromptDataset # handles per-rank file sharding internally via torch.distributed. self._inf_loader = InfiniteDataLoader(loader) # -------- Training state ------------------------------------------ # self.global_step = int(config.experiment.get("resume_iter", 0)) self.baseline_ema: float = 0.0 self.x_hat_prev: Optional[torch.Tensor] = None self.metrics = collections.OrderedDict() # -------- Verified regime validation -------------------------------- native = VERIFIED_NATIVE_DEFAULTS is_native = check_verified_regime( height=int(cfg.height), width=int(cfg.width), num_frames=int(cfg.num_frames), guidance_scale=float(cfg.teacher_cfg_scale) if cfg.enable_teacher_cfg else None, label="train", ) logger.info( f"[init] verified_native_regime={is_native} " f"geometry=({cfg.num_frames}×{cfg.height}×{cfg.width}) " f"teacher_cfg_scale={cfg.teacher_cfg_scale if cfg.enable_teacher_cfg else 'OFF'}" ) if not cfg.enable_teacher_cfg: logger.warning( "[WARN] Teacher CFG is DISABLED. no_cfg is known to produce " "blank/blurry output for this URSA checkpoint. " "Distillation without CFG is unlikely to produce useful results." ) elif float(cfg.teacher_cfg_scale) != native["guidance_scale"]: logger.warning( f"[WARN] teacher_cfg_scale={cfg.teacher_cfg_scale} differs from " f"the verified working value ({native['guidance_scale']}). " "Outputs may deviate from the official inference working point." ) logger.info( f"[init] student params: {engine_utils.count_params(self.pipe.student):.2f}M" ) logger.info( f"[init] max_train_steps={config.training.max_train_steps} " f"batch_size_per_gpu={cfg.batch_size_per_gpu} " f"num_processes={accelerator.num_processes}" ) # ----------------------------------------------------------------------- # run_step: Stages 1-9 # ----------------------------------------------------------------------- def run_step(self, step: int) -> dict: """Execute one distillation step (Stages 1-9).""" cfg = self.config.distill T, H, W = self.latents_shape N, K = self.N, self.K device = self.device stats = {"step": step} timer = profiler.Timer().tic() # Update LR from cosine schedulers. # param_groups[0] = student, param_groups[1] = aux lr_s = self.scheduler_s.get_lr() lr_a = self.scheduler_a.get_lr() stats["lr_student"] = lr_s stats["lr_aux"] = lr_a self.optimizer.param_groups[0]["lr"] = lr_s self.optimizer.param_groups[1]["lr"] = lr_a # ---------------------------------------------------------------- # Stage 1: Get tokenised batch (CPU → GPU) # ---------------------------------------------------------------- txt_ids = next(self._inf_loader) # [B, L] CPU tensor txt_ids = txt_ids.to(device, non_blocking=True) B = txt_ids.size(0) txt_uncond = None if cfg.enable_teacher_cfg: txt_uncond = self.txt_uncond_base_cpu.expand(B, -1).to(device) # # ---------------------------------------------------------------- # # Stage 2: Sample x_init ~ Uniform(K) with optional p_init mixing # # ---------------------------------------------------------------- # x_init = torch.randint(0, K, (B, T, H, W), device=device, dtype=torch.long) # if self.x_hat_prev is not None and float(cfg.p_init_mix_ratio) > 0: # n_mix = max(1, int(B * float(cfg.p_init_mix_ratio))) # x_init[:n_mix] = self.pipe.corrupt_tokens( # self.x_hat_prev[:n_mix], r=float(cfg.p_mix_corrupt_frac) # ) # ---------------------------------------------------------------- # Stage 2: Sample x_init ~ Uniform(K) with optional p_init mixing # ---------------------------------------------------------------- x_init = torch.randint(0, K, (B, T, H, W), device=device, dtype=torch.long) # 修复:使用概率触发,确保小 Batch 时模型依然能充分学习处理纯噪声 if self.x_hat_prev is not None and float(cfg.p_init_mix_ratio) > 0: if torch.rand(1).item() < float(cfg.p_init_mix_ratio): # 如果触发,只混合 batch 里的第一个样本 x_init[0] = self.pipe.corrupt_tokens( self.x_hat_prev[0:1], r=float(cfg.p_mix_corrupt_frac) ).squeeze(0) # ---------------------------------------------------------------- # Stage 3: Student 1-step on x_init — no_grad (only sample x_hat) # # Gradient-enabled forward on x_init is deferred to Stage 9b so # the KD computation graph (Stage 7, x_t) can be freed first. # ---------------------------------------------------------------- with torch.no_grad(): ids_init, rpos_init, _ = self.pipe.build_inputs( txt_ids, x_init, self.latents_shape ) logits_s_init = _get_logits( self.model("student", ids_init, rope_pos=rpos_init) ) z_s = self.pipe.extract_logits(logits_s_init, N) # [B, N, K] p_s = F.softmax(z_s / float(cfg.tau), dim=-1) # [B, N, K] x_hat = torch.multinomial(p_s.view(-1, K), 1).view(B, N) # [B, N] # if step == 1: # # 只抽 8 个 token 做 sum=1 检查,别全量 # idx = torch.randint(0, N, (8,), device=device) # p_err = (p_s[:, idx].sum(-1) - 1).abs().max().item() # assert p_err < 1e-3, f"p_s subset not normalised: {p_err}" del p_s, z_s, logits_s_init x_hat_4d = x_hat.view(B, T, H, W) # ---------------------------------------------------------------- # Stage 4: Pseudo-intermediate x_t = add_noise(x_hat, t) # ---------------------------------------------------------------- t = self.pipe.sample_t_curriculum( B, device, step, int(cfg.t_curriculum_steps) ) # [B] float ∈ (0.05, 0.995) with torch.no_grad(): x_t = self.pipe.scheduler.add_noise(x_hat_4d, t) # [B,T,H,W] long # # ---------------------------------------------------------------- # # Stage 5: Teacher forward — single [2B] forward when CFG enabled # # ---------------------------------------------------------------- # with torch.no_grad(): # if cfg.enable_teacher_cfg: # txt_dual = torch.cat([txt_ids, txt_uncond], dim=0) # [2B, L] # x_t_dual = torch.cat([x_t, x_t], dim=0) # [2B,T,H,W] # ids_dual, rpos_dual, _ = self.pipe.build_inputs( # txt_dual, x_t_dual, self.latents_shape # ) # logits_T_dual = _get_logits( # self.pipe.teacher(ids_dual, rope_pos=rpos_dual) # ) # z_T_dual = self.pipe.extract_logits(logits_T_dual, N) # [2B,N,K] # z_T_cond, z_T_uncond = z_T_dual.chunk(2, dim=0) # [B,N,K] # del logits_T_dual, z_T_dual # torch.cuda.empty_cache() # ids_t, rpos_t = ids_dual[:B], rpos_dual[:B] # else: # ids_t, rpos_t, _ = self.pipe.build_inputs( # txt_ids, x_t, self.latents_shape # ) # logits_T = _get_logits( # self.pipe.teacher(ids_t, rope_pos=rpos_t) # ) # z_T_cond = self.pipe.extract_logits(logits_T, N) # [B,N,K] # z_T_uncond = None # ids_dual, rpos_dual = ids_t, rpos_t # # CFG guided target with per-sample Bernoulli warmup. # z_T_guided = None # use_guided_ratio = 0.0 # if cfg.enable_teacher_cfg: # p_guided = _cfg_warmup_prob( # step, # float(cfg.teacher_cfg_prob), # int(cfg.teacher_cfg_warmup_steps), # ) # use_guided = torch.rand(B, device=device) < p_guided # [B] bool # use_guided_ratio = float(use_guided.float().mean().item()) # z_T_guided = _build_guided_logits( # z_T_cond, z_T_uncond, # t, float(cfg.teacher_cfg_scale), float(cfg.teacher_cfg_trunc), # ) # mask = use_guided.view(-1, 1, 1).expand_as(z_T_cond) # z_T_target = torch.where(mask, z_T_guided, z_T_cond.float()) # else: # z_T_target = z_T_cond # z_T_target = z_T_target.detach() # NO grad path to teacher # # # ---------------------------------------------------------------- # # # Stage 6: Aux update — fake_rounds iterations # # # # # # Freeze student so only aux gets gradients. With a single # # # DeepSpeed-wrapped optimizer this is the cleanest way to ensure # # # only aux params are updated. # # # ---------------------------------------------------------------- # # raw_twin = self.accelerator.unwrap_model(self.model) # # raw_twin.student.requires_grad_(False) # # raw_twin.aux.requires_grad_(True) # # loss_aux_cond_last = torch.tensor(0.0, device=device) # # loss_aux_uncond_last = torch.tensor(0.0, device=device) # # loss_aux_cond_sample_last = None # # for _fr in range(int(cfg.fake_rounds)): # # self.optimizer.zero_grad(set_to_none=True) # # if cfg.enable_teacher_cfg: # # logits_A_dual = _get_logits( # # self.model("aux", ids_dual.detach(), rope_pos=rpos_dual.detach()) # # ) # # z_A_dual = self.pipe.extract_logits(logits_A_dual, N) # [2B,N,K] # # z_A_cond, z_A_uncond = z_A_dual.chunk(2, dim=0) # # loss_aux_cond_sample = _stable_jeffrey( # # z_T_target, z_A_cond, float(cfg.tau_kd),chunk_size=1024 # # ) # [B] # # loss_aux_cond_v = loss_aux_cond_sample.mean() # # loss_aux_uncond_v = _stable_jeffrey( # # z_T_uncond.float().detach(), z_A_uncond, float(cfg.tau_kd),chunk_size=1024 # # ).mean() # # loss_aux_v = ( # # loss_aux_cond_v # # + float(cfg.lambda_kd_uncond) * loss_aux_uncond_v # # ) # # else: # # logits_A = _get_logits( # # self.model("aux", ids_t.detach(), rope_pos=rpos_t.detach()) # # ) # # z_A_cond = self.pipe.extract_logits(logits_A, N) # # loss_aux_cond_sample = _stable_jeffrey( # # z_T_target, z_A_cond, float(cfg.tau_kd),chunk_size=1024 # # ) # # loss_aux_cond_v = loss_aux_cond_sample.mean() # # loss_aux_uncond_v = torch.tensor(0.0, device=device) # # loss_aux_v = loss_aux_cond_v # # self.accelerator.backward(loss_aux_v) # # if float(cfg.grad_clip) > 0: # # torch.nn.utils.clip_grad_norm_( # # raw_twin.aux.parameters(), float(cfg.grad_clip) # # ) # # self.optimizer.step() # # self.optimizer.zero_grad(set_to_none=True) # # loss_aux_cond_last = loss_aux_cond_v.detach() # # loss_aux_uncond_last = loss_aux_uncond_v.detach() # # loss_aux_cond_sample_last = loss_aux_cond_sample.detach() # [B] # # # ---------------------------------------------------------------- # # # Stage 7: Student KD forward on x_t (with grad) # # # # # # Switch: freeze aux, unfreeze student for Stages 7-9. # # # ---------------------------------------------------------------- # # raw_twin.student.requires_grad_(True) # # raw_twin.aux.requires_grad_(False) # # self.optimizer.zero_grad(set_to_none=True) # # if cfg.enable_teacher_cfg: # # logits_S_dual = _get_logits( # # self.model("student", ids_dual.detach(), rope_pos=rpos_dual.detach()) # # ) # # z_S_dual = self.pipe.extract_logits(logits_S_dual, N) # # z_S_cond, z_S_uncond = z_S_dual.chunk(2, dim=0) # # loss_kd_cond = _stable_kl( # # z_T_target, z_S_cond, float(cfg.tau_kd), chunk_size=2048 # # ).mean() # # loss_kd_uncond = _stable_kl( # # z_T_uncond.float().detach(), z_S_uncond, float(cfg.tau_kd), chunk_size=2048 # # ).mean() # # loss_kd = loss_kd_cond + float(cfg.lambda_kd_uncond) * loss_kd_uncond # # else: # # logits_S = _get_logits( # # self.model("student", ids_t.detach(), rope_pos=rpos_t.detach()) # # ) # # z_S_cond = self.pipe.extract_logits(logits_S, N) # # loss_kd_cond = _stable_kl( # # z_T_target, z_S_cond, float(cfg.tau_kd), chunk_size=2048 # # ).mean() # # loss_kd_uncond = torch.tensor(0.0, device=device) # # loss_kd = loss_kd_cond # # # ---------------------------------------------------------------- # # # Stage 8: Reward + advantage (fully detached — no student grad) # # # # # # INVARIANT: reward and adv must never carry student gradients. # # # ---------------------------------------------------------------- # # if cfg.enable_teacher_cfg and cfg.reward_use_guided: # # z_T_for_rew = z_T_target # already detached # # else: # # z_T_for_rew = z_T_cond.detach() # # # reward[b] = -KL(z_T_cond || z_S_cond) with BOTH inputs detached # # with torch.no_grad(): # # reward = -_stable_kl( # # z_T_for_rew.detach(), z_S_cond.detach(), float(cfg.tau), chunk_size=1024 # # ) # [B] # # assert not reward.requires_grad, ( # # "[BUG] reward.requires_grad=True — student grad leaked into reward. " # # "z_S_cond must be detached before KL for reward." # # ) # # self.baseline_ema = ( # # 0.99 * self.baseline_ema + 0.01 * float(reward.mean().item()) # # ) # # adv = (reward - self.baseline_ema).detach() # [B] # # assert not adv.requires_grad, "[BUG] adv.requires_grad=True" # # # ---------------------------------------------------------------- # # # Stage 9: Two-backward student update # # # # # # 9a) KD backward first — frees the KD graph to save memory. # # # Uses no_sync() (no DDP all-reduce) so gradients are not # # # double-reduced when the PG backward syncs in 9b. # # # 9b) Fresh forward on x_init WITH grad → PG + entropy backward. # # # DDP all-reduce happens here (normal backward). # # # ---------------------------------------------------------------- # # # 9a: KD backward (no sync — first of two backwards) # # _no_sync_backward( # # self.accelerator, self.model, float(cfg.lambda_kd) * loss_kd # # ) # # # 9b: Policy + entropy — fresh forward on x_init WITH grad # # ids_init, rpos_init, _ = self.pipe.build_inputs( # # txt_ids, x_init, self.latents_shape # # ) # # logits_s_pol = _get_logits( # # self.model("student", ids_init, rope_pos=rpos_init) # # ) # # z_s_pol = self.pipe.extract_logits(logits_s_pol, N) # [B, N, K] # # logp_tok = F.log_softmax(z_s_pol / float(cfg.tau), dim=-1) # [B, N, K] # # p_s_pol = logp_tok.exp() # # # per-token average log-prob (recommended over log-prob sum) # # logp_sum = ( # # logp_tok.gather(-1, x_hat.unsqueeze(-1)).squeeze(-1).sum(-1) # # ) # [B] # # logp = logp_sum / N # [B] per-token logp # # H_mean = -(p_s_pol * logp_tok).sum(-1).mean() # # loss_pg = -(adv * logp).mean() # # lambda_ent_eff = float(cfg.lambda_ent) * (1.0 + 2.0 * use_guided_ratio) # # # Second backward: DDP all-reduce happens here. # # self.accelerator.backward( # # float(cfg.lambda_pg) * loss_pg - lambda_ent_eff * H_mean # # ) # # if float(cfg.grad_clip) > 0: # # torch.nn.utils.clip_grad_norm_( # # raw_twin.student.parameters(), float(cfg.grad_clip) # # ) # # self.optimizer.step() # # # Restore both sub-modules to trainable for next step. # # raw_twin.student.requires_grad_(True) # # raw_twin.aux.requires_grad_(True) # # # p_init mixing: store x_hat_4d (detached) for next step. # # self.x_hat_prev = x_hat_4d.detach() # # ---------------------------------------------------------------- # # Stage 6: Aux update — Fit sampled pseudo-target (x_hat) from student # # ---------------------------------------------------------------- # raw_twin = self.accelerator.unwrap_model(self.model) # raw_twin.student.requires_grad_(False) # raw_twin.aux.requires_grad_(True) # target_tokens = x_hat.detach() # [B, N] - 学生在 Stage 3 盲猜出来的画面 # for _fr in range(int(cfg.fake_rounds)): # self.optimizer.zero_grad(set_to_none=True) # if cfg.enable_teacher_cfg: # logits_A_dual = _get_logits( # self.model("aux", ids_dual.detach(), rope_pos=rpos_dual.detach()) # ) # z_A_dual = self.pipe.extract_logits(logits_A_dual, N) # [2B,N,K] # z_A_cond, z_A_uncond = z_A_dual.chunk(2, dim=0) # # Aux 拟合学生的假 token (Cross Entropy) # loss_aux_cond_v = F.cross_entropy( # z_A_cond.reshape(B * N, K), # target_tokens.reshape(B * N), # reduction="mean", # ) # loss_aux_uncond_v = F.cross_entropy( # z_A_uncond.reshape(B * N, K), # target_tokens.reshape(B * N), # reduction="mean", # ) # loss_aux_v = loss_aux_cond_v + float(cfg.lambda_kd_uncond) * loss_aux_uncond_v # else: # logits_A = _get_logits( # self.model("aux", ids_t.detach(), rope_pos=rpos_t.detach()) # ) # z_A_cond = self.pipe.extract_logits(logits_A, N) # loss_aux_cond_v = F.cross_entropy( # z_A_cond.reshape(B * N, K), # target_tokens.reshape(B * N), # reduction="mean", # ) # loss_aux_uncond_v = torch.tensor(0.0, device=device) # loss_aux_v = loss_aux_cond_v # self.accelerator.backward(loss_aux_v) # if float(cfg.grad_clip) > 0: # torch.nn.utils.clip_grad_norm_( # raw_twin.aux.parameters(), float(cfg.grad_clip) # ) # self.optimizer.step() # loss_aux_cond_last = loss_aux_cond_v.detach() # # ---------------------------------------------------------------- # # Stage 7 & 8: Student KD update & Aux Bridge (Gradient Injection) # # ---------------------------------------------------------------- # raw_twin.student.requires_grad_(True) # raw_twin.aux.requires_grad_(False) # self.optimizer.zero_grad(set_to_none=True) # # 7a. Student KD forward on x_t (保持原样) # if cfg.enable_teacher_cfg: # logits_S_dual = _get_logits( # self.model("student", ids_dual.detach(), rope_pos=rpos_dual.detach()) # ) # z_S_dual = self.pipe.extract_logits(logits_S_dual, N) # z_S_cond, z_S_uncond = z_S_dual.chunk(2, dim=0) # # --- [新增] 立刻释放显存 --- # del logits_S_dual, z_S_dual # loss_kd_cond = _stable_kl( # z_T_target, z_S_cond, float(cfg.tau_kd), chunk_size=256 #2048 # ).mean() # loss_kd_uncond = _stable_kl( # z_T_uncond.float().detach(), z_S_uncond, float(cfg.tau_kd), chunk_size=256 #2048 # ).mean() # loss_kd = loss_kd_cond + float(cfg.lambda_kd_uncond) * loss_kd_uncond # else: # logits_S = _get_logits( # self.model("student", ids_t.detach(), rope_pos=rpos_t.detach()) # ) # z_S_cond = self.pipe.extract_logits(logits_S, N) # loss_kd_cond = _stable_kl( # z_T_target, z_S_cond, float(cfg.tau_kd), chunk_size=256 #2048 # ).mean() # loss_kd_uncond = torch.tensor(0.0, device=device) # loss_kd = loss_kd_cond # # 7b. 获取 Aux 的预测 (无梯度) 作为计算桥梁 # with torch.no_grad(): # if cfg.enable_teacher_cfg: # logits_A_dual = _get_logits( # self.model("aux", ids_dual.detach(), rope_pos=rpos_dual.detach()) # ) # z_A_dual = self.pipe.extract_logits(logits_A_dual, N) # z_A_cond, _ = z_A_dual.chunk(2, dim=0) # # --- [新增] 立刻释放显存 --- # del logits_A_dual, z_A_dual # else: # logits_A = _get_logits( # self.model("aux", ids_t.detach(), rpos_t.detach()) # ) # z_A_cond = self.pipe.extract_logits(logits_A, N) # # 8. Student 对初始噪声 x_init 进行带梯度的前向传播 # ids_init, rpos_init, _ = self.pipe.build_inputs( # txt_ids, x_init, self.latents_shape # ) # logits_s_pol = _get_logits( # self.model("student", ids_init, rope_pos=rpos_init) # ) # z_s_pol = self.pipe.extract_logits(logits_s_pol, N) # # --- 核心数学修正:将 Logits 转换为概率,防止梯度爆炸 --- # p_T = F.softmax(z_T_target / float(cfg.tau_kd), dim=-1) # p_A = F.softmax(z_A_cond / float(cfg.tau_kd), dim=-1) # # 目标方向:Teacher 概率 - Aux 概率 (遵循论文公式推导) # bridge_target = (p_T - p_A).detach() # # 利用 MSE Trick 强制注入梯度 # loss_bridge = 0.5 * F.mse_loss( # z_s_pol.float(), # (z_s_pol.float() + bridge_target).detach() # ) # # 9. 单次反向传播 (合并 KD 和 Bridge) # # 借用原来的 lambda_pg 参数来控制 bridge 损失的权重 # loss_student = float(cfg.lambda_kd) * loss_kd + float(cfg.lambda_pg) * loss_bridge # self.accelerator.backward(loss_student) # if float(cfg.grad_clip) > 0: # torch.nn.utils.clip_grad_norm_( # raw_twin.student.parameters(), float(cfg.grad_clip) # ) # self.optimizer.step() # # 恢复两者的可训练状态 # raw_twin.student.requires_grad_(True) # raw_twin.aux.requires_grad_(True) # # --- 兼容原始日志输出的占位符 --- # H_mean = torch.tensor(0.0, device=device) # loss_pg = loss_bridge.detach() # 将 bridge 损失映射给 pg 显示 # logp = torch.tensor(0.0, device=device) # self.baseline_ema = 0.0 # ---------------------------------------------------------------- # Stage 5: Teacher forward — 破除视图死锁,生成目标后立刻释放 # ---------------------------------------------------------------- with torch.no_grad(): if cfg.enable_teacher_cfg: txt_dual = torch.cat([txt_ids, txt_uncond], dim=0) # [2B, L] x_t_dual = torch.cat([x_t, x_t], dim=0) # [2B,T,H,W] ids_dual, rpos_dual, _ = self.pipe.build_inputs( txt_dual, x_t_dual, self.latents_shape ) logits_T_dual = _get_logits( self.pipe.teacher(ids_dual, rope_pos=rpos_dual) ) z_T_dual = self.pipe.extract_logits(logits_T_dual, N) # [2B,N,K] # 【显存救星 1】使用 .clone() 打断视图依赖,使得原始巨型张量可以被回收 z_T_cond = z_T_dual[0:1].clone() # [1,N,K] z_T_uncond = z_T_dual[1:2].clone() # [1,N,K] ids_t, rpos_t = ids_dual[:B], rpos_dual[:B] # 立刻释放 17 GB 的双路缓冲 del logits_T_dual, z_T_dual torch.cuda.empty_cache() else: ids_t, rpos_t, _ = self.pipe.build_inputs(txt_ids, x_t, self.latents_shape) logits_T = _get_logits(self.pipe.teacher(ids_t, rope_pos=rpos_t)) z_T_cond = self.pipe.extract_logits(logits_T, N) z_T_uncond = None # 计算 CFG guided target z_T_guided = None use_guided_ratio = 0.0 if cfg.enable_teacher_cfg: p_guided = _cfg_warmup_prob(step, float(cfg.teacher_cfg_prob), int(cfg.teacher_cfg_warmup_steps)) use_guided = torch.rand(B, device=device) < p_guided use_guided_ratio = float(use_guided.float().mean().item()) z_T_guided = _build_guided_logits( z_T_cond, z_T_uncond, t, float(cfg.teacher_cfg_scale), float(cfg.teacher_cfg_trunc), ) mask = use_guided.view(-1, 1, 1).expand_as(z_T_cond) # 【显存救星 2】保持为 bf16 类型,避免膨胀到 8.5GB z_T_target = torch.where(mask, z_T_guided, z_T_cond).to(dtype=z_T_cond.dtype).detach() # 立刻清理所有中间推导变量 del z_T_cond, z_T_uncond, z_T_guided torch.cuda.empty_cache() else: z_T_target = z_T_cond.detach() # ---------------------------------------------------------------- # Stage 6: Aux update — 【显存救星 3】强行降维为单路前向传播 (Batch=1) # ---------------------------------------------------------------- raw_twin = self.accelerator.unwrap_model(self.model) raw_twin.student.requires_grad_(False) raw_twin.aux.requires_grad_(True) target_tokens = x_hat.detach() for _fr in range(int(cfg.fake_rounds)): self.optimizer.zero_grad(set_to_none=True) # 只处理单路 ids_t,不处理 dual,砍掉 Aux 50% 显存! logits_A = _get_logits( self.model("aux", ids_t.detach(), rope_pos=rpos_t.detach()) ) z_A_cond = self.pipe.extract_logits(logits_A, N) loss_aux_cond_v = F.cross_entropy( z_A_cond.reshape(B * N, K), target_tokens.reshape(B * N), reduction="mean", ) self.accelerator.backward(loss_aux_cond_v) if float(cfg.grad_clip) > 0: torch.nn.utils.clip_grad_norm_(raw_twin.aux.parameters(), float(cfg.grad_clip)) self.optimizer.step() # 必须立刻释放 del logits_A, z_A_cond torch.cuda.empty_cache() loss_aux_cond_last = loss_aux_cond_v.detach() # ---------------------------------------------------------------- # Stage 7 & 8: Student KD update & Aux Bridge # ---------------------------------------------------------------- raw_twin.student.requires_grad_(True) raw_twin.aux.requires_grad_(False) self.optimizer.zero_grad(set_to_none=True) # 7a. Student KD (强行降维为单路前向传播 Batch=1) logits_S = _get_logits( self.model("student", ids_t.detach(), rope_pos=rpos_t.detach()) ) z_S_cond = self.pipe.extract_logits(logits_S, N) # 使用 128 chunk size,确保极致安全 loss_kd = _stable_kl( z_T_target, z_S_cond, float(cfg.tau_kd), chunk_size=128 ).mean() del logits_S, z_S_cond torch.cuda.empty_cache() # 7b. 获取 Aux 的预测作为桥梁 (依然单路) with torch.no_grad(): logits_A = _get_logits( self.model("aux", ids_t.detach(), rope_pos=rpos_t.detach()) ) z_A_cond = self.pipe.extract_logits(logits_A, N) # 8. Student 对 x_init 进行前向传播 ids_init, rpos_init, _ = self.pipe.build_inputs(txt_ids, x_init, self.latents_shape) logits_s_pol = _get_logits( self.model("student", ids_init, rope_pos=rpos_init) ) z_s_pol = self.pipe.extract_logits(logits_s_pol, N) # 【显存救星 4】在 bf16 精度下计算 Softmax 概率,防止 float32 炸存 p_T = F.softmax(z_T_target / float(cfg.tau_kd), dim=-1).to(z_s_pol.dtype) p_A = F.softmax(z_A_cond / float(cfg.tau_kd), dim=-1).to(z_s_pol.dtype) bridge_target = (p_T - p_A).detach() # 拿到 bridge_target 后,前面所有百兆甚至 G 级的张量统统干掉 del p_T, p_A, logits_A, z_A_cond, z_T_target torch.cuda.empty_cache() # 伪梯度注入 loss_bridge = 0.5* K * F.mse_loss( z_s_pol.float(), (z_s_pol.float() + bridge_target.float()).detach() ) # 9. 统一反向传播 loss_student = float(cfg.lambda_kd) * loss_kd + float(cfg.lambda_pg) * loss_bridge self.accelerator.backward(loss_student) if float(cfg.grad_clip) > 0: torch.nn.utils.clip_grad_norm_(raw_twin.student.parameters(), float(cfg.grad_clip)) self.optimizer.step() # 恢复状态 raw_twin.student.requires_grad_(True) raw_twin.aux.requires_grad_(True) # 最后的清理 del logits_s_pol, z_s_pol, bridge_target torch.cuda.empty_cache() H_mean = torch.tensor(0.0, device=device) loss_pg = loss_bridge.detach() logp = torch.tensor(0.0, device=device) self.baseline_ema = 0.0 # Advance LR schedulers. self.scheduler_s.step() self.scheduler_a.step() # ---------------------------------------------------------------- # Step 1 sanity assertions (lightweight; runs only at step 1) # ---------------------------------------------------------------- # if step == 1: # self._step1_assertions( # x_init, ids_init, rpos_init, z_s, p_s, logp, # z_T_cond, z_S_cond, x_t, B, T, H, W, # ) # ---------------------------------------------------------------- # Token-level collapse detection # ---------------------------------------------------------------- tok_entropy = self._token_entropy(x_hat) if not hasattr(self, "_init_tok_entropy"): self._init_tok_entropy = tok_entropy collapse_frac = float(cfg.get("collapse_warn_frac", 0.2)) if tok_entropy < collapse_frac * self._init_tok_entropy: self.logger.warning( f"[COLLAPSE] step={step} tok_H={tok_entropy:.3f} " f"init={self._init_tok_entropy:.3f} " f"ratio={tok_entropy / max(self._init_tok_entropy, 1e-8):.2f} " f"< {collapse_frac}. Try increasing lambda_ent." ) stats["time"] = timer.toc() stats["metrics"] = collections.OrderedDict( sorted( { "loss_aux_cond": float(loss_aux_cond_last.item()), "loss_kd_cond": float(loss_kd.item()), "loss_pg": float(loss_pg.item()), "H_mean": float(H_mean.item()), "tok_entropy": float(tok_entropy), "mean_logp_tok": float(logp.mean().item()), "baseline_ema": float(self.baseline_ema), "use_guided_ratio": float(use_guided_ratio), }.items() ) ) return stats # ----------------------------------------------------------------------- # Train loop # ----------------------------------------------------------------------- def train_loop(self): """Main training loop (mirrors diffnext.engine.train_engine.Trainer).""" cfg_exp = self.config.experiment max_steps = int(self.config.training.max_train_steps) log_every = int(cfg_exp.log_every) save_every = int(cfg_exp.save_every) self.global_step = int(self.config.experiment.get("resume_iter", 0)) # Sync LR schedulers to resume step (set _step_count directly; # CosineLR uses _step_count internally in get_decay()). self.scheduler_s._step_count = self.global_step self.scheduler_a._step_count = self.global_step # [可选补充] 如果是续传,让 accelerator 自动恢复被切分的 Optimizer 等状态 if self.global_step > 0: ckpt_dir = os.path.join(self.config.experiment.output_dir, "checkpoints", f"checkpoint-{self.global_step}") if os.path.exists(ckpt_dir): self.accelerator.load_state(ckpt_dir) self.logger.info(f"✅ ZeRO-3 完整状态 (包含 Optimizer) 已从 {ckpt_dir} 恢复") timer = profiler.Timer() self.logger.info( f"[train] Starting from step {self.global_step} / {max_steps}" ) while self.global_step < max_steps: self.global_step += 1 with timer.tic_and_toc(): stats = self.run_step(self.global_step) self._add_metrics(stats) if self.global_step % log_every == 0: self._log_metrics(stats) if self.global_step % (10 * log_every) == 0: self.logger.info( profiler.get_progress(timer, self.global_step, max_steps) ) if self.global_step % save_every == 0: self.save(self.global_step) # Final log + save (only when loop ran at least one step). if self.global_step > int(self.config.experiment.get("resume_iter", 0)): self._log_metrics({**stats, "step": self.global_step}) # noqa: F821 self.accelerator.wait_for_everyone() self.save(self.global_step, suffix="final") self.accelerator.end_training() # ----------------------------------------------------------------------- # Checkpoint helpers # ----------------------------------------------------------------------- # def save(self, step: int, suffix: str = None) -> None: # """Save student + aux state_dicts (rank0 only). # Saved as: # /checkpoints/checkpoint-/student.pt # /checkpoints/checkpoint-/aux.pt # The student.pt can be used for inference by replacing the # transformer weights in a URSAPipeline (see README). # """ # if not self.accelerator.is_main_process: # return # folder = f"checkpoint-{suffix}" if suffix else f"checkpoint-{step}" # ckpt_dir = os.path.join( # self.config.experiment.output_dir, "checkpoints", folder # ) # os.makedirs(ckpt_dir, exist_ok=True) # raw_student = self.accelerator.unwrap_model(self.model).student # raw_aux = self.accelerator.unwrap_model(self.model).aux # student_path = os.path.join(ckpt_dir, "student.pt") # aux_path = os.path.join(ckpt_dir, "aux.pt") # torch.save(raw_student.state_dict(), student_path) # torch.save(raw_aux.state_dict(), aux_path) # # Also save training state for resuming. # state = { # "global_step": step, # "baseline_ema": self.baseline_ema, # "optimizer": self.optimizer.state_dict(), # } # torch.save(state, os.path.join(ckpt_dir, "train_state.pt")) # self.logger.info(f"[save] step={step} → {ckpt_dir}") def save(self, step: int, suffix: str = None) -> None: """Save student + aux state_dicts (支持 DeepSpeed ZeRO-3 自动聚合).""" # ⚠️ 【极其重要】:get_state_dict 必须由所有 8 张卡共同执行! # 绝对不能把它放在 is_main_process 判断的里面,否则会触发跨卡死锁! full_state_dict = self.accelerator.get_state_dict(self.model) # 只有主进程(0号卡)负责把聚合好的完整参数写进硬盘 if not self.accelerator.is_main_process: return folder = f"checkpoint-{suffix}" if suffix else f"checkpoint-{step}" ckpt_dir = os.path.join( self.config.experiment.output_dir, "checkpoints", folder ) os.makedirs(ckpt_dir, exist_ok=True) # 从 TwinModel 的完整字典中,根据前缀拆分出 student 和 aux 的独立权重 student_state = {k.replace("student.", ""): v for k, v in full_state_dict.items() if k.startswith("student.")} aux_state = {k.replace("aux.", ""): v for k, v in full_state_dict.items() if k.startswith("aux.")} student_path = os.path.join(ckpt_dir, "student.pt") aux_path = os.path.join(ckpt_dir, "aux.pt") torch.save(student_state, student_path) torch.save(aux_state, aux_path) # 保存辅助状态 state = { "global_step": step, "baseline_ema": self.baseline_ema, } torch.save(state, os.path.join(ckpt_dir, "train_state.pt")) self.logger.info(f"[save] step={step} → {ckpt_dir} (ZeRO-3 Gathered)") # ----------------------------------------------------------------------- # Logging helpers (same API as original Trainer) # ----------------------------------------------------------------------- def _add_metrics(self, stats: dict) -> None: for k, v in stats["metrics"].items(): if k not in self.metrics: self.metrics[k] = profiler.SmoothedValue() self.metrics[k].update(v) def _log_metrics(self, stats: dict) -> None: iter_template = "Iteration %d, lr_s=%.2e lr_a=%.2e, time=%.2fs" self.logger.info( iter_template % ( stats["step"], stats.get("lr_student", 0.0), stats.get("lr_aux", 0.0), stats.get("time", 0.0), ) ) metric_template = " Train %s: %s" for k, v in self.metrics.items(): self.logger.info(metric_template % (k, v)) tracker_logs = {k: v.median for k, v in self.metrics.items()} tracker_logs.update( { "lr_student": stats.get("lr_student", 0.0), "time": stats.get("time", 0.0), } ) self.accelerator.log(tracker_logs, step=stats["step"]) self.metrics.clear() # ----------------------------------------------------------------------- # Sanity checks (step 1 only) # ----------------------------------------------------------------------- def _step1_assertions( self, x_init, ids_init, rpos_init, z_s, p_s, logp, z_T_cond, z_S_cond, x_t, B, T, H, W, ) -> None: """Shape / value-domain assertions (mirrors single-card script).""" N, K = self.N, self.K lm_vocab = self.pipe.teacher.config.lm_vocab_size L_plus_N1 = ids_init.size(1) txt_len = L_plus_N1 - (N + 1) assert x_init.dtype == torch.long assert x_init.min() >= 0 and x_init.max() < K assert ids_init.shape == (B, L_plus_N1), ids_init.shape txt_part = ids_init[:, :txt_len] vis_part = ids_init[:, -N:] assert (txt_part < lm_vocab).all(), "text tokens in visual range" assert (vis_part >= lm_vocab).all(), "visual tokens not shifted" assert (vis_part < lm_vocab + K).all(), "visual tokens exceed lm_vocab+K" assert rpos_init.shape == (B, L_plus_N1, 3), rpos_init.shape assert z_s.shape == (B, N, K), z_s.shape p_err = float((p_s.sum(-1) - 1).abs().max().item()) assert p_err < 1e-3, f"p_s not normalised: max_dev={p_err:.2e}" assert not torch.isnan(logp).any(), "logp has NaN" assert not torch.isinf(logp).any(), "logp has Inf" assert x_t.min() >= 0 and x_t.max() < K assert z_T_cond.shape == z_S_cond.shape == (B, N, K), ( f"z_T_cond={z_T_cond.shape} z_S_cond={z_S_cond.shape}" ) # Teacher has no grad. teacher_grads = [ p for p in self.pipe.teacher.parameters() if p.grad is not None ] assert len(teacher_grads) == 0, "teacher has grads — not frozen" # Student has grad (from PG backward). raw_s = self.accelerator.unwrap_model(self.model).student student_grad_norms = [ float(p.grad.norm().item()) for p in raw_s.parameters() if p.grad is not None ] assert len(student_grad_norms) > 0, "student has NO grads — grad flow broken" # ########################## # raw_t = self.pipe.teacher # raw_s = self.accelerator.unwrap_model(self.model).student # # (a) 共享存储检查:零开销 # pt0 = next(raw_t.parameters()) # ps0 = next(raw_s.parameters()) # self.logger.info(f"[assert] shared_storage={pt0.data_ptr() == ps0.data_ptr()}") # # (b) 参数差异:只采样前 4096 个元素,避免巨型临时张量 # with torch.no_grad(): # a = pt0.view(-1)[:4096].float() # b = ps0.view(-1)[:4096].float() # self.logger.info(f"[assert] param_delta_sample_max={float((a-b).abs().max().item()):.3e}") # # (c) logits 差异:只采样小子块(64 token × 256 vocab) # with torch.no_grad(): # idx_n = torch.randint(0, self.N, (64,), device=z_T_cond.device) # idx_k = torch.randint(0, self.K, (256,), device=z_T_cond.device) # subT = z_T_cond[0, idx_n][:, idx_k].float() # subS = z_S_cond[0, idx_n][:, idx_k].float() # self.logger.info(f"[assert] logits_delta_sub_max={float((subT-subS).abs().max().item()):.3e}") # ########################### self.logger.info("[assert] Step-1 shape/grad assertions PASSED ✓") self.logger.info( f"[assert] z_T_cond shape={z_T_cond.shape} " f"min={float(z_T_cond.min().item()):.3f} " f"max={float(z_T_cond.max().item()):.3f}" ) self.logger.info( f"[assert] z_S_cond shape={z_S_cond.shape} " f"min={float(z_S_cond.min().item()):.3f} " f"max={float(z_S_cond.max().item()):.3f}" ) @staticmethod def _token_entropy(x_hat: torch.Tensor) -> float: """Histogram entropy of sampled token indices (collapse detection).""" counts = x_hat.flatten().bincount(minlength=1).float() p = counts / counts.sum() p = p[p > 0] return float(-(p * p.log()).sum().item()) def main(): """Entry point — identical pattern to scripts/train.py.""" config = omegaconf_utils.get_config() os.makedirs(config.experiment.output_dir, exist_ok=True) accelerator = accelerate_utils.build_accelerator(config) accelerate_utils.build_wandb(config, accelerator=accelerator) logger = accelerate_utils.set_logger( config.experiment.output_dir, accelerator=accelerator ) device_seed = int(config.training.seed) + accelerator.process_index engine_utils.manual_seed(device_seed, (accelerator.device.index, device_seed)) if accelerator.is_main_process: config_path = os.path.join(config.experiment.output_dir, "config.yaml") omegaconf_utils.save_config(config, config_path) logger.info(f"Config:\n{omegaconf_utils.config_to_yaml(config)}") trainer = DistillTrainer(config, accelerator, logger) trainer.train_loop() if __name__ == "__main__": main()