World_Model / URSA /scripts /train_distill_dimo.py
BryanW's picture
Add files using upload-large-folder tool
2ee4cd6 verified
#!/usr/bin/env python3
# ------------------------------------------------------------------------
# Copyright (c) 2024-present, BAAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ------------------------------------------------------------------------
"""URSA one-step distillation trainer (DiMO-style), 8-GPU distributed.
Verified native inference regime (from A/B testing — ground truth):
height=320, width=512, num_frames=49, guidance_scale=7, teacher_steps=50.
no_cfg (guidance_scale=1) is NOT a valid baseline for this URSA checkpoint.
Defaults in configs/distill_dimo.yaml are aligned to this regime.
Launch command:
accelerate launch --config_file accelerate_configs/deepspeed_zero2.yaml \\
--machine_rank 0 --num_machines 1 --num_processes 8 \\
scripts/train_distill_dimo.py \\
config="./configs/distill_dimo.yaml" \\
experiment.output_dir="./experiments/distill_dimo" \\
distill.teacher_ckpt="/path/to/URSA-1.7B-IBQ1024" \\
distill.prompt_source="/data/Koala_36M_*.csv" \\
distill.batch_size_per_gpu=1
Smoke test (single-GPU, 50 steps):
accelerate launch --num_processes 1 \\
scripts/train_distill_dimo.py \\
config="./configs/distill_dimo.yaml" \\
experiment.output_dir="./experiments/smoke" \\
distill.teacher_ckpt="/path/to/URSA-1.7B-IBQ1024" \\
distill.prompt_source="prompts.txt" \\
training.max_train_steps=50
Algorithm summary (9 stages per iteration)
------------------------------------------
Stage 1 Tokenize → txt_ids [B, L] (CPU in worker, moved to GPU in run_step)
Stage 2 x_init ~ Uniform(K) (+ p_init mixing from x_hat_prev)
Stage 3 no_grad student(x_init) → x_hat [B, N], logp for PG
Stage 4 x_t = scheduler.add_noise(x_hat_4d, t) [B,T,H,W], long
Stage 5 no_grad teacher(x_t) → z_T_cond [B,N,K] (+ uncond if CFG)
Stage 6 aux update × fake_rounds: Jeffrey(z_T_target, z_A_cond).backward()
Stage 7 student KD forward on x_t → z_S_cond [B,N,K]
Stage 8 reward = -KL(z_T_cond, z_S_cond) [detached]; adv = reward - baseline_ema
Stage 9 Two-backward:
9a _no_sync_backward(lambda_kd * loss_kd) [frees KD graph]
9b accelerator.backward(lambda_pg * loss_pg - lambda_ent * H_mean)
opt_student.step()
"""
import collections
import copy
import os
import sys
from typing import Optional
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
if _ROOT not in sys.path:
sys.path.insert(0, _ROOT)
from diffnext.engine import engine_utils
from diffnext.engine.lr_scheduler import CosineLR
from diffnext.pipelines.ursa.pipeline_ursa_distill_dimo import (
URSADistillDiMOPipeline,
_get_logits,
_stable_kl,
_stable_jeffrey,
_build_guided_logits,
_cfg_warmup_prob,
_no_sync_backward,
_reset_flex_attn,
VERIFIED_NATIVE_DEFAULTS,
check_verified_regime,
)
from diffnext.utils import accelerate_utils
from diffnext.utils import omegaconf_utils
from diffnext.utils import profiler
from src.distill.prompt_dataset import (
CSVSpec,
InfiniteDataLoader,
PromptDataset,
make_collate_fn,
)
# ---------------------------------------------------------------------------
# DistillTwinModel — single nn.Module wrapping student + aux for DeepSpeed
# ---------------------------------------------------------------------------
class DistillTwinModel(torch.nn.Module):
"""Wrapper that holds both student and aux as sub-modules.
DeepSpeed (via Accelerate) only allows a single model in
``accelerator.prepare()``. This container satisfies that constraint
while keeping student and aux as separately addressable sub-modules
with independent param groups.
"""
def __init__(self, student: torch.nn.Module, aux: torch.nn.Module):
super().__init__()
self.student = student
self.aux = aux
def forward(self, which: str, input_ids, rope_pos=None, **kwargs):
if which == "student":
return self.student(input_ids, rope_pos=rope_pos, **kwargs)
elif which == "aux":
return self.aux(input_ids, rope_pos=rope_pos, **kwargs)
else:
raise ValueError(f"DistillTwinModel: unknown sub-model '{which}'")
# ---------------------------------------------------------------------------
# DistillTrainer
# ---------------------------------------------------------------------------
class DistillTrainer:
"""Training orchestrator for on-policy one-step distillation.
Reuses the same accelerate / logger / checkpoint API as
``diffnext.engine.train_engine.Trainer`` so the distributed setup is
identical to the original training framework.
Key differences from standard Trainer:
- Three models (teacher frozen, student + aux trainable)
- Student and aux are wrapped in a single ``DistillTwinModel`` so that
only one ``accelerator.prepare()`` call is needed (DeepSpeed requirement)
- One optimizer with two param_groups: [0]=student, [1]=aux
- LR schedulers for both param groups
- Two-backward strategy within each step
- PromptDataset (no video latents; prompt-only)
- Stage 6 freezes student / unfreezes aux; Stages 7-9 do the reverse
"""
def __init__(self, config, accelerator, logger):
self.config = config
self.accelerator = accelerator
self.logger = logger
cfg = config.distill
dtype = accelerate_utils.precision_to_dtype(config.training.mixed_precision)
self.device = accelerator.device
# -------- Pipeline (teacher + student + aux) ----------------------
logger.info(f"[init] Loading teacher from {cfg.teacher_ckpt} ...")
self.pipe = URSADistillDiMOPipeline(
teacher_ckpt=cfg.teacher_ckpt,
compute_dtype=dtype,
aux_noise_std=float(cfg.get("aux_noise_std", 0.0)),
)
# Move teacher to GPU (not prepared by accelerate — frozen).
self.pipe.teacher = self.pipe.teacher.to(self.device)
self.pipe.scheduler.to(device=self.device)
# Compute latents shape from video geometry.
from src.distill.utils_ursa_inputs import compute_latents_shape
# Read VAE strides from pipeline (falls back to URSA defaults 4/8).
vae_t = int(getattr(self.pipe, "vae_temporal_stride", 4))
vae_s = int(getattr(self.pipe, "vae_spatial_stride", 8))
self.latents_shape = compute_latents_shape(
cfg.num_frames, cfg.height, cfg.width, vae_t, vae_s
)
T, H, W = self.latents_shape
self.N = T * H * W
self.K = self.pipe.codebook_size
logger.info(
f"[init] latents_shape=({T},{H},{W}) N={self.N} K={self.K} "
f"CFG={'ON' if cfg.enable_teacher_cfg else 'OFF'}"
)
# Pre-compute uncond token IDs (empty string, [1, L]) on CPU.
self.txt_uncond_base_cpu = self.pipe.tokenizer(
[""],
max_length=int(cfg.max_prompt_length),
padding="max_length",
padding_side="left",
truncation=True,
return_tensors="pt",
).input_ids # [1, L] CPU
# -------- Optimizers (before accelerate.prepare) ------------------
# Single optimizer with two param groups:
# group[0] = student params, group[1] = aux params
opt_cls = torch.optim.AdamW
opt_s_params = dict(
lr=float(config.optimizer_student.params.lr),
betas=tuple(config.optimizer_student.params.get("betas", [0.9, 0.95])),
weight_decay=float(config.optimizer_student.params.get("weight_decay", 0.01)),
)
opt_a_params = dict(
lr=float(config.optimizer_aux.params.lr),
betas=tuple(config.optimizer_aux.params.get("betas", [0.9, 0.95])),
weight_decay=float(config.optimizer_aux.params.get("weight_decay", 0.01)),
)
def _enable_gcpt(m):
# m.model.layers 是 Qwen3Model 的层列表
for layer in m.model.layers:
layer.gradient_checkpointing = True
layer.self_attn.gradient_checkpointing = True
layer.mlp.gradient_checkpointing = True
_enable_gcpt(self.pipe.student)
_enable_gcpt(self.pipe.aux)
# -------- 断点续传:在 ZeRO-3 切分参数前加载权重 ------------------
self.global_step = int(config.experiment.get("resume_iter", 0))
if self.global_step > 0:
ckpt_dir = os.path.join(
config.experiment.output_dir, "checkpoints", f"checkpoint-{self.global_step}"
)
if os.path.exists(ckpt_dir):
logger.info(f"[Resume] 正在从 {ckpt_dir} 恢复 Student 和 Aux 的权重...")
# 必须在 map_location="cpu" 下加载,防止爆显存,随后 prepare 会自动分配
self.pipe.student.load_state_dict(torch.load(os.path.join(ckpt_dir, "student.pt"), map_location="cpu"))
self.pipe.aux.load_state_dict(torch.load(os.path.join(ckpt_dir, "aux.pt"), map_location="cpu"))
else:
logger.warning(f"[Resume] 找不到检查点 {ckpt_dir},将从随机初始状态起步!")
# -------- Wrap student + aux into a single DistillTwinModel --------
twin_model = DistillTwinModel(self.pipe.student, self.pipe.aux)
opt_raw = opt_cls([
{"params": list(self.pipe.student.parameters()), **opt_s_params},
{"params": list(self.pipe.aux.parameters()), **opt_a_params},
])
# -------- accelerate.prepare: single model + single optimizer ------
# Teacher is NOT prepared (frozen; no grad sync needed).
self.model, self.optimizer = accelerator.prepare(twin_model, opt_raw)
# LR schedulers (step() called manually at end of each step).
self.scheduler_s = CosineLR(
lr_max=float(config.optimizer_student.params.lr),
lr_min=float(config.lr_scheduler.params.get("lr_min", 1e-6)),
max_steps=int(config.training.max_train_steps),
warmup_steps=int(config.lr_scheduler.params.get("warmup_steps", 500)),
)
self.scheduler_a = CosineLR(
lr_max=float(config.optimizer_aux.params.lr),
lr_min=float(config.lr_scheduler.params.get("lr_min", 1e-6)),
max_steps=int(config.training.max_train_steps),
warmup_steps=int(config.lr_scheduler.params.get("warmup_steps", 500)),
)
# -------- Dataset / DataLoader ------------------------------------
dataloader_cfg = config.get("prompt_dataloader", {})
dataset = PromptDataset(
prompt_source=str(cfg.prompt_source),
shuffle_files=bool(dataloader_cfg.get("shuffle_files", True)),
shuffle_buffer=int(dataloader_cfg.get("shuffle_buffer", 0)),
seed=int(config.training.seed),
infinite=True,
csv=CSVSpec(caption_field=str(dataloader_cfg.get("caption_field", "caption"))),
)
# collate_fn: tokenize on CPU (no CUDA in workers).
collate_fn = make_collate_fn(
self.pipe.tokenizer,
max_prompt_length=int(cfg.max_prompt_length),
device=torch.device("cpu"), # CPU output — moved to GPU in run_step
)
loader = DataLoader(
dataset,
batch_size=int(cfg.batch_size_per_gpu),
shuffle=False, # IterableDataset: no shuffle flag
drop_last=True,
num_workers=int(dataloader_cfg.get("num_workers", 2)),
collate_fn=collate_fn,
pin_memory=True,
)
# DataLoader is NOT prepared by accelerate because PromptDataset
# handles per-rank file sharding internally via torch.distributed.
self._inf_loader = InfiniteDataLoader(loader)
# -------- Training state ------------------------------------------
# self.global_step = int(config.experiment.get("resume_iter", 0))
self.baseline_ema: float = 0.0
self.x_hat_prev: Optional[torch.Tensor] = None
self.metrics = collections.OrderedDict()
# -------- Verified regime validation --------------------------------
native = VERIFIED_NATIVE_DEFAULTS
is_native = check_verified_regime(
height=int(cfg.height),
width=int(cfg.width),
num_frames=int(cfg.num_frames),
guidance_scale=float(cfg.teacher_cfg_scale) if cfg.enable_teacher_cfg else None,
label="train",
)
logger.info(
f"[init] verified_native_regime={is_native} "
f"geometry=({cfg.num_frames}×{cfg.height}×{cfg.width}) "
f"teacher_cfg_scale={cfg.teacher_cfg_scale if cfg.enable_teacher_cfg else 'OFF'}"
)
if not cfg.enable_teacher_cfg:
logger.warning(
"[WARN] Teacher CFG is DISABLED. no_cfg is known to produce "
"blank/blurry output for this URSA checkpoint. "
"Distillation without CFG is unlikely to produce useful results."
)
elif float(cfg.teacher_cfg_scale) != native["guidance_scale"]:
logger.warning(
f"[WARN] teacher_cfg_scale={cfg.teacher_cfg_scale} differs from "
f"the verified working value ({native['guidance_scale']}). "
"Outputs may deviate from the official inference working point."
)
logger.info(
f"[init] student params: {engine_utils.count_params(self.pipe.student):.2f}M"
)
logger.info(
f"[init] max_train_steps={config.training.max_train_steps} "
f"batch_size_per_gpu={cfg.batch_size_per_gpu} "
f"num_processes={accelerator.num_processes}"
)
# -----------------------------------------------------------------------
# run_step: Stages 1-9
# -----------------------------------------------------------------------
def run_step(self, step: int) -> dict:
"""Execute one distillation step (Stages 1-9)."""
cfg = self.config.distill
T, H, W = self.latents_shape
N, K = self.N, self.K
device = self.device
stats = {"step": step}
timer = profiler.Timer().tic()
# Update LR from cosine schedulers.
# param_groups[0] = student, param_groups[1] = aux
lr_s = self.scheduler_s.get_lr()
lr_a = self.scheduler_a.get_lr()
stats["lr_student"] = lr_s
stats["lr_aux"] = lr_a
self.optimizer.param_groups[0]["lr"] = lr_s
self.optimizer.param_groups[1]["lr"] = lr_a
# ----------------------------------------------------------------
# Stage 1: Get tokenised batch (CPU → GPU)
# ----------------------------------------------------------------
txt_ids = next(self._inf_loader) # [B, L] CPU tensor
txt_ids = txt_ids.to(device, non_blocking=True)
B = txt_ids.size(0)
txt_uncond = None
if cfg.enable_teacher_cfg:
txt_uncond = self.txt_uncond_base_cpu.expand(B, -1).to(device)
# # ----------------------------------------------------------------
# # Stage 2: Sample x_init ~ Uniform(K) with optional p_init mixing
# # ----------------------------------------------------------------
# x_init = torch.randint(0, K, (B, T, H, W), device=device, dtype=torch.long)
# if self.x_hat_prev is not None and float(cfg.p_init_mix_ratio) > 0:
# n_mix = max(1, int(B * float(cfg.p_init_mix_ratio)))
# x_init[:n_mix] = self.pipe.corrupt_tokens(
# self.x_hat_prev[:n_mix], r=float(cfg.p_mix_corrupt_frac)
# )
# ----------------------------------------------------------------
# Stage 2: Sample x_init ~ Uniform(K) with optional p_init mixing
# ----------------------------------------------------------------
x_init = torch.randint(0, K, (B, T, H, W), device=device, dtype=torch.long)
# 修复:使用概率触发,确保小 Batch 时模型依然能充分学习处理纯噪声
if self.x_hat_prev is not None and float(cfg.p_init_mix_ratio) > 0:
if torch.rand(1).item() < float(cfg.p_init_mix_ratio):
# 如果触发,只混合 batch 里的第一个样本
x_init[0] = self.pipe.corrupt_tokens(
self.x_hat_prev[0:1], r=float(cfg.p_mix_corrupt_frac)
).squeeze(0)
# ----------------------------------------------------------------
# Stage 3: Student 1-step on x_init — no_grad (only sample x_hat)
#
# Gradient-enabled forward on x_init is deferred to Stage 9b so
# the KD computation graph (Stage 7, x_t) can be freed first.
# ----------------------------------------------------------------
with torch.no_grad():
ids_init, rpos_init, _ = self.pipe.build_inputs(
txt_ids, x_init, self.latents_shape
)
logits_s_init = _get_logits(
self.model("student", ids_init, rope_pos=rpos_init)
)
z_s = self.pipe.extract_logits(logits_s_init, N) # [B, N, K]
p_s = F.softmax(z_s / float(cfg.tau), dim=-1) # [B, N, K]
x_hat = torch.multinomial(p_s.view(-1, K), 1).view(B, N) # [B, N]
# if step == 1:
# # 只抽 8 个 token 做 sum=1 检查,别全量
# idx = torch.randint(0, N, (8,), device=device)
# p_err = (p_s[:, idx].sum(-1) - 1).abs().max().item()
# assert p_err < 1e-3, f"p_s subset not normalised: {p_err}"
del p_s, z_s, logits_s_init
x_hat_4d = x_hat.view(B, T, H, W)
# ----------------------------------------------------------------
# Stage 4: Pseudo-intermediate x_t = add_noise(x_hat, t)
# ----------------------------------------------------------------
t = self.pipe.sample_t_curriculum(
B, device, step, int(cfg.t_curriculum_steps)
) # [B] float ∈ (0.05, 0.995)
with torch.no_grad():
x_t = self.pipe.scheduler.add_noise(x_hat_4d, t) # [B,T,H,W] long
# # ----------------------------------------------------------------
# # Stage 5: Teacher forward — single [2B] forward when CFG enabled
# # ----------------------------------------------------------------
# with torch.no_grad():
# if cfg.enable_teacher_cfg:
# txt_dual = torch.cat([txt_ids, txt_uncond], dim=0) # [2B, L]
# x_t_dual = torch.cat([x_t, x_t], dim=0) # [2B,T,H,W]
# ids_dual, rpos_dual, _ = self.pipe.build_inputs(
# txt_dual, x_t_dual, self.latents_shape
# )
# logits_T_dual = _get_logits(
# self.pipe.teacher(ids_dual, rope_pos=rpos_dual)
# )
# z_T_dual = self.pipe.extract_logits(logits_T_dual, N) # [2B,N,K]
# z_T_cond, z_T_uncond = z_T_dual.chunk(2, dim=0) # [B,N,K]
# del logits_T_dual, z_T_dual
# torch.cuda.empty_cache()
# ids_t, rpos_t = ids_dual[:B], rpos_dual[:B]
# else:
# ids_t, rpos_t, _ = self.pipe.build_inputs(
# txt_ids, x_t, self.latents_shape
# )
# logits_T = _get_logits(
# self.pipe.teacher(ids_t, rope_pos=rpos_t)
# )
# z_T_cond = self.pipe.extract_logits(logits_T, N) # [B,N,K]
# z_T_uncond = None
# ids_dual, rpos_dual = ids_t, rpos_t
# # CFG guided target with per-sample Bernoulli warmup.
# z_T_guided = None
# use_guided_ratio = 0.0
# if cfg.enable_teacher_cfg:
# p_guided = _cfg_warmup_prob(
# step,
# float(cfg.teacher_cfg_prob),
# int(cfg.teacher_cfg_warmup_steps),
# )
# use_guided = torch.rand(B, device=device) < p_guided # [B] bool
# use_guided_ratio = float(use_guided.float().mean().item())
# z_T_guided = _build_guided_logits(
# z_T_cond, z_T_uncond,
# t, float(cfg.teacher_cfg_scale), float(cfg.teacher_cfg_trunc),
# )
# mask = use_guided.view(-1, 1, 1).expand_as(z_T_cond)
# z_T_target = torch.where(mask, z_T_guided, z_T_cond.float())
# else:
# z_T_target = z_T_cond
# z_T_target = z_T_target.detach() # NO grad path to teacher
# # # ----------------------------------------------------------------
# # # Stage 6: Aux update — fake_rounds iterations
# # #
# # # Freeze student so only aux gets gradients. With a single
# # # DeepSpeed-wrapped optimizer this is the cleanest way to ensure
# # # only aux params are updated.
# # # ----------------------------------------------------------------
# # raw_twin = self.accelerator.unwrap_model(self.model)
# # raw_twin.student.requires_grad_(False)
# # raw_twin.aux.requires_grad_(True)
# # loss_aux_cond_last = torch.tensor(0.0, device=device)
# # loss_aux_uncond_last = torch.tensor(0.0, device=device)
# # loss_aux_cond_sample_last = None
# # for _fr in range(int(cfg.fake_rounds)):
# # self.optimizer.zero_grad(set_to_none=True)
# # if cfg.enable_teacher_cfg:
# # logits_A_dual = _get_logits(
# # self.model("aux", ids_dual.detach(), rope_pos=rpos_dual.detach())
# # )
# # z_A_dual = self.pipe.extract_logits(logits_A_dual, N) # [2B,N,K]
# # z_A_cond, z_A_uncond = z_A_dual.chunk(2, dim=0)
# # loss_aux_cond_sample = _stable_jeffrey(
# # z_T_target, z_A_cond, float(cfg.tau_kd),chunk_size=1024
# # ) # [B]
# # loss_aux_cond_v = loss_aux_cond_sample.mean()
# # loss_aux_uncond_v = _stable_jeffrey(
# # z_T_uncond.float().detach(), z_A_uncond, float(cfg.tau_kd),chunk_size=1024
# # ).mean()
# # loss_aux_v = (
# # loss_aux_cond_v
# # + float(cfg.lambda_kd_uncond) * loss_aux_uncond_v
# # )
# # else:
# # logits_A = _get_logits(
# # self.model("aux", ids_t.detach(), rope_pos=rpos_t.detach())
# # )
# # z_A_cond = self.pipe.extract_logits(logits_A, N)
# # loss_aux_cond_sample = _stable_jeffrey(
# # z_T_target, z_A_cond, float(cfg.tau_kd),chunk_size=1024
# # )
# # loss_aux_cond_v = loss_aux_cond_sample.mean()
# # loss_aux_uncond_v = torch.tensor(0.0, device=device)
# # loss_aux_v = loss_aux_cond_v
# # self.accelerator.backward(loss_aux_v)
# # if float(cfg.grad_clip) > 0:
# # torch.nn.utils.clip_grad_norm_(
# # raw_twin.aux.parameters(), float(cfg.grad_clip)
# # )
# # self.optimizer.step()
# # self.optimizer.zero_grad(set_to_none=True)
# # loss_aux_cond_last = loss_aux_cond_v.detach()
# # loss_aux_uncond_last = loss_aux_uncond_v.detach()
# # loss_aux_cond_sample_last = loss_aux_cond_sample.detach() # [B]
# # # ----------------------------------------------------------------
# # # Stage 7: Student KD forward on x_t (with grad)
# # #
# # # Switch: freeze aux, unfreeze student for Stages 7-9.
# # # ----------------------------------------------------------------
# # raw_twin.student.requires_grad_(True)
# # raw_twin.aux.requires_grad_(False)
# # self.optimizer.zero_grad(set_to_none=True)
# # if cfg.enable_teacher_cfg:
# # logits_S_dual = _get_logits(
# # self.model("student", ids_dual.detach(), rope_pos=rpos_dual.detach())
# # )
# # z_S_dual = self.pipe.extract_logits(logits_S_dual, N)
# # z_S_cond, z_S_uncond = z_S_dual.chunk(2, dim=0)
# # loss_kd_cond = _stable_kl(
# # z_T_target, z_S_cond, float(cfg.tau_kd), chunk_size=2048
# # ).mean()
# # loss_kd_uncond = _stable_kl(
# # z_T_uncond.float().detach(), z_S_uncond, float(cfg.tau_kd), chunk_size=2048
# # ).mean()
# # loss_kd = loss_kd_cond + float(cfg.lambda_kd_uncond) * loss_kd_uncond
# # else:
# # logits_S = _get_logits(
# # self.model("student", ids_t.detach(), rope_pos=rpos_t.detach())
# # )
# # z_S_cond = self.pipe.extract_logits(logits_S, N)
# # loss_kd_cond = _stable_kl(
# # z_T_target, z_S_cond, float(cfg.tau_kd), chunk_size=2048
# # ).mean()
# # loss_kd_uncond = torch.tensor(0.0, device=device)
# # loss_kd = loss_kd_cond
# # # ----------------------------------------------------------------
# # # Stage 8: Reward + advantage (fully detached — no student grad)
# # #
# # # INVARIANT: reward and adv must never carry student gradients.
# # # ----------------------------------------------------------------
# # if cfg.enable_teacher_cfg and cfg.reward_use_guided:
# # z_T_for_rew = z_T_target # already detached
# # else:
# # z_T_for_rew = z_T_cond.detach()
# # # reward[b] = -KL(z_T_cond || z_S_cond) with BOTH inputs detached
# # with torch.no_grad():
# # reward = -_stable_kl(
# # z_T_for_rew.detach(), z_S_cond.detach(), float(cfg.tau), chunk_size=1024
# # ) # [B]
# # assert not reward.requires_grad, (
# # "[BUG] reward.requires_grad=True — student grad leaked into reward. "
# # "z_S_cond must be detached before KL for reward."
# # )
# # self.baseline_ema = (
# # 0.99 * self.baseline_ema + 0.01 * float(reward.mean().item())
# # )
# # adv = (reward - self.baseline_ema).detach() # [B]
# # assert not adv.requires_grad, "[BUG] adv.requires_grad=True"
# # # ----------------------------------------------------------------
# # # Stage 9: Two-backward student update
# # #
# # # 9a) KD backward first — frees the KD graph to save memory.
# # # Uses no_sync() (no DDP all-reduce) so gradients are not
# # # double-reduced when the PG backward syncs in 9b.
# # # 9b) Fresh forward on x_init WITH grad → PG + entropy backward.
# # # DDP all-reduce happens here (normal backward).
# # # ----------------------------------------------------------------
# # # 9a: KD backward (no sync — first of two backwards)
# # _no_sync_backward(
# # self.accelerator, self.model, float(cfg.lambda_kd) * loss_kd
# # )
# # # 9b: Policy + entropy — fresh forward on x_init WITH grad
# # ids_init, rpos_init, _ = self.pipe.build_inputs(
# # txt_ids, x_init, self.latents_shape
# # )
# # logits_s_pol = _get_logits(
# # self.model("student", ids_init, rope_pos=rpos_init)
# # )
# # z_s_pol = self.pipe.extract_logits(logits_s_pol, N) # [B, N, K]
# # logp_tok = F.log_softmax(z_s_pol / float(cfg.tau), dim=-1) # [B, N, K]
# # p_s_pol = logp_tok.exp()
# # # per-token average log-prob (recommended over log-prob sum)
# # logp_sum = (
# # logp_tok.gather(-1, x_hat.unsqueeze(-1)).squeeze(-1).sum(-1)
# # ) # [B]
# # logp = logp_sum / N # [B] per-token logp
# # H_mean = -(p_s_pol * logp_tok).sum(-1).mean()
# # loss_pg = -(adv * logp).mean()
# # lambda_ent_eff = float(cfg.lambda_ent) * (1.0 + 2.0 * use_guided_ratio)
# # # Second backward: DDP all-reduce happens here.
# # self.accelerator.backward(
# # float(cfg.lambda_pg) * loss_pg - lambda_ent_eff * H_mean
# # )
# # if float(cfg.grad_clip) > 0:
# # torch.nn.utils.clip_grad_norm_(
# # raw_twin.student.parameters(), float(cfg.grad_clip)
# # )
# # self.optimizer.step()
# # # Restore both sub-modules to trainable for next step.
# # raw_twin.student.requires_grad_(True)
# # raw_twin.aux.requires_grad_(True)
# # # p_init mixing: store x_hat_4d (detached) for next step.
# # self.x_hat_prev = x_hat_4d.detach()
# # ----------------------------------------------------------------
# # Stage 6: Aux update — Fit sampled pseudo-target (x_hat) from student
# # ----------------------------------------------------------------
# raw_twin = self.accelerator.unwrap_model(self.model)
# raw_twin.student.requires_grad_(False)
# raw_twin.aux.requires_grad_(True)
# target_tokens = x_hat.detach() # [B, N] - 学生在 Stage 3 盲猜出来的画面
# for _fr in range(int(cfg.fake_rounds)):
# self.optimizer.zero_grad(set_to_none=True)
# if cfg.enable_teacher_cfg:
# logits_A_dual = _get_logits(
# self.model("aux", ids_dual.detach(), rope_pos=rpos_dual.detach())
# )
# z_A_dual = self.pipe.extract_logits(logits_A_dual, N) # [2B,N,K]
# z_A_cond, z_A_uncond = z_A_dual.chunk(2, dim=0)
# # Aux 拟合学生的假 token (Cross Entropy)
# loss_aux_cond_v = F.cross_entropy(
# z_A_cond.reshape(B * N, K),
# target_tokens.reshape(B * N),
# reduction="mean",
# )
# loss_aux_uncond_v = F.cross_entropy(
# z_A_uncond.reshape(B * N, K),
# target_tokens.reshape(B * N),
# reduction="mean",
# )
# loss_aux_v = loss_aux_cond_v + float(cfg.lambda_kd_uncond) * loss_aux_uncond_v
# else:
# logits_A = _get_logits(
# self.model("aux", ids_t.detach(), rope_pos=rpos_t.detach())
# )
# z_A_cond = self.pipe.extract_logits(logits_A, N)
# loss_aux_cond_v = F.cross_entropy(
# z_A_cond.reshape(B * N, K),
# target_tokens.reshape(B * N),
# reduction="mean",
# )
# loss_aux_uncond_v = torch.tensor(0.0, device=device)
# loss_aux_v = loss_aux_cond_v
# self.accelerator.backward(loss_aux_v)
# if float(cfg.grad_clip) > 0:
# torch.nn.utils.clip_grad_norm_(
# raw_twin.aux.parameters(), float(cfg.grad_clip)
# )
# self.optimizer.step()
# loss_aux_cond_last = loss_aux_cond_v.detach()
# # ----------------------------------------------------------------
# # Stage 7 & 8: Student KD update & Aux Bridge (Gradient Injection)
# # ----------------------------------------------------------------
# raw_twin.student.requires_grad_(True)
# raw_twin.aux.requires_grad_(False)
# self.optimizer.zero_grad(set_to_none=True)
# # 7a. Student KD forward on x_t (保持原样)
# if cfg.enable_teacher_cfg:
# logits_S_dual = _get_logits(
# self.model("student", ids_dual.detach(), rope_pos=rpos_dual.detach())
# )
# z_S_dual = self.pipe.extract_logits(logits_S_dual, N)
# z_S_cond, z_S_uncond = z_S_dual.chunk(2, dim=0)
# # --- [新增] 立刻释放显存 ---
# del logits_S_dual, z_S_dual
# loss_kd_cond = _stable_kl(
# z_T_target, z_S_cond, float(cfg.tau_kd), chunk_size=256 #2048
# ).mean()
# loss_kd_uncond = _stable_kl(
# z_T_uncond.float().detach(), z_S_uncond, float(cfg.tau_kd), chunk_size=256 #2048
# ).mean()
# loss_kd = loss_kd_cond + float(cfg.lambda_kd_uncond) * loss_kd_uncond
# else:
# logits_S = _get_logits(
# self.model("student", ids_t.detach(), rope_pos=rpos_t.detach())
# )
# z_S_cond = self.pipe.extract_logits(logits_S, N)
# loss_kd_cond = _stable_kl(
# z_T_target, z_S_cond, float(cfg.tau_kd), chunk_size=256 #2048
# ).mean()
# loss_kd_uncond = torch.tensor(0.0, device=device)
# loss_kd = loss_kd_cond
# # 7b. 获取 Aux 的预测 (无梯度) 作为计算桥梁
# with torch.no_grad():
# if cfg.enable_teacher_cfg:
# logits_A_dual = _get_logits(
# self.model("aux", ids_dual.detach(), rope_pos=rpos_dual.detach())
# )
# z_A_dual = self.pipe.extract_logits(logits_A_dual, N)
# z_A_cond, _ = z_A_dual.chunk(2, dim=0)
# # --- [新增] 立刻释放显存 ---
# del logits_A_dual, z_A_dual
# else:
# logits_A = _get_logits(
# self.model("aux", ids_t.detach(), rpos_t.detach())
# )
# z_A_cond = self.pipe.extract_logits(logits_A, N)
# # 8. Student 对初始噪声 x_init 进行带梯度的前向传播
# ids_init, rpos_init, _ = self.pipe.build_inputs(
# txt_ids, x_init, self.latents_shape
# )
# logits_s_pol = _get_logits(
# self.model("student", ids_init, rope_pos=rpos_init)
# )
# z_s_pol = self.pipe.extract_logits(logits_s_pol, N)
# # --- 核心数学修正:将 Logits 转换为概率,防止梯度爆炸 ---
# p_T = F.softmax(z_T_target / float(cfg.tau_kd), dim=-1)
# p_A = F.softmax(z_A_cond / float(cfg.tau_kd), dim=-1)
# # 目标方向:Teacher 概率 - Aux 概率 (遵循论文公式推导)
# bridge_target = (p_T - p_A).detach()
# # 利用 MSE Trick 强制注入梯度
# loss_bridge = 0.5 * F.mse_loss(
# z_s_pol.float(),
# (z_s_pol.float() + bridge_target).detach()
# )
# # 9. 单次反向传播 (合并 KD 和 Bridge)
# # 借用原来的 lambda_pg 参数来控制 bridge 损失的权重
# loss_student = float(cfg.lambda_kd) * loss_kd + float(cfg.lambda_pg) * loss_bridge
# self.accelerator.backward(loss_student)
# if float(cfg.grad_clip) > 0:
# torch.nn.utils.clip_grad_norm_(
# raw_twin.student.parameters(), float(cfg.grad_clip)
# )
# self.optimizer.step()
# # 恢复两者的可训练状态
# raw_twin.student.requires_grad_(True)
# raw_twin.aux.requires_grad_(True)
# # --- 兼容原始日志输出的占位符 ---
# H_mean = torch.tensor(0.0, device=device)
# loss_pg = loss_bridge.detach() # 将 bridge 损失映射给 pg 显示
# logp = torch.tensor(0.0, device=device)
# self.baseline_ema = 0.0
# ----------------------------------------------------------------
# Stage 5: Teacher forward — 破除视图死锁,生成目标后立刻释放
# ----------------------------------------------------------------
with torch.no_grad():
if cfg.enable_teacher_cfg:
txt_dual = torch.cat([txt_ids, txt_uncond], dim=0) # [2B, L]
x_t_dual = torch.cat([x_t, x_t], dim=0) # [2B,T,H,W]
ids_dual, rpos_dual, _ = self.pipe.build_inputs(
txt_dual, x_t_dual, self.latents_shape
)
logits_T_dual = _get_logits(
self.pipe.teacher(ids_dual, rope_pos=rpos_dual)
)
z_T_dual = self.pipe.extract_logits(logits_T_dual, N) # [2B,N,K]
# 【显存救星 1】使用 .clone() 打断视图依赖,使得原始巨型张量可以被回收
z_T_cond = z_T_dual[0:1].clone() # [1,N,K]
z_T_uncond = z_T_dual[1:2].clone() # [1,N,K]
ids_t, rpos_t = ids_dual[:B], rpos_dual[:B]
# 立刻释放 17 GB 的双路缓冲
del logits_T_dual, z_T_dual
torch.cuda.empty_cache()
else:
ids_t, rpos_t, _ = self.pipe.build_inputs(txt_ids, x_t, self.latents_shape)
logits_T = _get_logits(self.pipe.teacher(ids_t, rope_pos=rpos_t))
z_T_cond = self.pipe.extract_logits(logits_T, N)
z_T_uncond = None
# 计算 CFG guided target
z_T_guided = None
use_guided_ratio = 0.0
if cfg.enable_teacher_cfg:
p_guided = _cfg_warmup_prob(step, float(cfg.teacher_cfg_prob), int(cfg.teacher_cfg_warmup_steps))
use_guided = torch.rand(B, device=device) < p_guided
use_guided_ratio = float(use_guided.float().mean().item())
z_T_guided = _build_guided_logits(
z_T_cond, z_T_uncond,
t, float(cfg.teacher_cfg_scale), float(cfg.teacher_cfg_trunc),
)
mask = use_guided.view(-1, 1, 1).expand_as(z_T_cond)
# 【显存救星 2】保持为 bf16 类型,避免膨胀到 8.5GB
z_T_target = torch.where(mask, z_T_guided, z_T_cond).to(dtype=z_T_cond.dtype).detach()
# 立刻清理所有中间推导变量
del z_T_cond, z_T_uncond, z_T_guided
torch.cuda.empty_cache()
else:
z_T_target = z_T_cond.detach()
# ----------------------------------------------------------------
# Stage 6: Aux update — 【显存救星 3】强行降维为单路前向传播 (Batch=1)
# ----------------------------------------------------------------
raw_twin = self.accelerator.unwrap_model(self.model)
raw_twin.student.requires_grad_(False)
raw_twin.aux.requires_grad_(True)
target_tokens = x_hat.detach()
for _fr in range(int(cfg.fake_rounds)):
self.optimizer.zero_grad(set_to_none=True)
# 只处理单路 ids_t,不处理 dual,砍掉 Aux 50% 显存!
logits_A = _get_logits(
self.model("aux", ids_t.detach(), rope_pos=rpos_t.detach())
)
z_A_cond = self.pipe.extract_logits(logits_A, N)
loss_aux_cond_v = F.cross_entropy(
z_A_cond.reshape(B * N, K),
target_tokens.reshape(B * N),
reduction="mean",
)
self.accelerator.backward(loss_aux_cond_v)
if float(cfg.grad_clip) > 0:
torch.nn.utils.clip_grad_norm_(raw_twin.aux.parameters(), float(cfg.grad_clip))
self.optimizer.step()
# 必须立刻释放
del logits_A, z_A_cond
torch.cuda.empty_cache()
loss_aux_cond_last = loss_aux_cond_v.detach()
# ----------------------------------------------------------------
# Stage 7 & 8: Student KD update & Aux Bridge
# ----------------------------------------------------------------
raw_twin.student.requires_grad_(True)
raw_twin.aux.requires_grad_(False)
self.optimizer.zero_grad(set_to_none=True)
# 7a. Student KD (强行降维为单路前向传播 Batch=1)
logits_S = _get_logits(
self.model("student", ids_t.detach(), rope_pos=rpos_t.detach())
)
z_S_cond = self.pipe.extract_logits(logits_S, N)
# 使用 128 chunk size,确保极致安全
loss_kd = _stable_kl(
z_T_target, z_S_cond, float(cfg.tau_kd), chunk_size=128
).mean()
del logits_S, z_S_cond
torch.cuda.empty_cache()
# 7b. 获取 Aux 的预测作为桥梁 (依然单路)
with torch.no_grad():
logits_A = _get_logits(
self.model("aux", ids_t.detach(), rope_pos=rpos_t.detach())
)
z_A_cond = self.pipe.extract_logits(logits_A, N)
# 8. Student 对 x_init 进行前向传播
ids_init, rpos_init, _ = self.pipe.build_inputs(txt_ids, x_init, self.latents_shape)
logits_s_pol = _get_logits(
self.model("student", ids_init, rope_pos=rpos_init)
)
z_s_pol = self.pipe.extract_logits(logits_s_pol, N)
# 【显存救星 4】在 bf16 精度下计算 Softmax 概率,防止 float32 炸存
p_T = F.softmax(z_T_target / float(cfg.tau_kd), dim=-1).to(z_s_pol.dtype)
p_A = F.softmax(z_A_cond / float(cfg.tau_kd), dim=-1).to(z_s_pol.dtype)
bridge_target = (p_T - p_A).detach()
# 拿到 bridge_target 后,前面所有百兆甚至 G 级的张量统统干掉
del p_T, p_A, logits_A, z_A_cond, z_T_target
torch.cuda.empty_cache()
# 伪梯度注入
loss_bridge = 0.5* K * F.mse_loss(
z_s_pol.float(),
(z_s_pol.float() + bridge_target.float()).detach()
)
# 9. 统一反向传播
loss_student = float(cfg.lambda_kd) * loss_kd + float(cfg.lambda_pg) * loss_bridge
self.accelerator.backward(loss_student)
if float(cfg.grad_clip) > 0:
torch.nn.utils.clip_grad_norm_(raw_twin.student.parameters(), float(cfg.grad_clip))
self.optimizer.step()
# 恢复状态
raw_twin.student.requires_grad_(True)
raw_twin.aux.requires_grad_(True)
# 最后的清理
del logits_s_pol, z_s_pol, bridge_target
torch.cuda.empty_cache()
H_mean = torch.tensor(0.0, device=device)
loss_pg = loss_bridge.detach()
logp = torch.tensor(0.0, device=device)
self.baseline_ema = 0.0
# Advance LR schedulers.
self.scheduler_s.step()
self.scheduler_a.step()
# ----------------------------------------------------------------
# Step 1 sanity assertions (lightweight; runs only at step 1)
# ----------------------------------------------------------------
# if step == 1:
# self._step1_assertions(
# x_init, ids_init, rpos_init, z_s, p_s, logp,
# z_T_cond, z_S_cond, x_t, B, T, H, W,
# )
# ----------------------------------------------------------------
# Token-level collapse detection
# ----------------------------------------------------------------
tok_entropy = self._token_entropy(x_hat)
if not hasattr(self, "_init_tok_entropy"):
self._init_tok_entropy = tok_entropy
collapse_frac = float(cfg.get("collapse_warn_frac", 0.2))
if tok_entropy < collapse_frac * self._init_tok_entropy:
self.logger.warning(
f"[COLLAPSE] step={step} tok_H={tok_entropy:.3f} "
f"init={self._init_tok_entropy:.3f} "
f"ratio={tok_entropy / max(self._init_tok_entropy, 1e-8):.2f} "
f"< {collapse_frac}. Try increasing lambda_ent."
)
stats["time"] = timer.toc()
stats["metrics"] = collections.OrderedDict(
sorted(
{
"loss_aux_cond": float(loss_aux_cond_last.item()),
"loss_kd_cond": float(loss_kd.item()),
"loss_pg": float(loss_pg.item()),
"H_mean": float(H_mean.item()),
"tok_entropy": float(tok_entropy),
"mean_logp_tok": float(logp.mean().item()),
"baseline_ema": float(self.baseline_ema),
"use_guided_ratio": float(use_guided_ratio),
}.items()
)
)
return stats
# -----------------------------------------------------------------------
# Train loop
# -----------------------------------------------------------------------
def train_loop(self):
"""Main training loop (mirrors diffnext.engine.train_engine.Trainer)."""
cfg_exp = self.config.experiment
max_steps = int(self.config.training.max_train_steps)
log_every = int(cfg_exp.log_every)
save_every = int(cfg_exp.save_every)
self.global_step = int(self.config.experiment.get("resume_iter", 0))
# Sync LR schedulers to resume step (set _step_count directly;
# CosineLR uses _step_count internally in get_decay()).
self.scheduler_s._step_count = self.global_step
self.scheduler_a._step_count = self.global_step
# [可选补充] 如果是续传,让 accelerator 自动恢复被切分的 Optimizer 等状态
if self.global_step > 0:
ckpt_dir = os.path.join(self.config.experiment.output_dir, "checkpoints", f"checkpoint-{self.global_step}")
if os.path.exists(ckpt_dir):
self.accelerator.load_state(ckpt_dir)
self.logger.info(f"✅ ZeRO-3 完整状态 (包含 Optimizer) 已从 {ckpt_dir} 恢复")
timer = profiler.Timer()
self.logger.info(
f"[train] Starting from step {self.global_step} / {max_steps}"
)
while self.global_step < max_steps:
self.global_step += 1
with timer.tic_and_toc():
stats = self.run_step(self.global_step)
self._add_metrics(stats)
if self.global_step % log_every == 0:
self._log_metrics(stats)
if self.global_step % (10 * log_every) == 0:
self.logger.info(
profiler.get_progress(timer, self.global_step, max_steps)
)
if self.global_step % save_every == 0:
self.save(self.global_step)
# Final log + save (only when loop ran at least one step).
if self.global_step > int(self.config.experiment.get("resume_iter", 0)):
self._log_metrics({**stats, "step": self.global_step}) # noqa: F821
self.accelerator.wait_for_everyone()
self.save(self.global_step, suffix="final")
self.accelerator.end_training()
# -----------------------------------------------------------------------
# Checkpoint helpers
# -----------------------------------------------------------------------
# def save(self, step: int, suffix: str = None) -> None:
# """Save student + aux state_dicts (rank0 only).
# Saved as:
# <output_dir>/checkpoints/checkpoint-<step>/student.pt
# <output_dir>/checkpoints/checkpoint-<step>/aux.pt
# The student.pt can be used for inference by replacing the
# transformer weights in a URSAPipeline (see README).
# """
# if not self.accelerator.is_main_process:
# return
# folder = f"checkpoint-{suffix}" if suffix else f"checkpoint-{step}"
# ckpt_dir = os.path.join(
# self.config.experiment.output_dir, "checkpoints", folder
# )
# os.makedirs(ckpt_dir, exist_ok=True)
# raw_student = self.accelerator.unwrap_model(self.model).student
# raw_aux = self.accelerator.unwrap_model(self.model).aux
# student_path = os.path.join(ckpt_dir, "student.pt")
# aux_path = os.path.join(ckpt_dir, "aux.pt")
# torch.save(raw_student.state_dict(), student_path)
# torch.save(raw_aux.state_dict(), aux_path)
# # Also save training state for resuming.
# state = {
# "global_step": step,
# "baseline_ema": self.baseline_ema,
# "optimizer": self.optimizer.state_dict(),
# }
# torch.save(state, os.path.join(ckpt_dir, "train_state.pt"))
# self.logger.info(f"[save] step={step} → {ckpt_dir}")
def save(self, step: int, suffix: str = None) -> None:
"""Save student + aux state_dicts (支持 DeepSpeed ZeRO-3 自动聚合)."""
# ⚠️ 【极其重要】:get_state_dict 必须由所有 8 张卡共同执行!
# 绝对不能把它放在 is_main_process 判断的里面,否则会触发跨卡死锁!
full_state_dict = self.accelerator.get_state_dict(self.model)
# 只有主进程(0号卡)负责把聚合好的完整参数写进硬盘
if not self.accelerator.is_main_process:
return
folder = f"checkpoint-{suffix}" if suffix else f"checkpoint-{step}"
ckpt_dir = os.path.join(
self.config.experiment.output_dir, "checkpoints", folder
)
os.makedirs(ckpt_dir, exist_ok=True)
# 从 TwinModel 的完整字典中,根据前缀拆分出 student 和 aux 的独立权重
student_state = {k.replace("student.", ""): v for k, v in full_state_dict.items() if k.startswith("student.")}
aux_state = {k.replace("aux.", ""): v for k, v in full_state_dict.items() if k.startswith("aux.")}
student_path = os.path.join(ckpt_dir, "student.pt")
aux_path = os.path.join(ckpt_dir, "aux.pt")
torch.save(student_state, student_path)
torch.save(aux_state, aux_path)
# 保存辅助状态
state = {
"global_step": step,
"baseline_ema": self.baseline_ema,
}
torch.save(state, os.path.join(ckpt_dir, "train_state.pt"))
self.logger.info(f"[save] step={step}{ckpt_dir} (ZeRO-3 Gathered)")
# -----------------------------------------------------------------------
# Logging helpers (same API as original Trainer)
# -----------------------------------------------------------------------
def _add_metrics(self, stats: dict) -> None:
for k, v in stats["metrics"].items():
if k not in self.metrics:
self.metrics[k] = profiler.SmoothedValue()
self.metrics[k].update(v)
def _log_metrics(self, stats: dict) -> None:
iter_template = "Iteration %d, lr_s=%.2e lr_a=%.2e, time=%.2fs"
self.logger.info(
iter_template
% (
stats["step"],
stats.get("lr_student", 0.0),
stats.get("lr_aux", 0.0),
stats.get("time", 0.0),
)
)
metric_template = " Train %s: %s"
for k, v in self.metrics.items():
self.logger.info(metric_template % (k, v))
tracker_logs = {k: v.median for k, v in self.metrics.items()}
tracker_logs.update(
{
"lr_student": stats.get("lr_student", 0.0),
"time": stats.get("time", 0.0),
}
)
self.accelerator.log(tracker_logs, step=stats["step"])
self.metrics.clear()
# -----------------------------------------------------------------------
# Sanity checks (step 1 only)
# -----------------------------------------------------------------------
def _step1_assertions(
self, x_init, ids_init, rpos_init, z_s, p_s, logp,
z_T_cond, z_S_cond, x_t, B, T, H, W,
) -> None:
"""Shape / value-domain assertions (mirrors single-card script)."""
N, K = self.N, self.K
lm_vocab = self.pipe.teacher.config.lm_vocab_size
L_plus_N1 = ids_init.size(1)
txt_len = L_plus_N1 - (N + 1)
assert x_init.dtype == torch.long
assert x_init.min() >= 0 and x_init.max() < K
assert ids_init.shape == (B, L_plus_N1), ids_init.shape
txt_part = ids_init[:, :txt_len]
vis_part = ids_init[:, -N:]
assert (txt_part < lm_vocab).all(), "text tokens in visual range"
assert (vis_part >= lm_vocab).all(), "visual tokens not shifted"
assert (vis_part < lm_vocab + K).all(), "visual tokens exceed lm_vocab+K"
assert rpos_init.shape == (B, L_plus_N1, 3), rpos_init.shape
assert z_s.shape == (B, N, K), z_s.shape
p_err = float((p_s.sum(-1) - 1).abs().max().item())
assert p_err < 1e-3, f"p_s not normalised: max_dev={p_err:.2e}"
assert not torch.isnan(logp).any(), "logp has NaN"
assert not torch.isinf(logp).any(), "logp has Inf"
assert x_t.min() >= 0 and x_t.max() < K
assert z_T_cond.shape == z_S_cond.shape == (B, N, K), (
f"z_T_cond={z_T_cond.shape} z_S_cond={z_S_cond.shape}"
)
# Teacher has no grad.
teacher_grads = [
p for p in self.pipe.teacher.parameters() if p.grad is not None
]
assert len(teacher_grads) == 0, "teacher has grads — not frozen"
# Student has grad (from PG backward).
raw_s = self.accelerator.unwrap_model(self.model).student
student_grad_norms = [
float(p.grad.norm().item())
for p in raw_s.parameters()
if p.grad is not None
]
assert len(student_grad_norms) > 0, "student has NO grads — grad flow broken"
# ##########################
# raw_t = self.pipe.teacher
# raw_s = self.accelerator.unwrap_model(self.model).student
# # (a) 共享存储检查:零开销
# pt0 = next(raw_t.parameters())
# ps0 = next(raw_s.parameters())
# self.logger.info(f"[assert] shared_storage={pt0.data_ptr() == ps0.data_ptr()}")
# # (b) 参数差异:只采样前 4096 个元素,避免巨型临时张量
# with torch.no_grad():
# a = pt0.view(-1)[:4096].float()
# b = ps0.view(-1)[:4096].float()
# self.logger.info(f"[assert] param_delta_sample_max={float((a-b).abs().max().item()):.3e}")
# # (c) logits 差异:只采样小子块(64 token × 256 vocab)
# with torch.no_grad():
# idx_n = torch.randint(0, self.N, (64,), device=z_T_cond.device)
# idx_k = torch.randint(0, self.K, (256,), device=z_T_cond.device)
# subT = z_T_cond[0, idx_n][:, idx_k].float()
# subS = z_S_cond[0, idx_n][:, idx_k].float()
# self.logger.info(f"[assert] logits_delta_sub_max={float((subT-subS).abs().max().item()):.3e}")
# ###########################
self.logger.info("[assert] Step-1 shape/grad assertions PASSED ✓")
self.logger.info(
f"[assert] z_T_cond shape={z_T_cond.shape} "
f"min={float(z_T_cond.min().item()):.3f} "
f"max={float(z_T_cond.max().item()):.3f}"
)
self.logger.info(
f"[assert] z_S_cond shape={z_S_cond.shape} "
f"min={float(z_S_cond.min().item()):.3f} "
f"max={float(z_S_cond.max().item()):.3f}"
)
@staticmethod
def _token_entropy(x_hat: torch.Tensor) -> float:
"""Histogram entropy of sampled token indices (collapse detection)."""
counts = x_hat.flatten().bincount(minlength=1).float()
p = counts / counts.sum()
p = p[p > 0]
return float(-(p * p.log()).sum().item())
def main():
"""Entry point — identical pattern to scripts/train.py."""
config = omegaconf_utils.get_config()
os.makedirs(config.experiment.output_dir, exist_ok=True)
accelerator = accelerate_utils.build_accelerator(config)
accelerate_utils.build_wandb(config, accelerator=accelerator)
logger = accelerate_utils.set_logger(
config.experiment.output_dir, accelerator=accelerator
)
device_seed = int(config.training.seed) + accelerator.process_index
engine_utils.manual_seed(device_seed, (accelerator.device.index, device_seed))
if accelerator.is_main_process:
config_path = os.path.join(config.experiment.output_dir, "config.yaml")
omegaconf_utils.save_config(config, config_path)
logger.info(f"Config:\n{omegaconf_utils.config_to_yaml(config)}")
trainer = DistillTrainer(config, accelerator, logger)
trainer.train_loop()
if __name__ == "__main__":
main()