| |
| """A/B test: official URSA inference vs eval_distill_dimo inference. |
| |
| This script runs the EXACT same pipeline call in two ways: |
| A) "official" — follows README Quick Start verbatim |
| B) "eval" — follows eval_distill_dimo.py logic |
| |
| Both use the same pipeline instance, same prompt, same seed. |
| Saves side-by-side outputs + prints every intermediate diagnostic. |
| |
| Usage: |
| python scripts/ab_test_inference.py \ |
| --model /gfs/space/private/fengzl/World_Model/URSA-1.7B \ |
| --device 0 |
| |
| This will generate: |
| outputs/ab_test/official_t2i.jpg |
| outputs/ab_test/official_t2v.mp4 |
| outputs/ab_test/eval_teacher_cfg.mp4 |
| outputs/ab_test/eval_teacher_nocfg.mp4 |
| outputs/ab_test/eval_student_*.mp4 (if --student_ckpt given) |
| """ |
|
|
| import argparse |
| import os |
| import sys |
|
|
| import numpy as np |
| import torch |
|
|
| _REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) |
| if _REPO_ROOT not in sys.path: |
| sys.path.insert(0, _REPO_ROOT) |
|
|
| from diffnext.pipelines import URSAPipeline |
| from diffnext.utils import export_to_image, export_to_video |
|
|
|
|
| def parse_args(): |
| p = argparse.ArgumentParser(description="A/B test URSA inference") |
| p.add_argument("--model", required=True, help="URSA model path") |
| p.add_argument("--student_ckpt", default=None, help="Optional student.pt") |
| p.add_argument("--device", type=int, default=0) |
| p.add_argument("--precision", default="float16", choices=["float16", "bfloat16"]) |
| p.add_argument("--out_dir", default="./outputs/ab_test") |
| |
| p.add_argument("--test_resolutions", nargs="+", default=["320x512"], |
| help="Resolutions to test as HxW strings (FSQ320 native: 320x512)") |
| p.add_argument("--test_steps", nargs="+", type=int, default=[25, 50], |
| help="Inference steps to test") |
| p.add_argument("--num_frames", type=int, default=49) |
| return p.parse_args() |
|
|
|
|
| def diag(label, obj): |
| """Print diagnostic.""" |
| print(f" [{label}] {obj}") |
|
|
|
|
| def diag_tensor(label, t): |
| """Print tensor diagnostics.""" |
| if isinstance(t, torch.Tensor): |
| print(f" [{label}] shape={t.shape} dtype={t.dtype} device={t.device} " |
| f"min={t.min().item():.4f} max={t.max().item():.4f} mean={t.mean().item():.4f}") |
| elif isinstance(t, np.ndarray): |
| print(f" [{label}] shape={t.shape} dtype={t.dtype} " |
| f"min={t.min()} max={t.max()} mean={t.mean():.2f}") |
|
|
|
|
| def diag_pipeline(pipe): |
| """Full pipeline diagnostic.""" |
| print("\n" + "=" * 70) |
| print(" PIPELINE DIAGNOSTICS") |
| print("=" * 70) |
| print(f" pipeline class : {type(pipe).__name__}") |
| print(f" transformer class : {type(pipe.transformer).__name__}") |
| print(f" transformer device : {next(pipe.transformer.parameters()).device}") |
| print(f" transformer dtype : {next(pipe.transformer.parameters()).dtype}") |
| print(f" vae class : {type(pipe.vae).__name__}") |
| print(f" vae device : {next(pipe.vae.parameters()).device}") |
| print(f" scheduler class : {type(pipe.scheduler).__name__}") |
| print(f" scheduler repr : {repr(pipe.scheduler)}") |
|
|
| sched = pipe.scheduler |
| if hasattr(sched, 'path') and sched.path is not None: |
| print(f" scheduler.path class: {type(sched.path).__name__}") |
| if hasattr(sched.path, 'emb'): |
| emb = sched.path.emb |
| print(f" path.emb shape : {emb.shape}") |
| print(f" path.emb device : {emb.device}") |
| print(f" path.emb dtype : {emb.dtype}") |
| print(f" path.emb[0,:5] : {emb[0,:5].tolist()}") |
| if hasattr(sched.path, 'alpha'): |
| print(f" path.alpha : {getattr(sched.path, 'alpha', 'N/A')}") |
| if hasattr(sched.path, 'c'): |
| print(f" path.c : {getattr(sched.path, 'c', 'N/A')}") |
| else: |
| print(f" scheduler.path : MISSING or None!") |
|
|
| print(f" codebook_size : {getattr(sched, 'codebook_size', 'N/A')}") |
| print(f" shift : {getattr(sched, 'shift', 'N/A')}") |
|
|
| if hasattr(sched, 'config'): |
| print(f" scheduler.config : {dict(sched.config)}") |
|
|
| print(f" vae_temporal_stride : {getattr(pipe, 'vae_temporal_stride', 'N/A')}") |
| print(f" vae_spatial_stride : {getattr(pipe, 'vae_spatial_stride', 'N/A')}") |
| print(f" tokenizer class : {type(pipe.tokenizer).__name__}") |
| print("=" * 70 + "\n") |
|
|
|
|
| def diag_output(frames_output, label): |
| """Diagnose pipeline output.""" |
| print(f"\n --- Output diagnostics: {label} ---") |
| if isinstance(frames_output, np.ndarray): |
| diag_tensor(f"{label} raw", frames_output) |
| elif isinstance(frames_output, list): |
| print(f" [{label}] list of {len(frames_output)} items") |
| if len(frames_output) > 0: |
| f0 = frames_output[0] |
| if isinstance(f0, np.ndarray): |
| diag_tensor(f"{label}[0]", f0) |
| else: |
| print(f" [{label}[0]] type={type(f0)}") |
| else: |
| print(f" [{label}] type={type(frames_output)}") |
|
|
|
|
| def save_frames(frames, path, fps=12): |
| """Save frames as video or image.""" |
| os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True) |
| if path.endswith(".mp4"): |
| if isinstance(frames, np.ndarray) and frames.ndim == 4: |
| export_to_video(list(frames), output_video_path=path, fps=fps) |
| elif isinstance(frames, list): |
| export_to_video(frames, output_video_path=path, fps=fps) |
| else: |
| export_to_video(frames, output_video_path=path, fps=fps) |
| elif path.endswith((".jpg", ".png")): |
| from PIL import Image |
| if isinstance(frames, np.ndarray): |
| Image.fromarray(frames).save(path) |
| elif hasattr(frames, 'save'): |
| frames.save(path) |
|
|
|
|
| def main(): |
| args = parse_args() |
| os.makedirs(args.out_dir, exist_ok=True) |
|
|
| dtype = getattr(torch, args.precision) |
| device = torch.device("cuda", args.device) if torch.cuda.is_available() else torch.device("cpu") |
|
|
| prompt = "a lone grizzly bear walks through a misty forest at dawn, sunlight catching its fur." |
| negative_prompt = "worst quality, low quality, inconsistent motion, static, still, blurry, jittery, distorted, ugly" |
| seed = 42 |
|
|
| |
| |
| |
| print(f"\n[1] Loading pipeline from {args.model} ...") |
| pipe = URSAPipeline.from_pretrained( |
| args.model, torch_dtype=dtype, trust_remote_code=True |
| ).to(device) |
|
|
| diag_pipeline(pipe) |
|
|
| |
| |
| |
| |
| print("\n" + "#" * 70) |
| print("# TEST A: Official README T2V (FSQ320 native resolution)") |
| print("#" * 70) |
|
|
| gen = torch.Generator(device=device).manual_seed(seed) |
| out = pipe( |
| prompt=f"motion=9.0, {prompt}", |
| negative_prompt=negative_prompt, |
| height=320, |
| width=512, |
| num_frames=49, |
| num_inference_steps=50, |
| guidance_scale=7, |
| generator=gen, |
| output_type="np", |
| ) |
| frames = out.frames |
| diag_output(frames, "A_official_t2v") |
| if isinstance(frames, np.ndarray): |
| video_frames = frames[0] if frames.ndim == 5 else frames |
| else: |
| video_frames = frames |
| path_a = os.path.join(args.out_dir, "A_official_t2v_320x512_49f_50step.mp4") |
| try: |
| if isinstance(video_frames, np.ndarray): |
| export_to_video(list(video_frames), output_video_path=path_a, fps=12) |
| else: |
| export_to_video(video_frames, output_video_path=path_a, fps=12) |
| print(f" Saved: {path_a}") |
| except Exception as e: |
| print(f" Failed: {e}") |
|
|
| |
| print("\n# TEST A2: T2I at 320x512 (1 frame)") |
| gen = torch.Generator(device=device).manual_seed(seed) |
| out = pipe( |
| prompt=prompt, |
| negative_prompt=negative_prompt, |
| height=320, |
| width=512, |
| num_frames=1, |
| num_inference_steps=25, |
| guidance_scale=7, |
| generator=gen, |
| ) |
| image = out.frames[0] |
| path_a2 = os.path.join(args.out_dir, "A_official_t2i_320x512.jpg") |
| if hasattr(image, 'save'): |
| image.save(path_a2) |
| print(f" Saved: {path_a2} (PIL Image)") |
| else: |
| diag_output(out.frames, "A2_t2i") |
|
|
| |
| |
| |
| for res_str in args.test_resolutions: |
| h, w = map(int, res_str.split("x")) |
| for steps in args.test_steps: |
| for gs_label, gs_val in [("nocfg", 1.0), ("cfg7", 7.0)]: |
| label = f"B_{h}x{w}_{steps}step_{gs_label}" |
| print(f"\n{'#' * 70}") |
| print(f"# TEST {label}") |
| print(f"# height={h} width={w} num_frames={args.num_frames}") |
| print(f"# steps={steps} guidance_scale={gs_val}") |
| print(f"{'#' * 70}") |
|
|
| gen = torch.Generator(device=device).manual_seed(seed) |
| neg = negative_prompt if gs_val > 1 else None |
|
|
| |
| print(f" scheduler.codebook_size = {pipe.scheduler.codebook_size}") |
| print(f" scheduler.path type = {type(pipe.scheduler.path).__name__}") |
|
|
| out = pipe( |
| prompt=prompt, |
| negative_prompt=neg, |
| height=h, |
| width=w, |
| num_frames=args.num_frames, |
| num_inference_steps=steps, |
| guidance_scale=gs_val, |
| guidance_trunc=0.9, |
| max_prompt_length=320, |
| vae_batch_size=1, |
| output_type="np", |
| generator=gen, |
| ) |
|
|
| frames = out.frames |
| diag_output(frames, label) |
|
|
| |
| if isinstance(frames, np.ndarray): |
| if frames.ndim == 5: |
| video_frames = frames[0] |
| elif frames.ndim == 4: |
| video_frames = frames |
| else: |
| video_frames = frames |
| elif isinstance(frames, list): |
| video_frames = frames |
| else: |
| video_frames = frames |
|
|
| path = os.path.join(args.out_dir, f"{label}.mp4") |
| try: |
| if isinstance(video_frames, np.ndarray): |
| export_to_video(list(video_frames), output_video_path=path, fps=12) |
| else: |
| export_to_video(video_frames, output_video_path=path, fps=12) |
| print(f" Saved: {path}") |
| except Exception as e: |
| print(f" Failed to save {path}: {e}") |
|
|
| |
| |
| |
| if args.student_ckpt: |
| print(f"\n{'#' * 70}") |
| print(f"# TEST C: Student 1-step") |
| print(f"{'#' * 70}") |
|
|
| teacher_state = {k: v.clone() for k, v in pipe.transformer.state_dict().items()} |
| student_state = torch.load(args.student_ckpt, map_location=device, weights_only=True) |
|
|
| print(f" student keys: {len(student_state)}") |
| print(f" teacher keys: {len(teacher_state)}") |
|
|
| |
| missing = set(teacher_state.keys()) - set(student_state.keys()) |
| extra = set(student_state.keys()) - set(teacher_state.keys()) |
| if missing: |
| print(f" WARNING: {len(missing)} keys in teacher but not student: {list(missing)[:5]}") |
| if extra: |
| print(f" WARNING: {len(extra)} keys in student but not teacher: {list(extra)[:5]}") |
|
|
| pipe.transformer.load_state_dict(student_state, strict=True) |
| pipe.transformer.eval() |
|
|
| for res_str in args.test_resolutions[:1]: |
| h, w = map(int, res_str.split("x")) |
| for gs_label, gs_val in [("nocfg", 1.0), ("cfg7", 7.0)]: |
| label = f"C_student_{h}x{w}_1step_{gs_label}" |
| gen = torch.Generator(device=device).manual_seed(seed) |
| neg = negative_prompt if gs_val > 1 else None |
|
|
| out = pipe( |
| prompt=prompt, |
| negative_prompt=neg, |
| height=h, |
| width=w, |
| num_frames=args.num_frames, |
| num_inference_steps=1, |
| guidance_scale=gs_val, |
| guidance_trunc=0.9, |
| max_prompt_length=320, |
| vae_batch_size=1, |
| output_type="np", |
| generator=gen, |
| ) |
|
|
| frames = out.frames |
| diag_output(frames, label) |
|
|
| if isinstance(frames, np.ndarray): |
| video_frames = frames[0] if frames.ndim == 5 else frames |
| else: |
| video_frames = frames |
|
|
| path = os.path.join(args.out_dir, f"{label}.mp4") |
| try: |
| if isinstance(video_frames, np.ndarray): |
| export_to_video(list(video_frames), output_video_path=path, fps=12) |
| else: |
| export_to_video(video_frames, output_video_path=path, fps=12) |
| print(f" Saved: {path}") |
| except Exception as e: |
| print(f" Failed to save {path}: {e}") |
|
|
| |
| pipe.transformer.load_state_dict(teacher_state, strict=True) |
|
|
| print(f"\n[DONE] All outputs in {args.out_dir}") |
| print("\nCheck these files to diagnose blurriness:") |
| print(" - A_official_t2i_1024x1024.jpg → should be sharp (official T2I)") |
| print(" - B_*_cfg7.mp4 → teacher video with CFG") |
| print(" - B_*_nocfg.mp4 → teacher video without CFG") |
| print(" - Compare different resolutions and step counts") |
| print(" - If ALL are blurry, the issue is in pipeline/scheduler/VAE loading") |
| print(" - If only low-res are blurry, it's a resolution issue") |
| print(" - If only low-step are blurry, need more steps") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|