from __future__ import annotations import argparse import json import shutil from pathlib import Path from typing import Optional import torch from safetensors.torch import load_file as load_safetensors from transformers import AutoTokenizer, Qwen3ForCausalLM from .modeling_autoencoder import BitDanceAutoencoder from .modeling_diffusion_head import BitDanceDiffusionHead from .modeling_projector import BitDanceProjector from .pipeline_bitdance import BitDanceDiffusionPipeline def _resolve_dtype(dtype: str) -> torch.dtype: mapping = { "float32": torch.float32, "float16": torch.float16, "bfloat16": torch.bfloat16, } if dtype not in mapping: raise ValueError(f"Unsupported torch dtype '{dtype}'. Choose from {sorted(mapping)}.") return mapping[dtype] def _load_json(path: Path): with path.open("r", encoding="utf-8") as handle: return json.load(handle) def _copy_runtime_source(output_path: Path) -> None: package_root = Path(__file__).resolve().parent target_pkg = output_path / "bitdance_diffusers" shutil.copytree(package_root, target_pkg, dirs_exist_ok=True) loader_script = output_path / "load_pipeline.py" loader_script.write_text( "\n".join( [ "import sys", "from pathlib import Path", "", "from diffusers import DiffusionPipeline", "", "model_dir = Path(__file__).resolve().parent", "sys.path.insert(0, str(model_dir))", 'pipe = DiffusionPipeline.from_pretrained(model_dir, custom_pipeline=model_dir).to("cuda")', 'images = pipe(prompt="A scenic mountain lake at sunrise.").images', 'images[0].save("sample.png")', ] ) + "\n", encoding="utf-8", ) def convert_bitdance_to_diffusers( source_model_path: str, output_path: str, torch_dtype: str = "bfloat16", device: str = "cpu", copy_runtime_source: bool = True, ) -> Path: source = Path(source_model_path) output = Path(output_path) output.mkdir(parents=True, exist_ok=True) dtype = _resolve_dtype(torch_dtype) tokenizer = AutoTokenizer.from_pretrained(source) text_encoder = Qwen3ForCausalLM.from_pretrained( source, torch_dtype=dtype, low_cpu_mem_usage=True, ).eval() ae_config = _load_json(source / "ae_config.json") ddconfig = ae_config.get("ddconfig", ae_config) gan_decoder = bool(ae_config.get("gan_decoder", False)) autoencoder = BitDanceAutoencoder(ddconfig=ddconfig, gan_decoder=gan_decoder).eval() autoencoder.load_state_dict(load_safetensors(source / "ae.safetensors"), strict=True, assign=True) vision_head_config = _load_json(source / "vision_head_config.json") diffusion_head = BitDanceDiffusionHead(**vision_head_config).eval() diffusion_head.load_state_dict(load_safetensors(source / "vision_head.safetensors"), strict=True, assign=True) projector = BitDanceProjector( in_dim=int(ddconfig["z_channels"]), out_dim=int(text_encoder.config.hidden_size), hidden_act="gelu_pytorch_tanh", ).eval() projector.load_state_dict(load_safetensors(source / "projector.safetensors"), strict=True, assign=True) if device: text_encoder.to(device=device) autoencoder.to(device=device) diffusion_head.to(device=device) projector.to(device=device) pipeline = BitDanceDiffusionPipeline( tokenizer=tokenizer, text_encoder=text_encoder, autoencoder=autoencoder, diffusion_head=diffusion_head, projector=projector, ) pipeline.save_pretrained(output, safe_serialization=True) if copy_runtime_source: _copy_runtime_source(output) return output def parse_args(argv: Optional[list[str]] = None) -> argparse.Namespace: parser = argparse.ArgumentParser(description="Convert BitDance checkpoints to Diffusers format.") parser.add_argument("--source_model_path", type=str, required=True) parser.add_argument("--output_path", type=str, required=True) parser.add_argument("--torch_dtype", type=str, default="bfloat16", choices=["float32", "float16", "bfloat16"]) parser.add_argument("--device", type=str, default="cpu") parser.add_argument( "--copy_runtime_source", action=argparse.BooleanOptionalAction, default=True, help="Copy self-contained runtime source into output directory.", ) return parser.parse_args(argv) def main(argv: Optional[list[str]] = None) -> None: args = parse_args(argv) converted = convert_bitdance_to_diffusers( source_model_path=args.source_model_path, output_path=args.output_path, torch_dtype=args.torch_dtype, device=args.device, copy_runtime_source=args.copy_runtime_source, ) print(f"Saved converted Diffusers pipeline to: {converted}") if __name__ == "__main__": main()