| |
| """Export the pure Z-Image transformer body (no ControlNet) to ONNX.""" |
|
|
| import argparse |
| import logging |
| import os |
| import sys |
| from collections import OrderedDict |
| from typing import Dict, List, Optional, OrderedDict as OrderedDictType, Tuple |
|
|
| import numpy as np |
| import torch |
| import onnx |
| from onnx import numpy_helper |
| import subprocess |
| from omegaconf import OmegaConf |
|
|
| REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) |
| if REPO_ROOT not in sys.path: |
| sys.path.insert(0, REPO_ROOT) |
|
|
| from videox_fun.models import ZImageTransformer2DModel |
|
|
| LOGGER = logging.getLogger("export_transformer_body_onnx") |
| logging.basicConfig(level=logging.INFO, format="[%(asctime)s] %(levelname)s: %(message)s") |
|
|
| SEQ_MULTI_OF = 32 |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser(description="Export the pure Z-Image transformer body to ONNX (no ControlNet)") |
| parser.add_argument("--config", default="config/z_image/z_image.yaml", help="Path to YAML config (if needed)") |
| parser.add_argument("--model-root", default="models/Diffusion_Transformer/Z-Image-Turbo/", help="Directory that stores the pretrained weights") |
| parser.add_argument("--checkpoint", default=None, help="Optional fine-tuned checkpoint to load") |
| parser.add_argument("--output", default="onnx-models/z_image_transformer_body_only.onnx", help="Target ONNX file path") |
| parser.add_argument("--height", type=int, default=864, help="Target image height used to derive latent resolution") |
| parser.add_argument("--width", type=int, default=496, help="Target image width used to derive latent resolution") |
| parser.add_argument("--batch-size", type=int, default=1, help="Batch size for the exported graph") |
| parser.add_argument("--sequence-length", type=int, default=512, help="Prompt embedding sequence length (must be a multiple of 32)") |
| parser.add_argument("--frames", type=int, default=1, help="Number of frames in the latent tensor") |
| parser.add_argument("--latent-downsample-factor", type=int, default=8, help="Downsampling ratio between spatial image size and latent size") |
| parser.add_argument("--latent-height", type=int, default=None, help="Override latent height (after downsampling)") |
| parser.add_argument("--latent-width", type=int, default=None, help="Override latent width (after downsampling)") |
| parser.add_argument("--dtype", choices=["fp16", "fp32"], default="fp16", help="Export precision") |
| parser.add_argument("--patch-size", type=int, default=2, help="Spatial patch size used by the transformer") |
| parser.add_argument("--f-patch-size", type=int, default=1, help="Frame patch size used by the transformer") |
| parser.add_argument("--opset", type=int, default=17, help="ONNX opset version") |
| parser.add_argument("--no-external-data", action="store_true", help="Disable external data format even if the model is larger than 2GB") |
| parser.add_argument("--skip-ort-check", action="store_true", help="Skip running an ONNX Runtime correctness check") |
| parser.add_argument("--ort-provider", default="CPUExecutionProvider", help="ONNX Runtime provider used during validation") |
| parser.add_argument("--save-calib-inputs", action="store_true", help="Dump ONNX input dictionaries as .npy for calibration") |
| parser.add_argument("--calib-dir", default="onnx-calibration", help="Directory for storing calibration npy files") |
| parser.add_argument("--dynamic-axes", action="store_true", help="Export ONNX with dynamic batch/seq/latent dims; default is static shape") |
| parser.add_argument("--skip-slim", action="store_true", help="Skip onnxslim simplification for faster debug export") |
| return parser.parse_args() |
|
|
|
|
| def run_onnxslim(input_file="vae.onnx", output_file="vae_slim.onnx") -> bool: |
| """调用 onnxslim,对大模型可能较慢,可用 --skip-slim 跳过。""" |
| try: |
| cmd = ["onnxslim", input_file, output_file] |
| print(f"执行命令: {' '.join(cmd)}") |
| process = subprocess.Popen( |
| cmd, |
| stdout=subprocess.PIPE, |
| stderr=subprocess.PIPE, |
| text=True, |
| bufsize=1, |
| universal_newlines=True, |
| ) |
| for line in process.stdout: |
| print(line, end="") |
| stdout, stderr = process.communicate() |
| if process.returncode != 0: |
| print(f"命令执行失败, 错误信息:\n{stderr}") |
| return False |
| print("ONNX模型压缩完成!") |
| return True |
| except FileNotFoundError: |
| print("错误: 未找到 onnxslim 命令, 请确保已安装 onnxslim (pip install onnx-simplifier)") |
| return False |
| except Exception as exc: |
| print(f"执行命令时发生错误: {exc}") |
| return False |
|
|
|
|
| def _resolve_path(path: str) -> str: |
| return os.path.abspath(os.path.join(REPO_ROOT, path)) if not os.path.isabs(path) else path |
|
|
|
|
| def load_transformer(args: argparse.Namespace, torch_dtype: torch.dtype, device: torch.device) -> ZImageTransformer2DModel: |
| model_root = _resolve_path(args.model_root) |
| checkpoint_path = _resolve_path(args.checkpoint) if args.checkpoint else None |
|
|
| if not os.path.isdir(model_root): |
| raise FileNotFoundError(f"Model root not found: {model_root}") |
|
|
| LOGGER.info("Loading transformer from %s", model_root) |
| transformer = ZImageTransformer2DModel.from_pretrained( |
| model_root, |
| subfolder="transformer", |
| low_cpu_mem_usage=True, |
| torch_dtype=torch_dtype, |
| ) |
| transformer.eval() |
| transformer.to(device=device, dtype=torch_dtype) |
|
|
| if checkpoint_path and os.path.exists(checkpoint_path): |
| LOGGER.info("Loading checkpoint %s", checkpoint_path) |
| if checkpoint_path.endswith(".safetensors"): |
| from safetensors.torch import load_file |
|
|
| state_dict = load_file(checkpoint_path) |
| else: |
| state_dict = torch.load(checkpoint_path, map_location="cpu") |
| state_dict = state_dict.get("state_dict", state_dict) |
| missing, unexpected = transformer.load_state_dict(state_dict, strict=False) |
| LOGGER.info("Checkpoint loaded (missing=%d, unexpected=%d)", len(missing), len(unexpected)) |
| elif checkpoint_path: |
| raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") |
|
|
| return transformer |
|
|
|
|
| def _validate_sequence_length(seq_len: int) -> None: |
| if seq_len % SEQ_MULTI_OF != 0: |
| raise ValueError("sequence_length must be a multiple of 32 to satisfy transformer padding rules") |
|
|
|
|
| def _compute_latent_dims(args: argparse.Namespace) -> Dict[str, int]: |
| if args.latent_height is not None and args.latent_width is not None: |
| latent_h = args.latent_height |
| latent_w = args.latent_width |
| else: |
| if args.height % args.latent_downsample_factor != 0 or args.width % args.latent_downsample_factor != 0: |
| raise ValueError("height and width must be divisible by latent_downsample_factor") |
| latent_h = args.height // args.latent_downsample_factor |
| latent_w = args.width // args.latent_downsample_factor |
| if latent_h % args.patch_size != 0 or latent_w % args.patch_size != 0: |
| raise ValueError("latent dimensions must be divisible by patch_size") |
| if args.frames % args.f_patch_size != 0: |
| raise ValueError("frames must be divisible by f_patch_size") |
| return {"latent_h": latent_h, "latent_w": latent_w} |
|
|
|
|
| def build_dummy_inputs( |
| args: argparse.Namespace, |
| model: ZImageTransformer2DModel, |
| torch_dtype: torch.dtype, |
| device: torch.device, |
| ) -> OrderedDictType[str, torch.Tensor]: |
| _validate_sequence_length(args.sequence_length) |
| dims = _compute_latent_dims(args) |
| batch = args.batch_size |
| in_channels = model.config.in_channels |
| cap_dim = model.config.cap_feat_dim |
|
|
| latent = torch.randn( |
| batch, |
| in_channels, |
| args.frames, |
| dims["latent_h"], |
| dims["latent_w"], |
| dtype=torch_dtype, |
| device=device, |
| ) |
| timestep = torch.linspace(0.0, 1.0, steps=batch, dtype=torch.float32, device=device) |
| prompts = torch.randn( |
| batch, |
| args.sequence_length, |
| cap_dim, |
| dtype=torch_dtype, |
| device=device, |
| ) |
|
|
| return OrderedDict( |
| latent_model_input=latent, |
| timestep=timestep, |
| prompt_embeds=prompts, |
| ) |
|
|
|
|
| def maybe_save_calibration_inputs(tag: str, inputs: OrderedDictType[str, torch.Tensor], args: argparse.Namespace) -> Optional[str]: |
| if not getattr(args, "save_calib_inputs", False): |
| return None |
| output_dir = _resolve_path(args.calib_dir) |
| os.makedirs(output_dir, exist_ok=True) |
| numpy_dict = {name: tensor.detach().cpu().numpy() for name, tensor in inputs.items()} |
| file_path = os.path.join(output_dir, f"{tag}_inputs.npy") |
| np.save(file_path, numpy_dict, allow_pickle=True) |
| LOGGER.info("Saved calibration inputs (%s) to %s", tag, file_path) |
| return file_path |
|
|
|
|
| def dump_initializer_parameters(model_path: str) -> str: |
| """Save all ONNX initializers into a standalone .npz file.""" |
| model_proto = onnx.load(model_path, load_external_data=True) |
| param_dict = {} |
| for initializer in model_proto.graph.initializer: |
| param_dict[initializer.name] = numpy_helper.to_array(initializer) |
| param_path = f"{model_path}.params.npz" |
| np.savez(param_path, **param_dict) |
| LOGGER.info("Saved %d parameters to %s", len(param_dict), param_path) |
| return param_path |
|
|
|
|
| class TransformerBodyOnlyWrapper(torch.nn.Module): |
| """轻量包装,暴露纯 transformer body 所需的输入。""" |
|
|
| def __init__(self, model: ZImageTransformer2DModel, patch_size: int, f_patch_size: int): |
| super().__init__() |
| self.model = model |
| self.patch_size = patch_size |
| self.f_patch_size = f_patch_size |
|
|
| def forward( |
| self, |
| latent_model_input: torch.Tensor, |
| timestep: torch.Tensor, |
| prompt_embeds: torch.Tensor, |
| ) -> torch.Tensor: |
| latents = list(latent_model_input.unbind(dim=0)) |
| prompts = list(prompt_embeds.unbind(dim=0)) |
| outputs = self.model( |
| latents, |
| timestep, |
| prompts, |
| patch_size=self.patch_size, |
| f_patch_size=self.f_patch_size, |
| ) |
| return outputs |
|
|
|
|
| def export_onnx( |
| wrapper: torch.nn.Module, |
| sample_inputs: OrderedDictType[str, torch.Tensor], |
| output_path: str, |
| output_names: List[str], |
| args: argparse.Namespace, |
| ) -> Tuple[str, str]: |
| export_path = _resolve_path(output_path) |
| export_dir = os.path.dirname(export_path) |
| if export_dir: |
| os.makedirs(export_dir, exist_ok=True) |
| input_names = list(sample_inputs.keys()) |
| use_external = not args.no_external_data |
| wrapper.eval() |
|
|
| dynamic_axes = None |
| if args.dynamic_axes: |
| dynamic_axes = { |
| "latent_model_input": {0: "batch", 2: "frames", 3: "latent_h", 4: "latent_w"}, |
| "prompt_embeds": {0: "batch", 1: "seq_len"}, |
| "timestep": {0: "batch"}, |
| "sample": {0: "batch", 2: "frames", 3: "latent_h", 4: "latent_w"}, |
| } |
|
|
| LOGGER.info("Exporting ONNX to %s", export_path) |
| with torch.inference_mode(): |
| torch.onnx.export( |
| wrapper, |
| args=tuple(sample_inputs[name] for name in input_names), |
| f=export_path, |
| input_names=input_names, |
| output_names=output_names, |
| opset_version=args.opset, |
| do_constant_folding=True, |
| export_params=True, |
| dynamic_axes={k: v for k, v in dynamic_axes.items() if k in input_names + output_names} if dynamic_axes else None, |
| |
| ) |
|
|
| LOGGER.info("Raw ONNX export finished") |
|
|
| trans_onnx = onnx.load(export_path) |
| simp_onnx_data = os.path.splitext(export_path)[0] + "_simp.onnx" |
| onnx.save( |
| trans_onnx, |
| simp_onnx_data, |
| save_as_external_data=True, |
| all_tensors_to_one_file=True, |
| ) |
| external_weight_file = simp_onnx_data + ".data" |
| LOGGER.info("Saved external-data ONNX to %s (weights -> %s)", simp_onnx_data, external_weight_file) |
|
|
| if args.skip_slim: |
| LOGGER.info("Skip onnxslim as requested, using simplified external-data ONNX: %s", simp_onnx_data) |
| final_onnx = simp_onnx_data |
| else: |
| slim_onnx_path = os.path.splitext(simp_onnx_data)[0] + "_slim.onnx" |
| LOGGER.info("Transformer ONNX model exported, start to simplify via onnxslim") |
| success = run_onnxslim(simp_onnx_data, slim_onnx_path) |
| if not success: |
| raise RuntimeError("onnxslim simplification failed, please check logs") |
| final_onnx = slim_onnx_path |
| LOGGER.info("Transformer ONNX model exported successfully: %s", final_onnx) |
|
|
| param_path = "" |
| |
|
|
| return final_onnx, param_path |
|
|
|
|
| def run_ort_validation( |
| wrapper: torch.nn.Module, |
| sample_inputs: OrderedDictType[str, torch.Tensor], |
| onnx_path: str, |
| provider: str, |
| ) -> None: |
| try: |
| import onnxruntime as ort |
| except ImportError: |
| LOGGER.warning("onnxruntime not installed, skip validation") |
| return |
|
|
| wrapper.eval() |
| with torch.inference_mode(): |
| torch_output = wrapper(*sample_inputs.values()).detach().cpu().numpy() |
|
|
| sess_options = ort.SessionOptions() |
| session = ort.InferenceSession(onnx_path, sess_options=sess_options, providers=[provider]) |
| ort_inputs = {name: tensor.detach().cpu().numpy() for name, tensor in sample_inputs.items()} |
| ort_output = session.run(None, ort_inputs)[0] |
|
|
| abs_diff = np.max(np.abs(torch_output - ort_output)) |
| rel_diff = abs_diff / (np.maximum(1.0, np.max(np.abs(torch_output)))) |
| LOGGER.info("ONNX Runtime check done (abs=%.6f, rel=%.6f)", abs_diff, rel_diff) |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
| device = torch.device("cpu") |
| torch_dtype = torch.float16 if args.dtype == "fp16" else torch.float32 |
|
|
| torch.set_grad_enabled(False) |
| transformer = load_transformer(args, torch_dtype, device) |
| wrapper = TransformerBodyOnlyWrapper(transformer, args.patch_size, args.f_patch_size) |
| sample_inputs = build_dummy_inputs(args, transformer, torch_dtype, device) |
|
|
| maybe_save_calibration_inputs("transformer_body_only", sample_inputs, args) |
|
|
| transformer_model_path, _ = export_onnx( |
| wrapper, |
| sample_inputs, |
| args.output, |
| ["sample"], |
| args, |
| ) |
|
|
| if not args.skip_ort_check: |
| try: |
| run_ort_validation(wrapper, sample_inputs, transformer_model_path, args.ort_provider) |
| except Exception as exc: |
| LOGGER.warning("ONNX Runtime validation failed: %s", exc) |
|
|
|
|
| if __name__ == "__main__": |
| """ |
| 示例: |
| python scripts/z_image/export_transformer_body_onnx.py \ |
| --output onnx-models/z_image_transformer_body_only.onnx \ |
| --height 512 --width 512 --sequence-length 128 \ |
| --latent-downsample-factor 8 \ |
| --dtype fp32 \ |
| --skip-slim |
| """ |
| main() |
|
|