#!/usr/bin/env python3 """A/B test: official URSA inference vs eval_distill_dimo inference. This script runs the EXACT same pipeline call in two ways: A) "official" — follows README Quick Start verbatim B) "eval" — follows eval_distill_dimo.py logic Both use the same pipeline instance, same prompt, same seed. Saves side-by-side outputs + prints every intermediate diagnostic. Usage: python scripts/ab_test_inference.py \ --model /gfs/space/private/fengzl/World_Model/URSA-1.7B \ --device 0 This will generate: outputs/ab_test/official_t2i.jpg outputs/ab_test/official_t2v.mp4 outputs/ab_test/eval_teacher_cfg.mp4 outputs/ab_test/eval_teacher_nocfg.mp4 outputs/ab_test/eval_student_*.mp4 (if --student_ckpt given) """ import argparse import os import sys import numpy as np import torch _REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) if _REPO_ROOT not in sys.path: sys.path.insert(0, _REPO_ROOT) from diffnext.pipelines import URSAPipeline from diffnext.utils import export_to_image, export_to_video def parse_args(): p = argparse.ArgumentParser(description="A/B test URSA inference") p.add_argument("--model", required=True, help="URSA model path") p.add_argument("--student_ckpt", default=None, help="Optional student.pt") p.add_argument("--device", type=int, default=0) p.add_argument("--precision", default="float16", choices=["float16", "bfloat16"]) p.add_argument("--out_dir", default="./outputs/ab_test") # Test different resolutions — FSQ320 native is 320x512 p.add_argument("--test_resolutions", nargs="+", default=["320x512"], help="Resolutions to test as HxW strings (FSQ320 native: 320x512)") p.add_argument("--test_steps", nargs="+", type=int, default=[25, 50], help="Inference steps to test") p.add_argument("--num_frames", type=int, default=49) return p.parse_args() def diag(label, obj): """Print diagnostic.""" print(f" [{label}] {obj}") def diag_tensor(label, t): """Print tensor diagnostics.""" if isinstance(t, torch.Tensor): print(f" [{label}] shape={t.shape} dtype={t.dtype} device={t.device} " f"min={t.min().item():.4f} max={t.max().item():.4f} mean={t.mean().item():.4f}") elif isinstance(t, np.ndarray): print(f" [{label}] shape={t.shape} dtype={t.dtype} " f"min={t.min()} max={t.max()} mean={t.mean():.2f}") def diag_pipeline(pipe): """Full pipeline diagnostic.""" print("\n" + "=" * 70) print(" PIPELINE DIAGNOSTICS") print("=" * 70) print(f" pipeline class : {type(pipe).__name__}") print(f" transformer class : {type(pipe.transformer).__name__}") print(f" transformer device : {next(pipe.transformer.parameters()).device}") print(f" transformer dtype : {next(pipe.transformer.parameters()).dtype}") print(f" vae class : {type(pipe.vae).__name__}") print(f" vae device : {next(pipe.vae.parameters()).device}") print(f" scheduler class : {type(pipe.scheduler).__name__}") print(f" scheduler repr : {repr(pipe.scheduler)}") sched = pipe.scheduler if hasattr(sched, 'path') and sched.path is not None: print(f" scheduler.path class: {type(sched.path).__name__}") if hasattr(sched.path, 'emb'): emb = sched.path.emb print(f" path.emb shape : {emb.shape}") print(f" path.emb device : {emb.device}") print(f" path.emb dtype : {emb.dtype}") print(f" path.emb[0,:5] : {emb[0,:5].tolist()}") if hasattr(sched.path, 'alpha'): print(f" path.alpha : {getattr(sched.path, 'alpha', 'N/A')}") if hasattr(sched.path, 'c'): print(f" path.c : {getattr(sched.path, 'c', 'N/A')}") else: print(f" scheduler.path : MISSING or None!") print(f" codebook_size : {getattr(sched, 'codebook_size', 'N/A')}") print(f" shift : {getattr(sched, 'shift', 'N/A')}") if hasattr(sched, 'config'): print(f" scheduler.config : {dict(sched.config)}") print(f" vae_temporal_stride : {getattr(pipe, 'vae_temporal_stride', 'N/A')}") print(f" vae_spatial_stride : {getattr(pipe, 'vae_spatial_stride', 'N/A')}") print(f" tokenizer class : {type(pipe.tokenizer).__name__}") print("=" * 70 + "\n") def diag_output(frames_output, label): """Diagnose pipeline output.""" print(f"\n --- Output diagnostics: {label} ---") if isinstance(frames_output, np.ndarray): diag_tensor(f"{label} raw", frames_output) elif isinstance(frames_output, list): print(f" [{label}] list of {len(frames_output)} items") if len(frames_output) > 0: f0 = frames_output[0] if isinstance(f0, np.ndarray): diag_tensor(f"{label}[0]", f0) else: print(f" [{label}[0]] type={type(f0)}") else: print(f" [{label}] type={type(frames_output)}") def save_frames(frames, path, fps=12): """Save frames as video or image.""" os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True) if path.endswith(".mp4"): if isinstance(frames, np.ndarray) and frames.ndim == 4: export_to_video(list(frames), output_video_path=path, fps=fps) elif isinstance(frames, list): export_to_video(frames, output_video_path=path, fps=fps) else: export_to_video(frames, output_video_path=path, fps=fps) elif path.endswith((".jpg", ".png")): from PIL import Image if isinstance(frames, np.ndarray): Image.fromarray(frames).save(path) elif hasattr(frames, 'save'): frames.save(path) def main(): args = parse_args() os.makedirs(args.out_dir, exist_ok=True) dtype = getattr(torch, args.precision) device = torch.device("cuda", args.device) if torch.cuda.is_available() else torch.device("cpu") prompt = "a lone grizzly bear walks through a misty forest at dawn, sunlight catching its fur." negative_prompt = "worst quality, low quality, inconsistent motion, static, still, blurry, jittery, distorted, ugly" seed = 42 # ===================================================================== # Load pipeline # ===================================================================== print(f"\n[1] Loading pipeline from {args.model} ...") pipe = URSAPipeline.from_pretrained( args.model, torch_dtype=dtype, trust_remote_code=True ).to(device) diag_pipeline(pipe) # ===================================================================== # Test A: Official README T2V (exact copy from README for FSQ320) # FSQ320: height=320, width=512, num_frames=49, steps=50 # ===================================================================== print("\n" + "#" * 70) print("# TEST A: Official README T2V (FSQ320 native resolution)") print("#" * 70) gen = torch.Generator(device=device).manual_seed(seed) out = pipe( prompt=f"motion=9.0, {prompt}", negative_prompt=negative_prompt, height=320, width=512, num_frames=49, num_inference_steps=50, guidance_scale=7, generator=gen, output_type="np", ) frames = out.frames diag_output(frames, "A_official_t2v") if isinstance(frames, np.ndarray): video_frames = frames[0] if frames.ndim == 5 else frames else: video_frames = frames path_a = os.path.join(args.out_dir, "A_official_t2v_320x512_49f_50step.mp4") try: if isinstance(video_frames, np.ndarray): export_to_video(list(video_frames), output_video_path=path_a, fps=12) else: export_to_video(video_frames, output_video_path=path_a, fps=12) print(f" Saved: {path_a}") except Exception as e: print(f" Failed: {e}") # Also test T2I at native resolution (1 frame) print("\n# TEST A2: T2I at 320x512 (1 frame)") gen = torch.Generator(device=device).manual_seed(seed) out = pipe( prompt=prompt, negative_prompt=negative_prompt, height=320, width=512, num_frames=1, num_inference_steps=25, guidance_scale=7, generator=gen, ) image = out.frames[0] path_a2 = os.path.join(args.out_dir, "A_official_t2i_320x512.jpg") if hasattr(image, 'save'): image.save(path_a2) print(f" Saved: {path_a2} (PIL Image)") else: diag_output(out.frames, "A2_t2i") # ===================================================================== # Test B: Different resolutions and step counts for video # ===================================================================== for res_str in args.test_resolutions: h, w = map(int, res_str.split("x")) for steps in args.test_steps: for gs_label, gs_val in [("nocfg", 1.0), ("cfg7", 7.0)]: label = f"B_{h}x{w}_{steps}step_{gs_label}" print(f"\n{'#' * 70}") print(f"# TEST {label}") print(f"# height={h} width={w} num_frames={args.num_frames}") print(f"# steps={steps} guidance_scale={gs_val}") print(f"{'#' * 70}") gen = torch.Generator(device=device).manual_seed(seed) neg = negative_prompt if gs_val > 1 else None # Print scheduler state before call print(f" scheduler.codebook_size = {pipe.scheduler.codebook_size}") print(f" scheduler.path type = {type(pipe.scheduler.path).__name__}") out = pipe( prompt=prompt, negative_prompt=neg, height=h, width=w, num_frames=args.num_frames, num_inference_steps=steps, guidance_scale=gs_val, guidance_trunc=0.9, max_prompt_length=320, vae_batch_size=1, output_type="np", generator=gen, ) frames = out.frames diag_output(frames, label) # For video output (num_frames > 1), frames is [batch, T, H, W, 3] if isinstance(frames, np.ndarray): if frames.ndim == 5: video_frames = frames[0] # [T, H, W, 3] elif frames.ndim == 4: video_frames = frames # [T, H, W, 3] else: video_frames = frames elif isinstance(frames, list): video_frames = frames else: video_frames = frames path = os.path.join(args.out_dir, f"{label}.mp4") try: if isinstance(video_frames, np.ndarray): export_to_video(list(video_frames), output_video_path=path, fps=12) else: export_to_video(video_frames, output_video_path=path, fps=12) print(f" Saved: {path}") except Exception as e: print(f" Failed to save {path}: {e}") # ===================================================================== # Test C: Student (if provided) # ===================================================================== if args.student_ckpt: print(f"\n{'#' * 70}") print(f"# TEST C: Student 1-step") print(f"{'#' * 70}") teacher_state = {k: v.clone() for k, v in pipe.transformer.state_dict().items()} student_state = torch.load(args.student_ckpt, map_location=device, weights_only=True) print(f" student keys: {len(student_state)}") print(f" teacher keys: {len(teacher_state)}") # Check key compatibility missing = set(teacher_state.keys()) - set(student_state.keys()) extra = set(student_state.keys()) - set(teacher_state.keys()) if missing: print(f" WARNING: {len(missing)} keys in teacher but not student: {list(missing)[:5]}") if extra: print(f" WARNING: {len(extra)} keys in student but not teacher: {list(extra)[:5]}") pipe.transformer.load_state_dict(student_state, strict=True) pipe.transformer.eval() for res_str in args.test_resolutions[:1]: # Just first resolution h, w = map(int, res_str.split("x")) for gs_label, gs_val in [("nocfg", 1.0), ("cfg7", 7.0)]: label = f"C_student_{h}x{w}_1step_{gs_label}" gen = torch.Generator(device=device).manual_seed(seed) neg = negative_prompt if gs_val > 1 else None out = pipe( prompt=prompt, negative_prompt=neg, height=h, width=w, num_frames=args.num_frames, num_inference_steps=1, guidance_scale=gs_val, guidance_trunc=0.9, max_prompt_length=320, vae_batch_size=1, output_type="np", generator=gen, ) frames = out.frames diag_output(frames, label) if isinstance(frames, np.ndarray): video_frames = frames[0] if frames.ndim == 5 else frames else: video_frames = frames path = os.path.join(args.out_dir, f"{label}.mp4") try: if isinstance(video_frames, np.ndarray): export_to_video(list(video_frames), output_video_path=path, fps=12) else: export_to_video(video_frames, output_video_path=path, fps=12) print(f" Saved: {path}") except Exception as e: print(f" Failed to save {path}: {e}") # Restore teacher pipe.transformer.load_state_dict(teacher_state, strict=True) print(f"\n[DONE] All outputs in {args.out_dir}") print("\nCheck these files to diagnose blurriness:") print(" - A_official_t2i_1024x1024.jpg → should be sharp (official T2I)") print(" - B_*_cfg7.mp4 → teacher video with CFG") print(" - B_*_nocfg.mp4 → teacher video without CFG") print(" - Compare different resolutions and step counts") print(" - If ALL are blurry, the issue is in pipeline/scheduler/VAE loading") print(" - If only low-res are blurry, it's a resolution issue") print(" - If only low-step are blurry, need more steps") if __name__ == "__main__": main()