Z-Image-Turbo / VideoX-Fun /scripts /z_image /export_transformer_body_onnx.py
yongqiang
initialize this repo
ba96580
#!/usr/bin/env python3
"""Export the pure Z-Image transformer body (no ControlNet) to ONNX."""
import argparse
import logging
import os
import sys
from collections import OrderedDict
from typing import Dict, List, Optional, OrderedDict as OrderedDictType, Tuple
import numpy as np
import torch
import onnx
from onnx import numpy_helper
import subprocess
from omegaconf import OmegaConf
REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
if REPO_ROOT not in sys.path:
sys.path.insert(0, REPO_ROOT)
from videox_fun.models import ZImageTransformer2DModel # noqa: E402
LOGGER = logging.getLogger("export_transformer_body_onnx")
logging.basicConfig(level=logging.INFO, format="[%(asctime)s] %(levelname)s: %(message)s")
SEQ_MULTI_OF = 32
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Export the pure Z-Image transformer body to ONNX (no ControlNet)")
parser.add_argument("--config", default="config/z_image/z_image.yaml", help="Path to YAML config (if needed)")
parser.add_argument("--model-root", default="models/Diffusion_Transformer/Z-Image-Turbo/", help="Directory that stores the pretrained weights")
parser.add_argument("--checkpoint", default=None, help="Optional fine-tuned checkpoint to load")
parser.add_argument("--output", default="onnx-models/z_image_transformer_body_only.onnx", help="Target ONNX file path")
parser.add_argument("--height", type=int, default=864, help="Target image height used to derive latent resolution")
parser.add_argument("--width", type=int, default=496, help="Target image width used to derive latent resolution")
parser.add_argument("--batch-size", type=int, default=1, help="Batch size for the exported graph")
parser.add_argument("--sequence-length", type=int, default=512, help="Prompt embedding sequence length (must be a multiple of 32)")
parser.add_argument("--frames", type=int, default=1, help="Number of frames in the latent tensor")
parser.add_argument("--latent-downsample-factor", type=int, default=8, help="Downsampling ratio between spatial image size and latent size")
parser.add_argument("--latent-height", type=int, default=None, help="Override latent height (after downsampling)")
parser.add_argument("--latent-width", type=int, default=None, help="Override latent width (after downsampling)")
parser.add_argument("--dtype", choices=["fp16", "fp32"], default="fp16", help="Export precision")
parser.add_argument("--patch-size", type=int, default=2, help="Spatial patch size used by the transformer")
parser.add_argument("--f-patch-size", type=int, default=1, help="Frame patch size used by the transformer")
parser.add_argument("--opset", type=int, default=17, help="ONNX opset version")
parser.add_argument("--no-external-data", action="store_true", help="Disable external data format even if the model is larger than 2GB")
parser.add_argument("--skip-ort-check", action="store_true", help="Skip running an ONNX Runtime correctness check")
parser.add_argument("--ort-provider", default="CPUExecutionProvider", help="ONNX Runtime provider used during validation")
parser.add_argument("--save-calib-inputs", action="store_true", help="Dump ONNX input dictionaries as .npy for calibration")
parser.add_argument("--calib-dir", default="onnx-calibration", help="Directory for storing calibration npy files")
parser.add_argument("--dynamic-axes", action="store_true", help="Export ONNX with dynamic batch/seq/latent dims; default is static shape")
parser.add_argument("--skip-slim", action="store_true", help="Skip onnxslim simplification for faster debug export")
return parser.parse_args()
def run_onnxslim(input_file="vae.onnx", output_file="vae_slim.onnx") -> bool:
"""调用 onnxslim,对大模型可能较慢,可用 --skip-slim 跳过。"""
try:
cmd = ["onnxslim", input_file, output_file]
print(f"执行命令: {' '.join(cmd)}")
process = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
bufsize=1,
universal_newlines=True,
)
for line in process.stdout:
print(line, end="")
stdout, stderr = process.communicate()
if process.returncode != 0:
print(f"命令执行失败, 错误信息:\n{stderr}")
return False
print("ONNX模型压缩完成!")
return True
except FileNotFoundError:
print("错误: 未找到 onnxslim 命令, 请确保已安装 onnxslim (pip install onnx-simplifier)")
return False
except Exception as exc: # pragma: no cover
print(f"执行命令时发生错误: {exc}")
return False
def _resolve_path(path: str) -> str:
return os.path.abspath(os.path.join(REPO_ROOT, path)) if not os.path.isabs(path) else path
def load_transformer(args: argparse.Namespace, torch_dtype: torch.dtype, device: torch.device) -> ZImageTransformer2DModel:
model_root = _resolve_path(args.model_root)
checkpoint_path = _resolve_path(args.checkpoint) if args.checkpoint else None
if not os.path.isdir(model_root):
raise FileNotFoundError(f"Model root not found: {model_root}")
LOGGER.info("Loading transformer from %s", model_root)
transformer = ZImageTransformer2DModel.from_pretrained(
model_root,
subfolder="transformer",
low_cpu_mem_usage=True,
torch_dtype=torch_dtype,
)
transformer.eval()
transformer.to(device=device, dtype=torch_dtype)
if checkpoint_path and os.path.exists(checkpoint_path):
LOGGER.info("Loading checkpoint %s", checkpoint_path)
if checkpoint_path.endswith(".safetensors"):
from safetensors.torch import load_file # type: ignore
state_dict = load_file(checkpoint_path)
else:
state_dict = torch.load(checkpoint_path, map_location="cpu")
state_dict = state_dict.get("state_dict", state_dict)
missing, unexpected = transformer.load_state_dict(state_dict, strict=False)
LOGGER.info("Checkpoint loaded (missing=%d, unexpected=%d)", len(missing), len(unexpected))
elif checkpoint_path:
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
return transformer
def _validate_sequence_length(seq_len: int) -> None:
if seq_len % SEQ_MULTI_OF != 0:
raise ValueError("sequence_length must be a multiple of 32 to satisfy transformer padding rules")
def _compute_latent_dims(args: argparse.Namespace) -> Dict[str, int]:
if args.latent_height is not None and args.latent_width is not None:
latent_h = args.latent_height
latent_w = args.latent_width
else:
if args.height % args.latent_downsample_factor != 0 or args.width % args.latent_downsample_factor != 0:
raise ValueError("height and width must be divisible by latent_downsample_factor")
latent_h = args.height // args.latent_downsample_factor
latent_w = args.width // args.latent_downsample_factor
if latent_h % args.patch_size != 0 or latent_w % args.patch_size != 0:
raise ValueError("latent dimensions must be divisible by patch_size")
if args.frames % args.f_patch_size != 0:
raise ValueError("frames must be divisible by f_patch_size")
return {"latent_h": latent_h, "latent_w": latent_w}
def build_dummy_inputs(
args: argparse.Namespace,
model: ZImageTransformer2DModel,
torch_dtype: torch.dtype,
device: torch.device,
) -> OrderedDictType[str, torch.Tensor]:
_validate_sequence_length(args.sequence_length)
dims = _compute_latent_dims(args)
batch = args.batch_size
in_channels = model.config.in_channels
cap_dim = model.config.cap_feat_dim
latent = torch.randn(
batch,
in_channels,
args.frames,
dims["latent_h"],
dims["latent_w"],
dtype=torch_dtype,
device=device,
)
timestep = torch.linspace(0.0, 1.0, steps=batch, dtype=torch.float32, device=device)
prompts = torch.randn(
batch,
args.sequence_length,
cap_dim,
dtype=torch_dtype,
device=device,
)
return OrderedDict(
latent_model_input=latent,
timestep=timestep,
prompt_embeds=prompts,
)
def maybe_save_calibration_inputs(tag: str, inputs: OrderedDictType[str, torch.Tensor], args: argparse.Namespace) -> Optional[str]:
if not getattr(args, "save_calib_inputs", False):
return None
output_dir = _resolve_path(args.calib_dir)
os.makedirs(output_dir, exist_ok=True)
numpy_dict = {name: tensor.detach().cpu().numpy() for name, tensor in inputs.items()}
file_path = os.path.join(output_dir, f"{tag}_inputs.npy")
np.save(file_path, numpy_dict, allow_pickle=True)
LOGGER.info("Saved calibration inputs (%s) to %s", tag, file_path)
return file_path
def dump_initializer_parameters(model_path: str) -> str:
"""Save all ONNX initializers into a standalone .npz file."""
model_proto = onnx.load(model_path, load_external_data=True)
param_dict = {}
for initializer in model_proto.graph.initializer:
param_dict[initializer.name] = numpy_helper.to_array(initializer)
param_path = f"{model_path}.params.npz"
np.savez(param_path, **param_dict)
LOGGER.info("Saved %d parameters to %s", len(param_dict), param_path)
return param_path
class TransformerBodyOnlyWrapper(torch.nn.Module):
"""轻量包装,暴露纯 transformer body 所需的输入。"""
def __init__(self, model: ZImageTransformer2DModel, patch_size: int, f_patch_size: int):
super().__init__()
self.model = model
self.patch_size = patch_size
self.f_patch_size = f_patch_size
def forward(
self,
latent_model_input: torch.Tensor,
timestep: torch.Tensor,
prompt_embeds: torch.Tensor,
) -> torch.Tensor:
latents = list(latent_model_input.unbind(dim=0))
prompts = list(prompt_embeds.unbind(dim=0))
outputs = self.model(
latents,
timestep,
prompts,
patch_size=self.patch_size,
f_patch_size=self.f_patch_size,
)
return outputs
def export_onnx(
wrapper: torch.nn.Module,
sample_inputs: OrderedDictType[str, torch.Tensor],
output_path: str,
output_names: List[str],
args: argparse.Namespace,
) -> Tuple[str, str]:
export_path = _resolve_path(output_path)
export_dir = os.path.dirname(export_path)
if export_dir:
os.makedirs(export_dir, exist_ok=True)
input_names = list(sample_inputs.keys())
use_external = not args.no_external_data
wrapper.eval()
dynamic_axes = None
if args.dynamic_axes:
dynamic_axes = {
"latent_model_input": {0: "batch", 2: "frames", 3: "latent_h", 4: "latent_w"},
"prompt_embeds": {0: "batch", 1: "seq_len"},
"timestep": {0: "batch"},
"sample": {0: "batch", 2: "frames", 3: "latent_h", 4: "latent_w"},
}
LOGGER.info("Exporting ONNX to %s", export_path)
with torch.inference_mode():
torch.onnx.export(
wrapper,
args=tuple(sample_inputs[name] for name in input_names),
f=export_path,
input_names=input_names,
output_names=output_names,
opset_version=args.opset,
do_constant_folding=True,
export_params=True,
dynamic_axes={k: v for k, v in dynamic_axes.items() if k in input_names + output_names} if dynamic_axes else None,
# use_external_data_format=use_external,
)
LOGGER.info("Raw ONNX export finished")
trans_onnx = onnx.load(export_path)
simp_onnx_data = os.path.splitext(export_path)[0] + "_simp.onnx"
onnx.save(
trans_onnx,
simp_onnx_data,
save_as_external_data=True,
all_tensors_to_one_file=True,
)
external_weight_file = simp_onnx_data + ".data"
LOGGER.info("Saved external-data ONNX to %s (weights -> %s)", simp_onnx_data, external_weight_file)
if args.skip_slim:
LOGGER.info("Skip onnxslim as requested, using simplified external-data ONNX: %s", simp_onnx_data)
final_onnx = simp_onnx_data
else:
slim_onnx_path = os.path.splitext(simp_onnx_data)[0] + "_slim.onnx"
LOGGER.info("Transformer ONNX model exported, start to simplify via onnxslim")
success = run_onnxslim(simp_onnx_data, slim_onnx_path)
if not success:
raise RuntimeError("onnxslim simplification failed, please check logs")
final_onnx = slim_onnx_path
LOGGER.info("Transformer ONNX model exported successfully: %s", final_onnx)
param_path = ""
# param_path = dump_initializer_parameters(final_onnx) # 如需单独保存参数可解开
return final_onnx, param_path
def run_ort_validation(
wrapper: torch.nn.Module,
sample_inputs: OrderedDictType[str, torch.Tensor],
onnx_path: str,
provider: str,
) -> None:
try:
import onnxruntime as ort
except ImportError: # pragma: no cover
LOGGER.warning("onnxruntime not installed, skip validation")
return
wrapper.eval()
with torch.inference_mode():
torch_output = wrapper(*sample_inputs.values()).detach().cpu().numpy()
sess_options = ort.SessionOptions()
session = ort.InferenceSession(onnx_path, sess_options=sess_options, providers=[provider])
ort_inputs = {name: tensor.detach().cpu().numpy() for name, tensor in sample_inputs.items()}
ort_output = session.run(None, ort_inputs)[0]
abs_diff = np.max(np.abs(torch_output - ort_output))
rel_diff = abs_diff / (np.maximum(1.0, np.max(np.abs(torch_output))))
LOGGER.info("ONNX Runtime check done (abs=%.6f, rel=%.6f)", abs_diff, rel_diff)
def main() -> None:
args = parse_args()
device = torch.device("cpu")
torch_dtype = torch.float16 if args.dtype == "fp16" else torch.float32
torch.set_grad_enabled(False)
transformer = load_transformer(args, torch_dtype, device)
wrapper = TransformerBodyOnlyWrapper(transformer, args.patch_size, args.f_patch_size)
sample_inputs = build_dummy_inputs(args, transformer, torch_dtype, device)
maybe_save_calibration_inputs("transformer_body_only", sample_inputs, args)
transformer_model_path, _ = export_onnx(
wrapper,
sample_inputs,
args.output,
["sample"],
args,
)
if not args.skip_ort_check:
try:
run_ort_validation(wrapper, sample_inputs, transformer_model_path, args.ort_provider)
except Exception as exc: # pragma: no cover
LOGGER.warning("ONNX Runtime validation failed: %s", exc)
if __name__ == "__main__":
"""
示例:
python scripts/z_image/export_transformer_body_onnx.py \
--output onnx-models/z_image_transformer_body_only.onnx \
--height 512 --width 512 --sequence-length 128 \
--latent-downsample-factor 8 \
--dtype fp32 \
--skip-slim
"""
main()