# ============================================================================ # URSA one-step distillation — DiMO-style distributed training config # ============================================================================ # Verified native inference regime (from A/B testing — ground truth): # height=320, width=512, num_frames=49, guidance_scale=7, teacher_steps=50. # no_cfg (guidance_scale=1) does NOT produce valid output. # All defaults below align to this verified regime. # # Launch (8-GPU, single node): # # accelerate launch --config_file accelerate_configs/deepspeed_zero2.yaml \ # --machine_rank 0 --num_machines 1 --num_processes 8 \ # scripts/train_distill_dimo.py \ # config="./configs/distill_dimo.yaml" \ # experiment.output_dir="./experiments/distill_dimo" \ # distill.teacher_ckpt="/path/to/URSA-1.7B-IBQ1024" \ # distill.prompt_source="/data/Koala_36M_*.csv" # # Smoke test (1 GPU, 50 steps — save student checkpoint): # # accelerate launch --num_processes 1 \ # scripts/train_distill_dimo.py \ # config="./configs/distill_dimo.yaml" \ # experiment.output_dir="./experiments/smoke" \ # distill.teacher_ckpt="/path/to/URSA-1.7B-IBQ1024" \ # distill.prompt_source="prompts.txt" \ # training.max_train_steps=50 \ # experiment.save_every=50 # # Load student for 1-step inference (must use CFG=7, native geometry): # # pipe = URSAPipeline.from_pretrained("/path/to/URSA-1.7B-IBQ1024") # state = torch.load("experiments/distill_dimo/checkpoints/final/student.pt") # pipe.transformer.load_state_dict(state, strict=True) # frames = pipe(prompt="...", num_inference_steps=1, # height=320, width=512, num_frames=49, # guidance_scale=7).frames # ============================================================================ # ── Experiment bookkeeping ─────────────────────────────────────────────────── experiment: name: distill_dimo output_dir: ./experiments/distill_dimo log_every: 10 save_every: 100 resume_iter: 0 # set to step number to resume # ── Training (framework-level) ─────────────────────────────────────────────── training: seed: 42 mixed_precision: bf16 # bf16 | fp16 | fp32 max_train_steps: 10000 gradient_accumulation_steps: 1 # Two-backward; keep =1 for distillation # ── Distillation hyperparameters ───────────────────────────────────────────── distill: # ---- Paths ---------------------------------------------------------------- teacher_ckpt: /gfs/space/private/fengzl/World_Model/URSA-1.7B prompt_source: /gfs/space/private/fengzl/World_Model/Koala-36M-v1 # glob, dir, .txt, or comma-list # ---- Video geometry (verified native: 320×512×49) ------------------------- num_frames: 49 height: 320 width: 512 max_prompt_length: 320 # ---- Data ----------------------------------------------------------------- batch_size_per_gpu: 1 # effective global batch = batch_size_per_gpu × 8 GPUs # # ---- Loss weights --------------------------------------------------------- # lambda_kd: 0.5 # KL(z_T || z_S) weight # lambda_pg: 1.0 # REINFORCE policy gradient weight # lambda_ent: 0.01 # entropy bonus (λ_ent_eff × H) — set 0 for DiMO orig # tau: 1.0 # student sampling temperature # tau_kd: 1.0 # KD / Jeffrey softmax temperature # # ---- Teacher CFG (aligned to verified working regime: CFG=7) --------------- # # A/B testing confirmed: guidance_scale=1 (no_cfg) does NOT produce valid # # output for this URSA checkpoint. The teacher KD target must use CFG=7. # enable_teacher_cfg: true # teacher_cfg_scale: 7.0 # s in z_guided = z_uncond + s*(z_cond-z_uncond) # # Verified: CFG=7 is the official working value. # teacher_cfg_prob: 1.0 # max fraction of samples using guided target # teacher_cfg_warmup_steps: 2000 # linear warmup 0→teacher_cfg_prob # teacher_cfg_trunc: 0.9 # when t≥trunc, scale falls to 1 (no guide) # lambda_kd_uncond: 0.3 # weight for uncond-branch KD loss # reward_use_guided: false # [RISKY] use guided logits for reward signal # # ---- DiMO extensions ------------------------------------------------------- # fake_rounds: 1 # aux updates per student update (DiMO=2; try 2) # use_surrogate_grad: false # lambda_surr: 1.0 # ---- Loss weights --------------------------------------------------------- lambda_kd: 1.0 # KL(z_T || z_S) weight (基础知识蒸馏权重,保持不变) lambda_pg: 1.0 # [重用] 现在代表 lambda_bridge,控制 MSE 伪梯度注入的强度 lambda_ent: 0.0 # [已废弃] 强化学习的熵奖励已彻底删除,设为 0.0 tau: 1.0 # student sampling temperature tau_kd: 1.0 # KD softmax temperature # ---- Teacher CFG (aligned to verified working regime: CFG=7) --------------- enable_teacher_cfg: true teacher_cfg_scale: 7.0 teacher_cfg_prob: 1.0 teacher_cfg_warmup_steps: 1000 teacher_cfg_trunc: 0.9 lambda_kd_uncond: 0.3 # reward_use_guided: false <-- [请直接删除这行] 因为 Reward 计算已被移除 # ---- DiMO extensions ------------------------------------------------------- fake_rounds: 2 #1 # Aux 拟合假 token 的迭代次数。如果发现 Aux 算出的 bridge_loss 降不下去,可以尝试改为 2 use_surrogate_grad: false lambda_surr: 1.0 # ---- Stability ------------------------------------------------------------- t_curriculum_steps: 10000 # curriculum steps before uniform-t sampling p_init_mix_ratio: 0.2 # fraction of batch from corrupted x_hat_prev p_mix_corrupt_frac: 0.2 # token corruption rate in p_init mixing collapse_warn_frac: 0.2 # warn if tok_entropy < frac × initial entropy # ---- Aux initialisation --------------------------------------------------- aux_noise_std: 1.0e-5 # tiny noise added to aux weights at init to break # symmetry; set 0.0 to keep aux == student exactly # ---- Gradient clipping ---------------------------------------------------- grad_clip: 1.0 # ── Student optimizer ──────────────────────────────────────────────────────── optimizer_student: target: torch.optim.AdamW params: lr: 1.0e-5 betas: [0.9, 0.95] weight_decay: 0.01 # ── Aux optimizer ──────────────────────────────────────────────────────────── optimizer_aux: target: torch.optim.AdamW params: lr: 1.0e-5 betas: [0.9, 0.95] weight_decay: 0.01 # ── LR scheduler (cosine, shared warmup/decay params for both opts) ────────── lr_scheduler: target: diffnext.engine.lr_scheduler.CosineLR params: lr_max: ${optimizer_student.params.lr} lr_min: 1.0e-6 max_steps: ${training.max_train_steps} warmup_steps: 500 # ── Prompt DataLoader ───────────────────────────────────────────────────────── prompt_dataloader: shuffle_files: true shuffle_buffer: 50000 # in-memory shuffle buffer per shard; reduce if OOM num_workers: 4 # CPU workers (no CUDA in workers) caption_field: caption # CSV column name (Koala default)