World_Model / URSA /scripts /eval_distill_dimo.py
BryanW's picture
Add files using upload-large-folder tool
2ee4cd6 verified
#!/usr/bin/env python3
# Copyright (c) 2024-present, BAAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------
"""Evaluation script for distill_dimo checkpoints.
Generates videos from both the student (1-step) and teacher (multi-step)
using checkpoints saved by train_distill_dimo.py.
Verified native inference regime (from A/B testing — ground truth):
height=320, width=512, num_frames=49, guidance_scale=7, teacher_steps=50.
no_cfg (guidance_scale=1) does NOT produce valid output for this URSA
checkpoint — outputs are blank or blurry.
Student generation modes
------------------------
cfg : 1-step, guidance_scale=7 (2× forward, inference-time CFG)
Teacher generation modes
------------------------
cfg : 50-step, guidance_scale=7 (official working regime)
Usage:
python scripts/eval_distill_dimo.py \
--teacher_ckpt /gfs/space/private/fengzl/World_Model/URSA-1.7B \
--student_ckpt ./experiments/distill_dimo_v3/checkpoints/checkpoint-200/student.pt \
--out_dir ./outputs/eval_distill_v3_200steps_49frames
"""
import argparse
import os
import sys
import numpy as np
import torch
_REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
if _REPO_ROOT not in sys.path:
sys.path.insert(0, _REPO_ROOT)
from diffnext.pipelines import URSAPipeline
from diffnext.pipelines.ursa.pipeline_ursa_distill_dimo import (
VERIFIED_NATIVE_DEFAULTS,
check_verified_regime,
)
from diffnext.utils import export_to_video
# ---------------------------------------------------------------------------
# Default prompts and seeds
# ---------------------------------------------------------------------------
DEFAULT_PROMPTS = [
"a lone grizzly bear walks through a misty forest at dawn, sunlight catching its fur.",
"beautiful fireworks in the sky with red, white and blue.",
"a wave crashes on a rocky shoreline at sunset, slow motion.",
"a hummingbird hovers in front of a red flower, wings a blur.",
"timelapse of clouds rolling over mountain peaks.",
"a neon-lit city street at night with rain-soaked reflections.",
"a kitten playing with a ball of yarn on a wooden floor.",
"astronaut floating weightlessly inside a space station.",
]
# Official URSA negative prompt (from README / app scripts)
DEFAULT_NEGATIVE_PROMPT = (
"worst quality, low quality, inconsistent motion, static, still, "
"blurry, jittery, distorted, ugly"
)
DEFAULT_SEEDS = [0, 1, 2, 3]
# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------
def parse_args():
p = argparse.ArgumentParser(
description="Evaluate distill_dimo student (1-step) vs teacher (multi-step)"
)
p.add_argument("--teacher_ckpt", required=True,
help="URSA diffusers pipeline directory (teacher weights)")
p.add_argument("--student_ckpt", required=True,
help="student.pt from train_distill_dimo.py checkpoint")
p.add_argument("--out_dir", default="./outputs/eval_distill")
# Geometry — verified native: 320×512×49 (from A/B testing)
p.add_argument("--num_frames", type=int, default=49)
p.add_argument("--height", type=int, default=320)
p.add_argument("--width", type=int, default=512)
p.add_argument("--fps", type=int, default=12)
# Student generation — default: cfg only (no_cfg is known to fail)
p.add_argument("--student_modes", nargs="+", default=["cfg"],
choices=["no_cfg", "cfg", "baked"],
help="Student generation modes to evaluate. "
"Default: ['cfg']. no_cfg is known to produce blank/blurry "
"output for this checkpoint.")
p.add_argument("--eval_cfg_scale", type=float, default=7.0,
help="Guidance scale for 'cfg' mode (verified working value=7)")
# Teacher generation — default: cfg only (no_cfg is known to fail)
p.add_argument("--teacher_modes", nargs="+", default=["cfg"],
choices=["no_cfg", "cfg"],
help="Teacher generation modes. Default: ['cfg']. "
"no_cfg is NOT a valid baseline for this URSA checkpoint.")
p.add_argument("--teacher_steps", type=int, default=50,
help="Number of inference steps for teacher (verified default=50)")
# Shared generation params (match verified official defaults)
p.add_argument("--guidance_trunc", type=float, default=0.9,
help="Truncation threshold for inference CFG")
p.add_argument("--negative_prompt", type=str, default=DEFAULT_NEGATIVE_PROMPT,
help="Negative prompt for CFG (official URSA uses one)")
p.add_argument("--max_prompt_length", type=int, default=320)
p.add_argument("--vae_batch_size", type=int, default=1)
# Safety override for no_cfg
p.add_argument("--allow_bad_nocfg", action="store_true", default=False,
help="Suppress the no_cfg warning/block. Use at your own risk.")
# Data
p.add_argument("--prompt_file", default=None,
help="Text file with one prompt per line (overrides defaults)")
p.add_argument("--seeds", nargs="*", type=int, default=DEFAULT_SEEDS)
# Device
p.add_argument("--device", type=int, default=0)
p.add_argument("--mixed_precision", default="bf16",
choices=["fp16", "bf16", "fp32"])
return p.parse_args()
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def slug(text: str, max_len: int = 40) -> str:
s = text.lower()
s = "".join(c if c.isalnum() or c == " " else "" for c in s)
s = "_".join(s.split())[:max_len]
return s or "prompt"
def frames_to_mp4(frames, path: str, fps: int = 12):
os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True)
if isinstance(frames, np.ndarray) and frames.ndim == 4:
frames = list(frames)
export_to_video(frames, output_video_path=path, fps=fps)
def _extract_frames(frames_output):
"""Normalise pipeline output → list of uint8 numpy arrays [H, W, 3]."""
if isinstance(frames_output, np.ndarray):
frames_output = frames_output[0] if frames_output.ndim == 5 else frames_output
frames = list(frames_output)
elif isinstance(frames_output, list):
frames = [np.array(f) if not isinstance(f, np.ndarray) else f
for f in frames_output]
else:
raise TypeError(f"Unexpected frames type: {type(frames_output)}")
result = []
for f in frames:
if f.dtype != np.uint8:
f = ((f * 255).clip(0, 255).astype(np.uint8)
if f.max() <= 1.0 else f.astype(np.uint8))
result.append(f)
return result
def _gen(pipe, prompt, negative_prompt, seed, num_frames, height, width,
guidance_scale, num_inference_steps, guidance_trunc,
max_prompt_length, vae_batch_size, device):
"""Single generation call, returns list of uint8 frames."""
gen = torch.Generator(device=device).manual_seed(seed)
with torch.inference_mode():
out = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
height=height,
width=width,
num_frames=num_frames,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
guidance_trunc=guidance_trunc,
max_prompt_length=max_prompt_length,
vae_batch_size=vae_batch_size,
output_type="np",
generator=gen,
)
return _extract_frames(out.frames)
def _debug_pipeline(pipe, label=""):
"""Print diagnostic info about the pipeline state."""
print(f"\n{'='*60}")
print(f" Pipeline diagnostics {label}")
print(f"{'='*60}")
print(f" scheduler class : {type(pipe.scheduler).__name__}")
print(f" scheduler type : {type(pipe.scheduler)}")
if hasattr(pipe.scheduler, 'config'):
print(f" scheduler.config : {dict(pipe.scheduler.config)}")
if hasattr(pipe.scheduler, 'path'):
print(f" scheduler.path : {type(pipe.scheduler.path).__name__}")
if hasattr(pipe.scheduler.path, 'emb'):
emb = pipe.scheduler.path.emb
print(f" path.emb shape : {emb.shape}")
print(f" path.emb device : {emb.device}")
print(f" path.emb dtype : {emb.dtype}")
else:
print(f" scheduler.path : MISSING (scheduler not fully loaded!)")
print(f" codebook_size : {getattr(pipe.scheduler, 'codebook_size', 'N/A')}")
print(f" transformer class : {type(pipe.transformer).__name__}")
print(f" transformer device : {next(pipe.transformer.parameters()).device}")
print(f" vae class : {type(pipe.vae).__name__}")
if hasattr(pipe, 'image_processor'):
print(f" image_processor : {type(pipe.image_processor).__name__}")
print(f"{'='*60}\n")
def _debug_frames(frames, label=""):
"""Print diagnostic info about generated frames."""
if not frames:
print(f" [{label}] No frames generated!")
return
f0 = frames[0]
print(f" [{label}] n_frames={len(frames)} shape={f0.shape} "
f"dtype={f0.dtype} min={f0.min()} max={f0.max()}")
def _verify_state_dict_swap(pipe, state_dict, label=""):
"""Verify transformer weights actually changed after load_state_dict."""
sample_key = next(iter(state_dict.keys()))
loaded_val = state_dict[sample_key].flatten()[:8]
current_val = pipe.transformer.state_dict()[sample_key].flatten()[:8]
match = torch.allclose(loaded_val.cpu().float(), current_val.cpu().float(), atol=1e-6)
print(f" [{label}] state_dict match for '{sample_key}': {match}")
if not match:
print(f" loaded : {loaded_val[:4]}")
print(f" current : {current_val[:4]}")
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main():
args = parse_args()
dtype_map = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp32": torch.float32}
dtype = dtype_map[args.mixed_precision]
device = (torch.device("cuda", args.device)
if torch.cuda.is_available() else torch.device("cpu"))
os.makedirs(args.out_dir, exist_ok=True)
# -- Verified regime check --------------------------------------------
is_native = check_verified_regime(
height=args.height,
width=args.width,
num_frames=args.num_frames,
guidance_scale=args.eval_cfg_scale,
teacher_steps=args.teacher_steps,
label="eval",
)
print(f"[eval] verified_native_regime={is_native}")
print(f"[eval] geometry=({args.num_frames},{args.height},{args.width}), "
f"guidance_scale={args.eval_cfg_scale}, teacher_steps={args.teacher_steps}")
# -- no_cfg safety gate -----------------------------------------------
all_modes = list(args.student_modes) + list(args.teacher_modes)
if "no_cfg" in all_modes:
if args.allow_bad_nocfg:
print("[WARN] no_cfg is known to fail for this URSA checkpoint. "
"Outputs may be blank or blurry. Proceeding because --allow_bad_nocfg is set.")
else:
print("[WARN] no_cfg is known to fail for this URSA checkpoint. "
"Outputs may be blank or blurry. "
"Pass --allow_bad_nocfg to override this warning.")
# -- Load prompts -----------------------------------------------------
if args.prompt_file:
with open(args.prompt_file, encoding="utf-8") as f:
prompts = [l.strip() for l in f if l.strip() and not l.startswith("#")]
else:
prompts = DEFAULT_PROMPTS
print(f"[eval] {len(prompts)} prompts × {len(args.seeds)} seeds "
f"| student modes={args.student_modes} "
f"| teacher modes={args.teacher_modes}")
print(f"[eval] guidance_scale={args.eval_cfg_scale} "
f"guidance_trunc={args.guidance_trunc} "
f"teacher_steps={args.teacher_steps}")
print(f"[eval] negative_prompt='{args.negative_prompt[:60]}...'")
# -- Load pipeline (teacher) ------------------------------------------
print(f"[eval] Loading pipeline from {args.teacher_ckpt} …")
# 【修改点 2】尝试启用 Flash Attention 2
try:
pipe = URSAPipeline.from_pretrained(
args.teacher_ckpt,
torch_dtype=dtype,
trust_remote_code=True,
attn_implementation="flash_attention_2"
).to(device)
except Exception:
# 如果环境不支持 FA2,降级到默认
pipe = URSAPipeline.from_pretrained(
args.teacher_ckpt, torch_dtype=dtype, trust_remote_code=True
).to(device)
if hasattr(pipe.vae, "disable_slicing"):
pipe.vae.disable_slicing()
if hasattr(pipe.vae, "disable_tiling"):
pipe.vae.disable_tiling()
# print("[eval] Compiling transformer (this takes ~2 mins for the first time)...")
# pipe.transformer = torch.compile(pipe.transformer, mode="reduce-overhead")
# Diagnostic: verify scheduler loaded correctly
_debug_pipeline(pipe, label="after from_pretrained + .to(device)")
# CRITICAL CHECK: scheduler must have .path with embeddings
if not hasattr(pipe.scheduler, 'path') or pipe.scheduler.path is None:
print("[ERROR] Scheduler path not loaded! This will cause blurry output.")
print("[ERROR] The scheduler needs scheduler_model.pth with codebook embeddings.")
return
if pipe.scheduler.codebook_size == 0:
print("[ERROR] codebook_size=0 — scheduler not properly initialized!")
return
# Save teacher state for switching back after student inference
teacher_state = {k: v.clone() for k, v in pipe.transformer.state_dict().items()}
# -- Load student checkpoint ------------------------------------------
print(f"[eval] Loading student weights from {args.student_ckpt} …")
student_state = torch.load(
args.student_ckpt, map_location=device, weights_only=True
)
print(f"[eval] student state_dict keys: {len(student_state)} "
f"sample key: {next(iter(student_state.keys()))}")
# Common kwargs for every pipeline call
gen_kwargs = dict(
num_frames=args.num_frames,
height=args.height,
width=args.width,
guidance_trunc=args.guidance_trunc,
max_prompt_length=args.max_prompt_length,
vae_batch_size=args.vae_batch_size,
)
# Mode → guidance_scale mapping
student_guidance = {
"no_cfg": 1.0,
"cfg": args.eval_cfg_scale,
"baked": 1.0,
}
teacher_guidance = {
"no_cfg": 1.0,
"cfg": args.eval_cfg_scale,
}
# -- Evaluation loop --------------------------------------------------
for idx, prompt in enumerate(prompts):
p_slug = slug(prompt)
print(f"\n[{idx+1}/{len(prompts)}] {prompt[:70]}")
for seed in args.seeds:
# ---- Student: 1-step generation -----------------------------
for mode in args.student_modes:
g_scale = student_guidance[mode]
neg = args.negative_prompt if g_scale > 1 else None
pipe.transformer.load_state_dict(student_state, strict=True)
pipe.transformer.eval()
if idx == 0 and seed == args.seeds[0]:
_verify_state_dict_swap(pipe, student_state, f"student/{mode}")
with torch.no_grad():
frames = _gen(pipe, prompt, neg, seed,
guidance_scale=g_scale,
num_inference_steps=1,
device=device, **gen_kwargs)
if idx == 0 and seed == args.seeds[0]:
_debug_frames(frames, f"student/{mode}")
path = os.path.join(
args.out_dir,
f"{idx:02d}_s{seed}_{p_slug}_student_1step_{mode}.mp4",
)
frames_to_mp4(frames, path, fps=args.fps)
print(f" [student/{mode:6s}] seed={seed} scale={g_scale}{path}")
# ---- Teacher: multi-step reference --------------------------
for t_mode in args.teacher_modes:
g_scale = teacher_guidance[t_mode]
neg = args.negative_prompt if g_scale > 1 else None
pipe.transformer.load_state_dict(teacher_state, strict=True)
pipe.transformer.eval()
if idx == 0 and seed == args.seeds[0]:
_verify_state_dict_swap(pipe, teacher_state, f"teacher/{t_mode}")
with torch.no_grad():
frames = _gen(pipe, prompt, neg, seed,
guidance_scale=g_scale,
num_inference_steps=args.teacher_steps,
device=device, **gen_kwargs)
if idx == 0 and seed == args.seeds[0]:
_debug_frames(frames, f"teacher/{t_mode}")
path = os.path.join(
args.out_dir,
f"{idx:02d}_s{seed}_{p_slug}_teacher_{args.teacher_steps}step_{t_mode}.mp4",
)
frames_to_mp4(frames, path, fps=args.fps)
print(f" [teacher/{t_mode:6s}] seed={seed} scale={g_scale} "
f"steps={args.teacher_steps}{path}")
print(f"\n[eval] Done. Results in {args.out_dir}")
_print_guide(args)
def _print_guide(args):
print(f"""
╔══════════════════════════════════════════════════════════════╗
║ Interpretation guide ║
╠══════════════════════════════════════════════════════════════╣
║ student_1step_cfg : 1-step, guidance_scale={args.eval_cfg_scale:<4}
║ (verified working student mode) ║
║ student_1step_baked : 1-step, guidance_scale=1 ║
║ (for students trained with CFG KD) ║
║ teacher_{args.teacher_steps}step_cfg : {args.teacher_steps}-step, guidance_scale={args.eval_cfg_scale:<4}
║ (verified working teacher mode) ║
╠══════════════════════════════════════════════════════════════╣
║ NOTE: no_cfg (guidance_scale=1) is NOT a valid baseline ║
║ for this URSA checkpoint. Use --allow_bad_nocfg to override.║
╚══════════════════════════════════════════════════════════════╝""")
if __name__ == "__main__":
main()