#!/usr/bin/env python3 # Copyright (c) 2024-present, BAAI. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ----------------------------------------------------------------------- """Evaluation script for distill_dimo checkpoints. Generates videos from both the student (1-step) and teacher (multi-step) using checkpoints saved by train_distill_dimo.py. Verified native inference regime (from A/B testing — ground truth): height=320, width=512, num_frames=49, guidance_scale=7, teacher_steps=50. no_cfg (guidance_scale=1) does NOT produce valid output for this URSA checkpoint — outputs are blank or blurry. Student generation modes ------------------------ cfg : 1-step, guidance_scale=7 (2× forward, inference-time CFG) Teacher generation modes ------------------------ cfg : 50-step, guidance_scale=7 (official working regime) Usage: python scripts/eval_distill_dimo.py \ --teacher_ckpt /gfs/space/private/fengzl/World_Model/URSA-1.7B \ --student_ckpt ./experiments/distill_dimo_v3/checkpoints/checkpoint-200/student.pt \ --out_dir ./outputs/eval_distill_v3_200steps_49frames """ import argparse import os import sys import numpy as np import torch _REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) if _REPO_ROOT not in sys.path: sys.path.insert(0, _REPO_ROOT) from diffnext.pipelines import URSAPipeline from diffnext.pipelines.ursa.pipeline_ursa_distill_dimo import ( VERIFIED_NATIVE_DEFAULTS, check_verified_regime, ) from diffnext.utils import export_to_video # --------------------------------------------------------------------------- # Default prompts and seeds # --------------------------------------------------------------------------- DEFAULT_PROMPTS = [ "a lone grizzly bear walks through a misty forest at dawn, sunlight catching its fur.", "beautiful fireworks in the sky with red, white and blue.", "a wave crashes on a rocky shoreline at sunset, slow motion.", "a hummingbird hovers in front of a red flower, wings a blur.", "timelapse of clouds rolling over mountain peaks.", "a neon-lit city street at night with rain-soaked reflections.", "a kitten playing with a ball of yarn on a wooden floor.", "astronaut floating weightlessly inside a space station.", ] # Official URSA negative prompt (from README / app scripts) DEFAULT_NEGATIVE_PROMPT = ( "worst quality, low quality, inconsistent motion, static, still, " "blurry, jittery, distorted, ugly" ) DEFAULT_SEEDS = [0, 1, 2, 3] # --------------------------------------------------------------------------- # CLI # --------------------------------------------------------------------------- def parse_args(): p = argparse.ArgumentParser( description="Evaluate distill_dimo student (1-step) vs teacher (multi-step)" ) p.add_argument("--teacher_ckpt", required=True, help="URSA diffusers pipeline directory (teacher weights)") p.add_argument("--student_ckpt", required=True, help="student.pt from train_distill_dimo.py checkpoint") p.add_argument("--out_dir", default="./outputs/eval_distill") # Geometry — verified native: 320×512×49 (from A/B testing) p.add_argument("--num_frames", type=int, default=49) p.add_argument("--height", type=int, default=320) p.add_argument("--width", type=int, default=512) p.add_argument("--fps", type=int, default=12) # Student generation — default: cfg only (no_cfg is known to fail) p.add_argument("--student_modes", nargs="+", default=["cfg"], choices=["no_cfg", "cfg", "baked"], help="Student generation modes to evaluate. " "Default: ['cfg']. no_cfg is known to produce blank/blurry " "output for this checkpoint.") p.add_argument("--eval_cfg_scale", type=float, default=7.0, help="Guidance scale for 'cfg' mode (verified working value=7)") # Teacher generation — default: cfg only (no_cfg is known to fail) p.add_argument("--teacher_modes", nargs="+", default=["cfg"], choices=["no_cfg", "cfg"], help="Teacher generation modes. Default: ['cfg']. " "no_cfg is NOT a valid baseline for this URSA checkpoint.") p.add_argument("--teacher_steps", type=int, default=50, help="Number of inference steps for teacher (verified default=50)") # Shared generation params (match verified official defaults) p.add_argument("--guidance_trunc", type=float, default=0.9, help="Truncation threshold for inference CFG") p.add_argument("--negative_prompt", type=str, default=DEFAULT_NEGATIVE_PROMPT, help="Negative prompt for CFG (official URSA uses one)") p.add_argument("--max_prompt_length", type=int, default=320) p.add_argument("--vae_batch_size", type=int, default=1) # Safety override for no_cfg p.add_argument("--allow_bad_nocfg", action="store_true", default=False, help="Suppress the no_cfg warning/block. Use at your own risk.") # Data p.add_argument("--prompt_file", default=None, help="Text file with one prompt per line (overrides defaults)") p.add_argument("--seeds", nargs="*", type=int, default=DEFAULT_SEEDS) # Device p.add_argument("--device", type=int, default=0) p.add_argument("--mixed_precision", default="bf16", choices=["fp16", "bf16", "fp32"]) return p.parse_args() # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def slug(text: str, max_len: int = 40) -> str: s = text.lower() s = "".join(c if c.isalnum() or c == " " else "" for c in s) s = "_".join(s.split())[:max_len] return s or "prompt" def frames_to_mp4(frames, path: str, fps: int = 12): os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True) if isinstance(frames, np.ndarray) and frames.ndim == 4: frames = list(frames) export_to_video(frames, output_video_path=path, fps=fps) def _extract_frames(frames_output): """Normalise pipeline output → list of uint8 numpy arrays [H, W, 3].""" if isinstance(frames_output, np.ndarray): frames_output = frames_output[0] if frames_output.ndim == 5 else frames_output frames = list(frames_output) elif isinstance(frames_output, list): frames = [np.array(f) if not isinstance(f, np.ndarray) else f for f in frames_output] else: raise TypeError(f"Unexpected frames type: {type(frames_output)}") result = [] for f in frames: if f.dtype != np.uint8: f = ((f * 255).clip(0, 255).astype(np.uint8) if f.max() <= 1.0 else f.astype(np.uint8)) result.append(f) return result def _gen(pipe, prompt, negative_prompt, seed, num_frames, height, width, guidance_scale, num_inference_steps, guidance_trunc, max_prompt_length, vae_batch_size, device): """Single generation call, returns list of uint8 frames.""" gen = torch.Generator(device=device).manual_seed(seed) with torch.inference_mode(): out = pipe( prompt=prompt, negative_prompt=negative_prompt, height=height, width=width, num_frames=num_frames, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, guidance_trunc=guidance_trunc, max_prompt_length=max_prompt_length, vae_batch_size=vae_batch_size, output_type="np", generator=gen, ) return _extract_frames(out.frames) def _debug_pipeline(pipe, label=""): """Print diagnostic info about the pipeline state.""" print(f"\n{'='*60}") print(f" Pipeline diagnostics {label}") print(f"{'='*60}") print(f" scheduler class : {type(pipe.scheduler).__name__}") print(f" scheduler type : {type(pipe.scheduler)}") if hasattr(pipe.scheduler, 'config'): print(f" scheduler.config : {dict(pipe.scheduler.config)}") if hasattr(pipe.scheduler, 'path'): print(f" scheduler.path : {type(pipe.scheduler.path).__name__}") if hasattr(pipe.scheduler.path, 'emb'): emb = pipe.scheduler.path.emb print(f" path.emb shape : {emb.shape}") print(f" path.emb device : {emb.device}") print(f" path.emb dtype : {emb.dtype}") else: print(f" scheduler.path : MISSING (scheduler not fully loaded!)") print(f" codebook_size : {getattr(pipe.scheduler, 'codebook_size', 'N/A')}") print(f" transformer class : {type(pipe.transformer).__name__}") print(f" transformer device : {next(pipe.transformer.parameters()).device}") print(f" vae class : {type(pipe.vae).__name__}") if hasattr(pipe, 'image_processor'): print(f" image_processor : {type(pipe.image_processor).__name__}") print(f"{'='*60}\n") def _debug_frames(frames, label=""): """Print diagnostic info about generated frames.""" if not frames: print(f" [{label}] No frames generated!") return f0 = frames[0] print(f" [{label}] n_frames={len(frames)} shape={f0.shape} " f"dtype={f0.dtype} min={f0.min()} max={f0.max()}") def _verify_state_dict_swap(pipe, state_dict, label=""): """Verify transformer weights actually changed after load_state_dict.""" sample_key = next(iter(state_dict.keys())) loaded_val = state_dict[sample_key].flatten()[:8] current_val = pipe.transformer.state_dict()[sample_key].flatten()[:8] match = torch.allclose(loaded_val.cpu().float(), current_val.cpu().float(), atol=1e-6) print(f" [{label}] state_dict match for '{sample_key}': {match}") if not match: print(f" loaded : {loaded_val[:4]}") print(f" current : {current_val[:4]}") # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- def main(): args = parse_args() dtype_map = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp32": torch.float32} dtype = dtype_map[args.mixed_precision] device = (torch.device("cuda", args.device) if torch.cuda.is_available() else torch.device("cpu")) os.makedirs(args.out_dir, exist_ok=True) # -- Verified regime check -------------------------------------------- is_native = check_verified_regime( height=args.height, width=args.width, num_frames=args.num_frames, guidance_scale=args.eval_cfg_scale, teacher_steps=args.teacher_steps, label="eval", ) print(f"[eval] verified_native_regime={is_native}") print(f"[eval] geometry=({args.num_frames},{args.height},{args.width}), " f"guidance_scale={args.eval_cfg_scale}, teacher_steps={args.teacher_steps}") # -- no_cfg safety gate ----------------------------------------------- all_modes = list(args.student_modes) + list(args.teacher_modes) if "no_cfg" in all_modes: if args.allow_bad_nocfg: print("[WARN] no_cfg is known to fail for this URSA checkpoint. " "Outputs may be blank or blurry. Proceeding because --allow_bad_nocfg is set.") else: print("[WARN] no_cfg is known to fail for this URSA checkpoint. " "Outputs may be blank or blurry. " "Pass --allow_bad_nocfg to override this warning.") # -- Load prompts ----------------------------------------------------- if args.prompt_file: with open(args.prompt_file, encoding="utf-8") as f: prompts = [l.strip() for l in f if l.strip() and not l.startswith("#")] else: prompts = DEFAULT_PROMPTS print(f"[eval] {len(prompts)} prompts × {len(args.seeds)} seeds " f"| student modes={args.student_modes} " f"| teacher modes={args.teacher_modes}") print(f"[eval] guidance_scale={args.eval_cfg_scale} " f"guidance_trunc={args.guidance_trunc} " f"teacher_steps={args.teacher_steps}") print(f"[eval] negative_prompt='{args.negative_prompt[:60]}...'") # -- Load pipeline (teacher) ------------------------------------------ print(f"[eval] Loading pipeline from {args.teacher_ckpt} …") # 【修改点 2】尝试启用 Flash Attention 2 try: pipe = URSAPipeline.from_pretrained( args.teacher_ckpt, torch_dtype=dtype, trust_remote_code=True, attn_implementation="flash_attention_2" ).to(device) except Exception: # 如果环境不支持 FA2,降级到默认 pipe = URSAPipeline.from_pretrained( args.teacher_ckpt, torch_dtype=dtype, trust_remote_code=True ).to(device) if hasattr(pipe.vae, "disable_slicing"): pipe.vae.disable_slicing() if hasattr(pipe.vae, "disable_tiling"): pipe.vae.disable_tiling() # print("[eval] Compiling transformer (this takes ~2 mins for the first time)...") # pipe.transformer = torch.compile(pipe.transformer, mode="reduce-overhead") # Diagnostic: verify scheduler loaded correctly _debug_pipeline(pipe, label="after from_pretrained + .to(device)") # CRITICAL CHECK: scheduler must have .path with embeddings if not hasattr(pipe.scheduler, 'path') or pipe.scheduler.path is None: print("[ERROR] Scheduler path not loaded! This will cause blurry output.") print("[ERROR] The scheduler needs scheduler_model.pth with codebook embeddings.") return if pipe.scheduler.codebook_size == 0: print("[ERROR] codebook_size=0 — scheduler not properly initialized!") return # Save teacher state for switching back after student inference teacher_state = {k: v.clone() for k, v in pipe.transformer.state_dict().items()} # -- Load student checkpoint ------------------------------------------ print(f"[eval] Loading student weights from {args.student_ckpt} …") student_state = torch.load( args.student_ckpt, map_location=device, weights_only=True ) print(f"[eval] student state_dict keys: {len(student_state)} " f"sample key: {next(iter(student_state.keys()))}") # Common kwargs for every pipeline call gen_kwargs = dict( num_frames=args.num_frames, height=args.height, width=args.width, guidance_trunc=args.guidance_trunc, max_prompt_length=args.max_prompt_length, vae_batch_size=args.vae_batch_size, ) # Mode → guidance_scale mapping student_guidance = { "no_cfg": 1.0, "cfg": args.eval_cfg_scale, "baked": 1.0, } teacher_guidance = { "no_cfg": 1.0, "cfg": args.eval_cfg_scale, } # -- Evaluation loop -------------------------------------------------- for idx, prompt in enumerate(prompts): p_slug = slug(prompt) print(f"\n[{idx+1}/{len(prompts)}] {prompt[:70]}") for seed in args.seeds: # ---- Student: 1-step generation ----------------------------- for mode in args.student_modes: g_scale = student_guidance[mode] neg = args.negative_prompt if g_scale > 1 else None pipe.transformer.load_state_dict(student_state, strict=True) pipe.transformer.eval() if idx == 0 and seed == args.seeds[0]: _verify_state_dict_swap(pipe, student_state, f"student/{mode}") with torch.no_grad(): frames = _gen(pipe, prompt, neg, seed, guidance_scale=g_scale, num_inference_steps=1, device=device, **gen_kwargs) if idx == 0 and seed == args.seeds[0]: _debug_frames(frames, f"student/{mode}") path = os.path.join( args.out_dir, f"{idx:02d}_s{seed}_{p_slug}_student_1step_{mode}.mp4", ) frames_to_mp4(frames, path, fps=args.fps) print(f" [student/{mode:6s}] seed={seed} scale={g_scale} → {path}") # ---- Teacher: multi-step reference -------------------------- for t_mode in args.teacher_modes: g_scale = teacher_guidance[t_mode] neg = args.negative_prompt if g_scale > 1 else None pipe.transformer.load_state_dict(teacher_state, strict=True) pipe.transformer.eval() if idx == 0 and seed == args.seeds[0]: _verify_state_dict_swap(pipe, teacher_state, f"teacher/{t_mode}") with torch.no_grad(): frames = _gen(pipe, prompt, neg, seed, guidance_scale=g_scale, num_inference_steps=args.teacher_steps, device=device, **gen_kwargs) if idx == 0 and seed == args.seeds[0]: _debug_frames(frames, f"teacher/{t_mode}") path = os.path.join( args.out_dir, f"{idx:02d}_s{seed}_{p_slug}_teacher_{args.teacher_steps}step_{t_mode}.mp4", ) frames_to_mp4(frames, path, fps=args.fps) print(f" [teacher/{t_mode:6s}] seed={seed} scale={g_scale} " f"steps={args.teacher_steps} → {path}") print(f"\n[eval] Done. Results in {args.out_dir}") _print_guide(args) def _print_guide(args): print(f""" ╔══════════════════════════════════════════════════════════════╗ ║ Interpretation guide ║ ╠══════════════════════════════════════════════════════════════╣ ║ student_1step_cfg : 1-step, guidance_scale={args.eval_cfg_scale:<4} ║ ║ (verified working student mode) ║ ║ student_1step_baked : 1-step, guidance_scale=1 ║ ║ (for students trained with CFG KD) ║ ║ teacher_{args.teacher_steps}step_cfg : {args.teacher_steps}-step, guidance_scale={args.eval_cfg_scale:<4} ║ ║ (verified working teacher mode) ║ ╠══════════════════════════════════════════════════════════════╣ ║ NOTE: no_cfg (guidance_scale=1) is NOT a valid baseline ║ ║ for this URSA checkpoint. Use --allow_bad_nocfg to override.║ ╚══════════════════════════════════════════════════════════════╝""") if __name__ == "__main__": main()