BiliSakura's picture
Update all files for BitDance-Tokenizer-diffusers
e754228 verified
from __future__ import annotations
import argparse
import json
import shutil
from pathlib import Path
from typing import Optional
import torch
from safetensors.torch import load_file as load_safetensors
from transformers import AutoTokenizer, Qwen3ForCausalLM
from .modeling_autoencoder import BitDanceAutoencoder
from .modeling_diffusion_head import BitDanceDiffusionHead
from .modeling_projector import BitDanceProjector
from .pipeline_bitdance import BitDanceDiffusionPipeline
def _resolve_dtype(dtype: str) -> torch.dtype:
mapping = {
"float32": torch.float32,
"float16": torch.float16,
"bfloat16": torch.bfloat16,
}
if dtype not in mapping:
raise ValueError(f"Unsupported torch dtype '{dtype}'. Choose from {sorted(mapping)}.")
return mapping[dtype]
def _load_json(path: Path):
with path.open("r", encoding="utf-8") as handle:
return json.load(handle)
def _copy_runtime_source(output_path: Path) -> None:
package_root = Path(__file__).resolve().parent
target_pkg = output_path / "bitdance_diffusers"
shutil.copytree(package_root, target_pkg, dirs_exist_ok=True)
loader_script = output_path / "load_pipeline.py"
loader_script.write_text(
"\n".join(
[
"import sys",
"from pathlib import Path",
"",
"from diffusers import DiffusionPipeline",
"",
"model_dir = Path(__file__).resolve().parent",
"sys.path.insert(0, str(model_dir))",
'pipe = DiffusionPipeline.from_pretrained(model_dir, custom_pipeline=model_dir).to("cuda")',
'images = pipe(prompt="A scenic mountain lake at sunrise.").images',
'images[0].save("sample.png")',
]
)
+ "\n",
encoding="utf-8",
)
def convert_bitdance_to_diffusers(
source_model_path: str,
output_path: str,
torch_dtype: str = "bfloat16",
device: str = "cpu",
copy_runtime_source: bool = True,
) -> Path:
source = Path(source_model_path)
output = Path(output_path)
output.mkdir(parents=True, exist_ok=True)
dtype = _resolve_dtype(torch_dtype)
tokenizer = AutoTokenizer.from_pretrained(source)
text_encoder = Qwen3ForCausalLM.from_pretrained(
source,
torch_dtype=dtype,
low_cpu_mem_usage=True,
).eval()
ae_config = _load_json(source / "ae_config.json")
ddconfig = ae_config.get("ddconfig", ae_config)
gan_decoder = bool(ae_config.get("gan_decoder", False))
autoencoder = BitDanceAutoencoder(ddconfig=ddconfig, gan_decoder=gan_decoder).eval()
autoencoder.load_state_dict(load_safetensors(source / "ae.safetensors"), strict=True, assign=True)
vision_head_config = _load_json(source / "vision_head_config.json")
diffusion_head = BitDanceDiffusionHead(**vision_head_config).eval()
diffusion_head.load_state_dict(load_safetensors(source / "vision_head.safetensors"), strict=True, assign=True)
projector = BitDanceProjector(
in_dim=int(ddconfig["z_channels"]),
out_dim=int(text_encoder.config.hidden_size),
hidden_act="gelu_pytorch_tanh",
).eval()
projector.load_state_dict(load_safetensors(source / "projector.safetensors"), strict=True, assign=True)
if device:
text_encoder.to(device=device)
autoencoder.to(device=device)
diffusion_head.to(device=device)
projector.to(device=device)
pipeline = BitDanceDiffusionPipeline(
tokenizer=tokenizer,
text_encoder=text_encoder,
autoencoder=autoencoder,
diffusion_head=diffusion_head,
projector=projector,
)
pipeline.save_pretrained(output, safe_serialization=True)
if copy_runtime_source:
_copy_runtime_source(output)
return output
def parse_args(argv: Optional[list[str]] = None) -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Convert BitDance checkpoints to Diffusers format.")
parser.add_argument("--source_model_path", type=str, required=True)
parser.add_argument("--output_path", type=str, required=True)
parser.add_argument("--torch_dtype", type=str, default="bfloat16", choices=["float32", "float16", "bfloat16"])
parser.add_argument("--device", type=str, default="cpu")
parser.add_argument(
"--copy_runtime_source",
action=argparse.BooleanOptionalAction,
default=True,
help="Copy self-contained runtime source into output directory.",
)
return parser.parse_args(argv)
def main(argv: Optional[list[str]] = None) -> None:
args = parse_args(argv)
converted = convert_bitdance_to_diffusers(
source_model_path=args.source_model_path,
output_path=args.output_path,
torch_dtype=args.torch_dtype,
device=args.device,
copy_runtime_source=args.copy_runtime_source,
)
print(f"Saved converted Diffusers pipeline to: {converted}")
if __name__ == "__main__":
main()