| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """Evaluation script for distill_dimo checkpoints. |
| |
| Generates videos from both the student (1-step) and teacher (multi-step) |
| using checkpoints saved by train_distill_dimo.py. |
| |
| Verified native inference regime (from A/B testing — ground truth): |
| height=320, width=512, num_frames=49, guidance_scale=7, teacher_steps=50. |
| no_cfg (guidance_scale=1) does NOT produce valid output for this URSA |
| checkpoint — outputs are blank or blurry. |
| |
| Student generation modes |
| ------------------------ |
| cfg : 1-step, guidance_scale=7 (2× forward, inference-time CFG) |
| |
| Teacher generation modes |
| ------------------------ |
| cfg : 50-step, guidance_scale=7 (official working regime) |
| |
| Usage: |
| python scripts/eval_distill_dimo.py \ |
| --teacher_ckpt /gfs/space/private/fengzl/World_Model/URSA-1.7B \ |
| --student_ckpt ./experiments/distill_dimo_v3/checkpoints/checkpoint-200/student.pt \ |
| --out_dir ./outputs/eval_distill_v3_200steps_49frames |
| """ |
|
|
| import argparse |
| import os |
| import sys |
|
|
| import numpy as np |
| import torch |
|
|
| _REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) |
| if _REPO_ROOT not in sys.path: |
| sys.path.insert(0, _REPO_ROOT) |
|
|
| from diffnext.pipelines import URSAPipeline |
| from diffnext.pipelines.ursa.pipeline_ursa_distill_dimo import ( |
| VERIFIED_NATIVE_DEFAULTS, |
| check_verified_regime, |
| ) |
| from diffnext.utils import export_to_video |
|
|
|
|
| |
| |
| |
|
|
| DEFAULT_PROMPTS = [ |
| "a lone grizzly bear walks through a misty forest at dawn, sunlight catching its fur.", |
| "beautiful fireworks in the sky with red, white and blue.", |
| "a wave crashes on a rocky shoreline at sunset, slow motion.", |
| "a hummingbird hovers in front of a red flower, wings a blur.", |
| "timelapse of clouds rolling over mountain peaks.", |
| "a neon-lit city street at night with rain-soaked reflections.", |
| "a kitten playing with a ball of yarn on a wooden floor.", |
| "astronaut floating weightlessly inside a space station.", |
| ] |
|
|
| |
| DEFAULT_NEGATIVE_PROMPT = ( |
| "worst quality, low quality, inconsistent motion, static, still, " |
| "blurry, jittery, distorted, ugly" |
| ) |
|
|
| DEFAULT_SEEDS = [0, 1, 2, 3] |
|
|
|
|
| |
| |
| |
|
|
| def parse_args(): |
| p = argparse.ArgumentParser( |
| description="Evaluate distill_dimo student (1-step) vs teacher (multi-step)" |
| ) |
|
|
| p.add_argument("--teacher_ckpt", required=True, |
| help="URSA diffusers pipeline directory (teacher weights)") |
| p.add_argument("--student_ckpt", required=True, |
| help="student.pt from train_distill_dimo.py checkpoint") |
| p.add_argument("--out_dir", default="./outputs/eval_distill") |
|
|
| |
| p.add_argument("--num_frames", type=int, default=49) |
| p.add_argument("--height", type=int, default=320) |
| p.add_argument("--width", type=int, default=512) |
| p.add_argument("--fps", type=int, default=12) |
|
|
| |
| p.add_argument("--student_modes", nargs="+", default=["cfg"], |
| choices=["no_cfg", "cfg", "baked"], |
| help="Student generation modes to evaluate. " |
| "Default: ['cfg']. no_cfg is known to produce blank/blurry " |
| "output for this checkpoint.") |
| p.add_argument("--eval_cfg_scale", type=float, default=7.0, |
| help="Guidance scale for 'cfg' mode (verified working value=7)") |
|
|
| |
| p.add_argument("--teacher_modes", nargs="+", default=["cfg"], |
| choices=["no_cfg", "cfg"], |
| help="Teacher generation modes. Default: ['cfg']. " |
| "no_cfg is NOT a valid baseline for this URSA checkpoint.") |
| p.add_argument("--teacher_steps", type=int, default=50, |
| help="Number of inference steps for teacher (verified default=50)") |
|
|
| |
| p.add_argument("--guidance_trunc", type=float, default=0.9, |
| help="Truncation threshold for inference CFG") |
| p.add_argument("--negative_prompt", type=str, default=DEFAULT_NEGATIVE_PROMPT, |
| help="Negative prompt for CFG (official URSA uses one)") |
| p.add_argument("--max_prompt_length", type=int, default=320) |
| p.add_argument("--vae_batch_size", type=int, default=1) |
|
|
| |
| p.add_argument("--allow_bad_nocfg", action="store_true", default=False, |
| help="Suppress the no_cfg warning/block. Use at your own risk.") |
|
|
| |
| p.add_argument("--prompt_file", default=None, |
| help="Text file with one prompt per line (overrides defaults)") |
| p.add_argument("--seeds", nargs="*", type=int, default=DEFAULT_SEEDS) |
|
|
| |
| p.add_argument("--device", type=int, default=0) |
| p.add_argument("--mixed_precision", default="bf16", |
| choices=["fp16", "bf16", "fp32"]) |
|
|
| return p.parse_args() |
|
|
|
|
| |
| |
| |
|
|
| def slug(text: str, max_len: int = 40) -> str: |
| s = text.lower() |
| s = "".join(c if c.isalnum() or c == " " else "" for c in s) |
| s = "_".join(s.split())[:max_len] |
| return s or "prompt" |
|
|
|
|
| def frames_to_mp4(frames, path: str, fps: int = 12): |
| os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True) |
| if isinstance(frames, np.ndarray) and frames.ndim == 4: |
| frames = list(frames) |
| export_to_video(frames, output_video_path=path, fps=fps) |
|
|
|
|
| def _extract_frames(frames_output): |
| """Normalise pipeline output → list of uint8 numpy arrays [H, W, 3].""" |
| if isinstance(frames_output, np.ndarray): |
| frames_output = frames_output[0] if frames_output.ndim == 5 else frames_output |
| frames = list(frames_output) |
| elif isinstance(frames_output, list): |
| frames = [np.array(f) if not isinstance(f, np.ndarray) else f |
| for f in frames_output] |
| else: |
| raise TypeError(f"Unexpected frames type: {type(frames_output)}") |
| result = [] |
| for f in frames: |
| if f.dtype != np.uint8: |
| f = ((f * 255).clip(0, 255).astype(np.uint8) |
| if f.max() <= 1.0 else f.astype(np.uint8)) |
| result.append(f) |
| return result |
|
|
|
|
| def _gen(pipe, prompt, negative_prompt, seed, num_frames, height, width, |
| guidance_scale, num_inference_steps, guidance_trunc, |
| max_prompt_length, vae_batch_size, device): |
| """Single generation call, returns list of uint8 frames.""" |
| gen = torch.Generator(device=device).manual_seed(seed) |
| with torch.inference_mode(): |
| out = pipe( |
| prompt=prompt, |
| negative_prompt=negative_prompt, |
| height=height, |
| width=width, |
| num_frames=num_frames, |
| guidance_scale=guidance_scale, |
| num_inference_steps=num_inference_steps, |
| guidance_trunc=guidance_trunc, |
| max_prompt_length=max_prompt_length, |
| vae_batch_size=vae_batch_size, |
| output_type="np", |
| generator=gen, |
| ) |
| return _extract_frames(out.frames) |
|
|
|
|
| def _debug_pipeline(pipe, label=""): |
| """Print diagnostic info about the pipeline state.""" |
| print(f"\n{'='*60}") |
| print(f" Pipeline diagnostics {label}") |
| print(f"{'='*60}") |
| print(f" scheduler class : {type(pipe.scheduler).__name__}") |
| print(f" scheduler type : {type(pipe.scheduler)}") |
| if hasattr(pipe.scheduler, 'config'): |
| print(f" scheduler.config : {dict(pipe.scheduler.config)}") |
| if hasattr(pipe.scheduler, 'path'): |
| print(f" scheduler.path : {type(pipe.scheduler.path).__name__}") |
| if hasattr(pipe.scheduler.path, 'emb'): |
| emb = pipe.scheduler.path.emb |
| print(f" path.emb shape : {emb.shape}") |
| print(f" path.emb device : {emb.device}") |
| print(f" path.emb dtype : {emb.dtype}") |
| else: |
| print(f" scheduler.path : MISSING (scheduler not fully loaded!)") |
| print(f" codebook_size : {getattr(pipe.scheduler, 'codebook_size', 'N/A')}") |
| print(f" transformer class : {type(pipe.transformer).__name__}") |
| print(f" transformer device : {next(pipe.transformer.parameters()).device}") |
| print(f" vae class : {type(pipe.vae).__name__}") |
| if hasattr(pipe, 'image_processor'): |
| print(f" image_processor : {type(pipe.image_processor).__name__}") |
| print(f"{'='*60}\n") |
|
|
|
|
| def _debug_frames(frames, label=""): |
| """Print diagnostic info about generated frames.""" |
| if not frames: |
| print(f" [{label}] No frames generated!") |
| return |
| f0 = frames[0] |
| print(f" [{label}] n_frames={len(frames)} shape={f0.shape} " |
| f"dtype={f0.dtype} min={f0.min()} max={f0.max()}") |
|
|
|
|
| def _verify_state_dict_swap(pipe, state_dict, label=""): |
| """Verify transformer weights actually changed after load_state_dict.""" |
| sample_key = next(iter(state_dict.keys())) |
| loaded_val = state_dict[sample_key].flatten()[:8] |
| current_val = pipe.transformer.state_dict()[sample_key].flatten()[:8] |
| match = torch.allclose(loaded_val.cpu().float(), current_val.cpu().float(), atol=1e-6) |
| print(f" [{label}] state_dict match for '{sample_key}': {match}") |
| if not match: |
| print(f" loaded : {loaded_val[:4]}") |
| print(f" current : {current_val[:4]}") |
|
|
|
|
| |
| |
| |
|
|
| def main(): |
| args = parse_args() |
|
|
| dtype_map = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp32": torch.float32} |
| dtype = dtype_map[args.mixed_precision] |
| device = (torch.device("cuda", args.device) |
| if torch.cuda.is_available() else torch.device("cpu")) |
| os.makedirs(args.out_dir, exist_ok=True) |
|
|
| |
| is_native = check_verified_regime( |
| height=args.height, |
| width=args.width, |
| num_frames=args.num_frames, |
| guidance_scale=args.eval_cfg_scale, |
| teacher_steps=args.teacher_steps, |
| label="eval", |
| ) |
| print(f"[eval] verified_native_regime={is_native}") |
| print(f"[eval] geometry=({args.num_frames},{args.height},{args.width}), " |
| f"guidance_scale={args.eval_cfg_scale}, teacher_steps={args.teacher_steps}") |
|
|
| |
| all_modes = list(args.student_modes) + list(args.teacher_modes) |
| if "no_cfg" in all_modes: |
| if args.allow_bad_nocfg: |
| print("[WARN] no_cfg is known to fail for this URSA checkpoint. " |
| "Outputs may be blank or blurry. Proceeding because --allow_bad_nocfg is set.") |
| else: |
| print("[WARN] no_cfg is known to fail for this URSA checkpoint. " |
| "Outputs may be blank or blurry. " |
| "Pass --allow_bad_nocfg to override this warning.") |
|
|
| |
| if args.prompt_file: |
| with open(args.prompt_file, encoding="utf-8") as f: |
| prompts = [l.strip() for l in f if l.strip() and not l.startswith("#")] |
| else: |
| prompts = DEFAULT_PROMPTS |
|
|
| print(f"[eval] {len(prompts)} prompts × {len(args.seeds)} seeds " |
| f"| student modes={args.student_modes} " |
| f"| teacher modes={args.teacher_modes}") |
| print(f"[eval] guidance_scale={args.eval_cfg_scale} " |
| f"guidance_trunc={args.guidance_trunc} " |
| f"teacher_steps={args.teacher_steps}") |
| print(f"[eval] negative_prompt='{args.negative_prompt[:60]}...'") |
|
|
| |
| print(f"[eval] Loading pipeline from {args.teacher_ckpt} …") |
| |
| try: |
| pipe = URSAPipeline.from_pretrained( |
| args.teacher_ckpt, |
| torch_dtype=dtype, |
| trust_remote_code=True, |
| attn_implementation="flash_attention_2" |
| ).to(device) |
| except Exception: |
| |
| pipe = URSAPipeline.from_pretrained( |
| args.teacher_ckpt, torch_dtype=dtype, trust_remote_code=True |
| ).to(device) |
| |
| if hasattr(pipe.vae, "disable_slicing"): |
| pipe.vae.disable_slicing() |
| if hasattr(pipe.vae, "disable_tiling"): |
| pipe.vae.disable_tiling() |
| |
| |
| |
|
|
| |
| _debug_pipeline(pipe, label="after from_pretrained + .to(device)") |
|
|
| |
| if not hasattr(pipe.scheduler, 'path') or pipe.scheduler.path is None: |
| print("[ERROR] Scheduler path not loaded! This will cause blurry output.") |
| print("[ERROR] The scheduler needs scheduler_model.pth with codebook embeddings.") |
| return |
|
|
| if pipe.scheduler.codebook_size == 0: |
| print("[ERROR] codebook_size=0 — scheduler not properly initialized!") |
| return |
|
|
| |
| teacher_state = {k: v.clone() for k, v in pipe.transformer.state_dict().items()} |
|
|
| |
| print(f"[eval] Loading student weights from {args.student_ckpt} …") |
| student_state = torch.load( |
| args.student_ckpt, map_location=device, weights_only=True |
| ) |
| print(f"[eval] student state_dict keys: {len(student_state)} " |
| f"sample key: {next(iter(student_state.keys()))}") |
|
|
| |
| gen_kwargs = dict( |
| num_frames=args.num_frames, |
| height=args.height, |
| width=args.width, |
| guidance_trunc=args.guidance_trunc, |
| max_prompt_length=args.max_prompt_length, |
| vae_batch_size=args.vae_batch_size, |
| ) |
|
|
| |
| student_guidance = { |
| "no_cfg": 1.0, |
| "cfg": args.eval_cfg_scale, |
| "baked": 1.0, |
| } |
| teacher_guidance = { |
| "no_cfg": 1.0, |
| "cfg": args.eval_cfg_scale, |
| } |
|
|
| |
| for idx, prompt in enumerate(prompts): |
| p_slug = slug(prompt) |
| print(f"\n[{idx+1}/{len(prompts)}] {prompt[:70]}") |
|
|
| for seed in args.seeds: |
| |
| for mode in args.student_modes: |
| g_scale = student_guidance[mode] |
| neg = args.negative_prompt if g_scale > 1 else None |
| pipe.transformer.load_state_dict(student_state, strict=True) |
| pipe.transformer.eval() |
|
|
| if idx == 0 and seed == args.seeds[0]: |
| _verify_state_dict_swap(pipe, student_state, f"student/{mode}") |
|
|
| with torch.no_grad(): |
| frames = _gen(pipe, prompt, neg, seed, |
| guidance_scale=g_scale, |
| num_inference_steps=1, |
| device=device, **gen_kwargs) |
|
|
| if idx == 0 and seed == args.seeds[0]: |
| _debug_frames(frames, f"student/{mode}") |
|
|
| path = os.path.join( |
| args.out_dir, |
| f"{idx:02d}_s{seed}_{p_slug}_student_1step_{mode}.mp4", |
| ) |
| frames_to_mp4(frames, path, fps=args.fps) |
| print(f" [student/{mode:6s}] seed={seed} scale={g_scale} → {path}") |
|
|
| |
| for t_mode in args.teacher_modes: |
| g_scale = teacher_guidance[t_mode] |
| neg = args.negative_prompt if g_scale > 1 else None |
| pipe.transformer.load_state_dict(teacher_state, strict=True) |
| pipe.transformer.eval() |
|
|
| if idx == 0 and seed == args.seeds[0]: |
| _verify_state_dict_swap(pipe, teacher_state, f"teacher/{t_mode}") |
|
|
| with torch.no_grad(): |
| frames = _gen(pipe, prompt, neg, seed, |
| guidance_scale=g_scale, |
| num_inference_steps=args.teacher_steps, |
| device=device, **gen_kwargs) |
|
|
| if idx == 0 and seed == args.seeds[0]: |
| _debug_frames(frames, f"teacher/{t_mode}") |
|
|
| path = os.path.join( |
| args.out_dir, |
| f"{idx:02d}_s{seed}_{p_slug}_teacher_{args.teacher_steps}step_{t_mode}.mp4", |
| ) |
| frames_to_mp4(frames, path, fps=args.fps) |
| print(f" [teacher/{t_mode:6s}] seed={seed} scale={g_scale} " |
| f"steps={args.teacher_steps} → {path}") |
|
|
| print(f"\n[eval] Done. Results in {args.out_dir}") |
| _print_guide(args) |
|
|
|
|
| def _print_guide(args): |
| print(f""" |
| ╔══════════════════════════════════════════════════════════════╗ |
| ║ Interpretation guide ║ |
| ╠══════════════════════════════════════════════════════════════╣ |
| ║ student_1step_cfg : 1-step, guidance_scale={args.eval_cfg_scale:<4} ║ |
| ║ (verified working student mode) ║ |
| ║ student_1step_baked : 1-step, guidance_scale=1 ║ |
| ║ (for students trained with CFG KD) ║ |
| ║ teacher_{args.teacher_steps}step_cfg : {args.teacher_steps}-step, guidance_scale={args.eval_cfg_scale:<4} ║ |
| ║ (verified working teacher mode) ║ |
| ╠══════════════════════════════════════════════════════════════╣ |
| ║ NOTE: no_cfg (guidance_scale=1) is NOT a valid baseline ║ |
| ║ for this URSA checkpoint. Use --allow_bad_nocfg to override.║ |
| ╚══════════════════════════════════════════════════════════════╝""") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|