World_Model / URSA /scripts /ab_test_inference.py
BryanW's picture
Add files using upload-large-folder tool
2ee4cd6 verified
#!/usr/bin/env python3
"""A/B test: official URSA inference vs eval_distill_dimo inference.
This script runs the EXACT same pipeline call in two ways:
A) "official" — follows README Quick Start verbatim
B) "eval" — follows eval_distill_dimo.py logic
Both use the same pipeline instance, same prompt, same seed.
Saves side-by-side outputs + prints every intermediate diagnostic.
Usage:
python scripts/ab_test_inference.py \
--model /gfs/space/private/fengzl/World_Model/URSA-1.7B \
--device 0
This will generate:
outputs/ab_test/official_t2i.jpg
outputs/ab_test/official_t2v.mp4
outputs/ab_test/eval_teacher_cfg.mp4
outputs/ab_test/eval_teacher_nocfg.mp4
outputs/ab_test/eval_student_*.mp4 (if --student_ckpt given)
"""
import argparse
import os
import sys
import numpy as np
import torch
_REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
if _REPO_ROOT not in sys.path:
sys.path.insert(0, _REPO_ROOT)
from diffnext.pipelines import URSAPipeline
from diffnext.utils import export_to_image, export_to_video
def parse_args():
p = argparse.ArgumentParser(description="A/B test URSA inference")
p.add_argument("--model", required=True, help="URSA model path")
p.add_argument("--student_ckpt", default=None, help="Optional student.pt")
p.add_argument("--device", type=int, default=0)
p.add_argument("--precision", default="float16", choices=["float16", "bfloat16"])
p.add_argument("--out_dir", default="./outputs/ab_test")
# Test different resolutions — FSQ320 native is 320x512
p.add_argument("--test_resolutions", nargs="+", default=["320x512"],
help="Resolutions to test as HxW strings (FSQ320 native: 320x512)")
p.add_argument("--test_steps", nargs="+", type=int, default=[25, 50],
help="Inference steps to test")
p.add_argument("--num_frames", type=int, default=49)
return p.parse_args()
def diag(label, obj):
"""Print diagnostic."""
print(f" [{label}] {obj}")
def diag_tensor(label, t):
"""Print tensor diagnostics."""
if isinstance(t, torch.Tensor):
print(f" [{label}] shape={t.shape} dtype={t.dtype} device={t.device} "
f"min={t.min().item():.4f} max={t.max().item():.4f} mean={t.mean().item():.4f}")
elif isinstance(t, np.ndarray):
print(f" [{label}] shape={t.shape} dtype={t.dtype} "
f"min={t.min()} max={t.max()} mean={t.mean():.2f}")
def diag_pipeline(pipe):
"""Full pipeline diagnostic."""
print("\n" + "=" * 70)
print(" PIPELINE DIAGNOSTICS")
print("=" * 70)
print(f" pipeline class : {type(pipe).__name__}")
print(f" transformer class : {type(pipe.transformer).__name__}")
print(f" transformer device : {next(pipe.transformer.parameters()).device}")
print(f" transformer dtype : {next(pipe.transformer.parameters()).dtype}")
print(f" vae class : {type(pipe.vae).__name__}")
print(f" vae device : {next(pipe.vae.parameters()).device}")
print(f" scheduler class : {type(pipe.scheduler).__name__}")
print(f" scheduler repr : {repr(pipe.scheduler)}")
sched = pipe.scheduler
if hasattr(sched, 'path') and sched.path is not None:
print(f" scheduler.path class: {type(sched.path).__name__}")
if hasattr(sched.path, 'emb'):
emb = sched.path.emb
print(f" path.emb shape : {emb.shape}")
print(f" path.emb device : {emb.device}")
print(f" path.emb dtype : {emb.dtype}")
print(f" path.emb[0,:5] : {emb[0,:5].tolist()}")
if hasattr(sched.path, 'alpha'):
print(f" path.alpha : {getattr(sched.path, 'alpha', 'N/A')}")
if hasattr(sched.path, 'c'):
print(f" path.c : {getattr(sched.path, 'c', 'N/A')}")
else:
print(f" scheduler.path : MISSING or None!")
print(f" codebook_size : {getattr(sched, 'codebook_size', 'N/A')}")
print(f" shift : {getattr(sched, 'shift', 'N/A')}")
if hasattr(sched, 'config'):
print(f" scheduler.config : {dict(sched.config)}")
print(f" vae_temporal_stride : {getattr(pipe, 'vae_temporal_stride', 'N/A')}")
print(f" vae_spatial_stride : {getattr(pipe, 'vae_spatial_stride', 'N/A')}")
print(f" tokenizer class : {type(pipe.tokenizer).__name__}")
print("=" * 70 + "\n")
def diag_output(frames_output, label):
"""Diagnose pipeline output."""
print(f"\n --- Output diagnostics: {label} ---")
if isinstance(frames_output, np.ndarray):
diag_tensor(f"{label} raw", frames_output)
elif isinstance(frames_output, list):
print(f" [{label}] list of {len(frames_output)} items")
if len(frames_output) > 0:
f0 = frames_output[0]
if isinstance(f0, np.ndarray):
diag_tensor(f"{label}[0]", f0)
else:
print(f" [{label}[0]] type={type(f0)}")
else:
print(f" [{label}] type={type(frames_output)}")
def save_frames(frames, path, fps=12):
"""Save frames as video or image."""
os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True)
if path.endswith(".mp4"):
if isinstance(frames, np.ndarray) and frames.ndim == 4:
export_to_video(list(frames), output_video_path=path, fps=fps)
elif isinstance(frames, list):
export_to_video(frames, output_video_path=path, fps=fps)
else:
export_to_video(frames, output_video_path=path, fps=fps)
elif path.endswith((".jpg", ".png")):
from PIL import Image
if isinstance(frames, np.ndarray):
Image.fromarray(frames).save(path)
elif hasattr(frames, 'save'):
frames.save(path)
def main():
args = parse_args()
os.makedirs(args.out_dir, exist_ok=True)
dtype = getattr(torch, args.precision)
device = torch.device("cuda", args.device) if torch.cuda.is_available() else torch.device("cpu")
prompt = "a lone grizzly bear walks through a misty forest at dawn, sunlight catching its fur."
negative_prompt = "worst quality, low quality, inconsistent motion, static, still, blurry, jittery, distorted, ugly"
seed = 42
# =====================================================================
# Load pipeline
# =====================================================================
print(f"\n[1] Loading pipeline from {args.model} ...")
pipe = URSAPipeline.from_pretrained(
args.model, torch_dtype=dtype, trust_remote_code=True
).to(device)
diag_pipeline(pipe)
# =====================================================================
# Test A: Official README T2V (exact copy from README for FSQ320)
# FSQ320: height=320, width=512, num_frames=49, steps=50
# =====================================================================
print("\n" + "#" * 70)
print("# TEST A: Official README T2V (FSQ320 native resolution)")
print("#" * 70)
gen = torch.Generator(device=device).manual_seed(seed)
out = pipe(
prompt=f"motion=9.0, {prompt}",
negative_prompt=negative_prompt,
height=320,
width=512,
num_frames=49,
num_inference_steps=50,
guidance_scale=7,
generator=gen,
output_type="np",
)
frames = out.frames
diag_output(frames, "A_official_t2v")
if isinstance(frames, np.ndarray):
video_frames = frames[0] if frames.ndim == 5 else frames
else:
video_frames = frames
path_a = os.path.join(args.out_dir, "A_official_t2v_320x512_49f_50step.mp4")
try:
if isinstance(video_frames, np.ndarray):
export_to_video(list(video_frames), output_video_path=path_a, fps=12)
else:
export_to_video(video_frames, output_video_path=path_a, fps=12)
print(f" Saved: {path_a}")
except Exception as e:
print(f" Failed: {e}")
# Also test T2I at native resolution (1 frame)
print("\n# TEST A2: T2I at 320x512 (1 frame)")
gen = torch.Generator(device=device).manual_seed(seed)
out = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
height=320,
width=512,
num_frames=1,
num_inference_steps=25,
guidance_scale=7,
generator=gen,
)
image = out.frames[0]
path_a2 = os.path.join(args.out_dir, "A_official_t2i_320x512.jpg")
if hasattr(image, 'save'):
image.save(path_a2)
print(f" Saved: {path_a2} (PIL Image)")
else:
diag_output(out.frames, "A2_t2i")
# =====================================================================
# Test B: Different resolutions and step counts for video
# =====================================================================
for res_str in args.test_resolutions:
h, w = map(int, res_str.split("x"))
for steps in args.test_steps:
for gs_label, gs_val in [("nocfg", 1.0), ("cfg7", 7.0)]:
label = f"B_{h}x{w}_{steps}step_{gs_label}"
print(f"\n{'#' * 70}")
print(f"# TEST {label}")
print(f"# height={h} width={w} num_frames={args.num_frames}")
print(f"# steps={steps} guidance_scale={gs_val}")
print(f"{'#' * 70}")
gen = torch.Generator(device=device).manual_seed(seed)
neg = negative_prompt if gs_val > 1 else None
# Print scheduler state before call
print(f" scheduler.codebook_size = {pipe.scheduler.codebook_size}")
print(f" scheduler.path type = {type(pipe.scheduler.path).__name__}")
out = pipe(
prompt=prompt,
negative_prompt=neg,
height=h,
width=w,
num_frames=args.num_frames,
num_inference_steps=steps,
guidance_scale=gs_val,
guidance_trunc=0.9,
max_prompt_length=320,
vae_batch_size=1,
output_type="np",
generator=gen,
)
frames = out.frames
diag_output(frames, label)
# For video output (num_frames > 1), frames is [batch, T, H, W, 3]
if isinstance(frames, np.ndarray):
if frames.ndim == 5:
video_frames = frames[0] # [T, H, W, 3]
elif frames.ndim == 4:
video_frames = frames # [T, H, W, 3]
else:
video_frames = frames
elif isinstance(frames, list):
video_frames = frames
else:
video_frames = frames
path = os.path.join(args.out_dir, f"{label}.mp4")
try:
if isinstance(video_frames, np.ndarray):
export_to_video(list(video_frames), output_video_path=path, fps=12)
else:
export_to_video(video_frames, output_video_path=path, fps=12)
print(f" Saved: {path}")
except Exception as e:
print(f" Failed to save {path}: {e}")
# =====================================================================
# Test C: Student (if provided)
# =====================================================================
if args.student_ckpt:
print(f"\n{'#' * 70}")
print(f"# TEST C: Student 1-step")
print(f"{'#' * 70}")
teacher_state = {k: v.clone() for k, v in pipe.transformer.state_dict().items()}
student_state = torch.load(args.student_ckpt, map_location=device, weights_only=True)
print(f" student keys: {len(student_state)}")
print(f" teacher keys: {len(teacher_state)}")
# Check key compatibility
missing = set(teacher_state.keys()) - set(student_state.keys())
extra = set(student_state.keys()) - set(teacher_state.keys())
if missing:
print(f" WARNING: {len(missing)} keys in teacher but not student: {list(missing)[:5]}")
if extra:
print(f" WARNING: {len(extra)} keys in student but not teacher: {list(extra)[:5]}")
pipe.transformer.load_state_dict(student_state, strict=True)
pipe.transformer.eval()
for res_str in args.test_resolutions[:1]: # Just first resolution
h, w = map(int, res_str.split("x"))
for gs_label, gs_val in [("nocfg", 1.0), ("cfg7", 7.0)]:
label = f"C_student_{h}x{w}_1step_{gs_label}"
gen = torch.Generator(device=device).manual_seed(seed)
neg = negative_prompt if gs_val > 1 else None
out = pipe(
prompt=prompt,
negative_prompt=neg,
height=h,
width=w,
num_frames=args.num_frames,
num_inference_steps=1,
guidance_scale=gs_val,
guidance_trunc=0.9,
max_prompt_length=320,
vae_batch_size=1,
output_type="np",
generator=gen,
)
frames = out.frames
diag_output(frames, label)
if isinstance(frames, np.ndarray):
video_frames = frames[0] if frames.ndim == 5 else frames
else:
video_frames = frames
path = os.path.join(args.out_dir, f"{label}.mp4")
try:
if isinstance(video_frames, np.ndarray):
export_to_video(list(video_frames), output_video_path=path, fps=12)
else:
export_to_video(video_frames, output_video_path=path, fps=12)
print(f" Saved: {path}")
except Exception as e:
print(f" Failed to save {path}: {e}")
# Restore teacher
pipe.transformer.load_state_dict(teacher_state, strict=True)
print(f"\n[DONE] All outputs in {args.out_dir}")
print("\nCheck these files to diagnose blurriness:")
print(" - A_official_t2i_1024x1024.jpg → should be sharp (official T2I)")
print(" - B_*_cfg7.mp4 → teacher video with CFG")
print(" - B_*_nocfg.mp4 → teacher video without CFG")
print(" - Compare different resolutions and step counts")
print(" - If ALL are blurry, the issue is in pipeline/scheduler/VAE loading")
print(" - If only low-res are blurry, it's a resolution issue")
print(" - If only low-step are blurry, need more steps")
if __name__ == "__main__":
main()