| |
| """导出 VAE Encoder/Decoder 的 ONNX 模型。""" |
|
|
| import argparse |
| import logging |
| import os |
| import sys |
| from collections import OrderedDict |
| from typing import Any, Dict, Optional |
|
|
| import numpy as np |
| import torch |
| from loguru import logger |
| import onnx |
| from onnx import numpy_helper |
| import subprocess |
|
|
| REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) |
| if REPO_ROOT not in sys.path: |
| sys.path.insert(0, REPO_ROOT) |
|
|
| from videox_fun.models import AutoencoderKL |
|
|
| LOGGER = logging.getLogger("export_vae_onnx") |
| logging.basicConfig(level=logging.INFO, format="[%(asctime)s] %(levelname)s: %(message)s") |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser(description="Export VAE encoder/decoder to ONNX") |
| parser.add_argument("--model-root", default="models/Diffusion_Transformer/Z-Image-Turbo/", help="Diffusers 权重所在目录") |
| parser.add_argument("--checkpoint", default=None, help="可选的 VAE finetune checkpoint") |
| parser.add_argument("--encoder-output", default="onnx-models/vae_encoder.onnx", help="VAE Encoder ONNX 路径") |
| parser.add_argument("--decoder-output", default="onnx-models/vae_decoder.onnx", help="VAE Decoder ONNX 路径") |
| parser.add_argument("--height", type=int, default=864, help="导出时的图片高度") |
| parser.add_argument("--width", type=int, default=496, help="导出时的图片宽度") |
| parser.add_argument("--latent-downsample-factor", type=int, default=8, help="VAE 下采样倍数") |
| parser.add_argument("--batch-size", type=int, default=1, help="导出 batch 大小") |
| parser.add_argument("--dtype", choices=["fp16", "fp32"], default="fp16", help="导出精度") |
| parser.add_argument("--opset", type=int, default=17, help="ONNX opset 版本") |
| parser.add_argument("--dynamic-axes", action="store_true", help="是否导出动态维度") |
| parser.add_argument("--skip-ort-check", action="store_true", help="跳过 onnxruntime 结果校验") |
| parser.add_argument("--ort-provider", default="CPUExecutionProvider", help="onnxruntime provider") |
| parser.add_argument("--skip-slim", action="store_true", help="跳过 onnxslim") |
| parser.add_argument("--no-external-data", action="store_true", help="禁用外部数据格式保存") |
| parser.add_argument("--save-calib-inputs", action="store_true", help="保存校准输入 npy") |
| parser.add_argument("--calib-dir", default="onnx-calibration", help="校准输入保存目录") |
| return parser.parse_args() |
|
|
|
|
| def run_onnxslim(input_file: str, output_file: str) -> bool: |
| try: |
| cmd = ["onnxslim", input_file, output_file] |
| print(f"执行命令: {' '.join(cmd)}") |
| process = subprocess.Popen( |
| cmd, |
| stdout=subprocess.PIPE, |
| stderr=subprocess.PIPE, |
| text=True, |
| bufsize=1, |
| universal_newlines=True, |
| ) |
| for line in process.stdout: |
| print(line, end="") |
| _, stderr = process.communicate() |
| if process.returncode != 0: |
| print(f"命令执行失败, 错误信息:\n{stderr}") |
| return False |
| print("ONNX模型压缩完成!") |
| return True |
| except FileNotFoundError: |
| print("错误: 未找到 onnxslim 命令, 请确保已安装 onnxslim (pip install onnx-simplifier)") |
| return False |
| except Exception as exc: |
| print(f"执行命令时发生错误: {exc}") |
| return False |
|
|
|
|
| def _resolve_path(path: str) -> str: |
| return os.path.abspath(os.path.join(REPO_ROOT, path)) if not os.path.isabs(path) else path |
|
|
|
|
| def _check_image_dims(height: int, width: int, factor: int) -> None: |
| if height % factor != 0 or width % factor != 0: |
| raise ValueError("height 和 width 需要能被 latent_downsample_factor 整除") |
|
|
|
|
| def _compute_latent_dims(height: int, width: int, factor: int) -> Dict[str, int]: |
| _check_image_dims(height, width, factor) |
| return {"latent_h": height // factor, "latent_w": width // factor} |
|
|
|
|
| def load_vae(args: argparse.Namespace, torch_dtype: torch.dtype, device: torch.device) -> AutoencoderKL: |
| model_root = _resolve_path(args.model_root) |
| checkpoint_path = _resolve_path(args.checkpoint) if args.checkpoint else None |
| if not os.path.isdir(model_root): |
| raise FileNotFoundError(f"Model root not found: {model_root}") |
|
|
| LOGGER.info("Loading VAE from %s", model_root) |
| vae = AutoencoderKL.from_pretrained( |
| model_root, |
| subfolder="vae", |
| torch_dtype=torch_dtype, |
| low_cpu_mem_usage=True, |
| ) |
| vae.to(device=device, dtype=torch_dtype) |
| vae.eval() |
|
|
| if checkpoint_path: |
| if not os.path.exists(checkpoint_path): |
| raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") |
| LOGGER.info("Loading checkpoint %s", checkpoint_path) |
| if checkpoint_path.endswith(".safetensors"): |
| from safetensors.torch import load_file |
|
|
| state_dict = load_file(checkpoint_path) |
| else: |
| state_dict = torch.load(checkpoint_path, map_location="cpu") |
| state_dict = state_dict.get("state_dict", state_dict) |
| missing, unexpected = vae.load_state_dict(state_dict, strict=False) |
| LOGGER.info("Checkpoint loaded (missing=%d, unexpected=%d)", len(missing), len(unexpected)) |
|
|
| return vae |
|
|
|
|
| def build_dummy_inputs(args: argparse.Namespace, vae: AutoencoderKL, torch_dtype: torch.dtype, device: torch.device) -> OrderedDict: |
| dims = _compute_latent_dims(args.height, args.width, args.latent_downsample_factor) |
| pixel_values = torch.randn( |
| args.batch_size, |
| 3, |
| args.height, |
| args.width, |
| dtype=torch_dtype, |
| device=device, |
| ) |
| latents = torch.randn( |
| args.batch_size, |
| vae.config.latent_channels, |
| dims["latent_h"], |
| dims["latent_w"], |
| dtype=torch_dtype, |
| device=device, |
| ) |
| return OrderedDict(pixel_values=pixel_values, latent=latents) |
|
|
|
|
| class VAEEncoderWrapper(torch.nn.Module): |
| def __init__(self, model: AutoencoderKL): |
| super().__init__() |
| self.model = model |
|
|
| def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: |
| latent_dist = self.model.encode(pixel_values)[0] |
| return latent_dist.mode() |
|
|
|
|
| class VAEDecoderWrapper(torch.nn.Module): |
| def __init__(self, model: AutoencoderKL): |
| super().__init__() |
| self.model = model |
|
|
| def forward(self, latents: torch.Tensor) -> torch.Tensor: |
| image = self.model.decode(latents, return_dict=False)[0] |
| return image |
|
|
|
|
| def maybe_save_calibration_inputs(tag: str, inputs: OrderedDict, args: argparse.Namespace) -> Optional[str]: |
| if not getattr(args, "save_calib_inputs", False): |
| return None |
| output_dir = _resolve_path(args.calib_dir) |
| os.makedirs(output_dir, exist_ok=True) |
| numpy_dict = {name: tensor.detach().cpu().numpy() for name, tensor in inputs.items()} |
| file_path = os.path.join(output_dir, f"{tag}_inputs.npy") |
| np.save(file_path, numpy_dict, allow_pickle=True) |
| LOGGER.info("Saved calibration inputs (%s) to %s", tag, file_path) |
| return file_path |
|
|
|
|
| def dump_initializer_parameters(model_path: str) -> str: |
| model_proto = onnx.load(model_path, load_external_data=True) |
| param_dict: Dict[str, Any] = {} |
| for initializer in model_proto.graph.initializer: |
| param_dict[initializer.name] = numpy_helper.to_array(initializer) |
| param_path = f"{model_path}.params.npz" |
| np.savez(param_path, **param_dict) |
| LOGGER.info("Saved %d parameters to %s", len(param_dict), param_path) |
| return param_path |
|
|
|
|
| def export_onnx( |
| wrapper: torch.nn.Module, |
| sample_inputs: OrderedDict, |
| output_path: str, |
| output_names: list, |
| args: argparse.Namespace, |
| ) -> str: |
| export_path = _resolve_path(output_path) |
| export_dir = os.path.dirname(export_path) |
| if export_dir: |
| os.makedirs(export_dir, exist_ok=True) |
|
|
| input_names = list(sample_inputs.keys()) |
| wrapper.eval() |
|
|
| dynamic_axes = None |
| if args.dynamic_axes: |
| dynamic_axes = { |
| "pixel_values": {0: "batch", 2: "height", 3: "width"}, |
| "latents": {0: "batch", 2: "latent_h", 3: "latent_w"}, |
| "images": {0: "batch", 2: "height", 3: "width"}, |
| } |
|
|
| LOGGER.info("Exporting ONNX to %s", export_path) |
| with torch.inference_mode(): |
| torch.onnx.export( |
| wrapper, |
| args=tuple(sample_inputs[name] for name in input_names), |
| f=export_path, |
| input_names=input_names, |
| output_names=output_names, |
| opset_version=args.opset, |
| do_constant_folding=True, |
| export_params=True, |
| dynamic_axes={k: v for k, v in (dynamic_axes or {}).items() if k in input_names + output_names} if dynamic_axes else None, |
| |
| ) |
|
|
| LOGGER.info("Raw ONNX export finished") |
|
|
| onnx_model = onnx.load(export_path) |
| simp_onnx_path = os.path.splitext(export_path)[0] + "_simp.onnx" |
| onnx.save( |
| onnx_model, |
| simp_onnx_path, |
| save_as_external_data=True, |
| all_tensors_to_one_file=True, |
| ) |
| LOGGER.info("Saved external-data ONNX to %s", simp_onnx_path) |
|
|
| if args.skip_slim: |
| LOGGER.info("Skip onnxslim as requested") |
| final_path = simp_onnx_path |
| else: |
| slim_path = os.path.splitext(simp_onnx_path)[0] + "_slim.onnx" |
| LOGGER.info("Start onnxslim simplification") |
| success = run_onnxslim(simp_onnx_path, slim_path) |
| if not success: |
| raise RuntimeError("onnxslim simplification failed") |
| final_path = slim_path |
| LOGGER.info("onnxslim done: %s", final_path) |
|
|
| |
| return final_path |
|
|
|
|
| def run_ort_validation(wrapper: torch.nn.Module, sample_inputs: OrderedDict, onnx_path: str, provider: str) -> None: |
| try: |
| import onnxruntime as ort |
| except ImportError: |
| LOGGER.warning("onnxruntime not installed, skip validation") |
| return |
|
|
| wrapper.eval() |
| with torch.inference_mode(): |
| torch_output = wrapper(*sample_inputs.values()).detach().cpu().numpy() |
|
|
| session = ort.InferenceSession(onnx_path, providers=[provider]) |
| ort_inputs = {name: tensor.detach().cpu().numpy() for name, tensor in sample_inputs.items()} |
| ort_output = session.run(None, ort_inputs)[0] |
|
|
| abs_diff = np.max(np.abs(torch_output - ort_output)) |
| rel_diff = abs_diff / max(1.0, float(np.max(np.abs(torch_output)))) |
| LOGGER.info("ONNX Runtime check done (abs=%.6f, rel=%.6f)", abs_diff, rel_diff) |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| torch_dtype = torch.float16 if args.dtype == "fp16" else torch.float32 |
| if torch_dtype == torch.float16 and device.type == "cpu": |
| LOGGER.warning("CPU 上不支持 fp16, 自动回退为 fp32") |
| torch_dtype = torch.float32 |
|
|
| torch.set_grad_enabled(False) |
| vae = load_vae(args, torch_dtype, device) |
| sample_inputs = build_dummy_inputs(args, vae, torch_dtype, device) |
|
|
| encoder_inputs = OrderedDict(pixel_values=sample_inputs["pixel_values"]) |
| encoder_wrapper = VAEEncoderWrapper(vae) |
| maybe_save_calibration_inputs("vae_encoder", encoder_inputs, args) |
|
|
| with torch.inference_mode(): |
| latent_sample = encoder_wrapper(*encoder_inputs.values()).detach() |
|
|
| encoder_onnx = export_onnx( |
| encoder_wrapper, |
| encoder_inputs, |
| args.encoder_output, |
| ["latents"], |
| args, |
| ) |
|
|
| decoder_inputs = OrderedDict(latents=latent_sample) |
| decoder_wrapper = VAEDecoderWrapper(vae) |
| maybe_save_calibration_inputs("vae_decoder", decoder_inputs, args) |
|
|
| decoder_onnx = export_onnx( |
| decoder_wrapper, |
| decoder_inputs, |
| args.decoder_output, |
| ["images"], |
| args, |
| ) |
|
|
| if not args.skip_ort_check: |
| try: |
| run_ort_validation(encoder_wrapper, encoder_inputs, encoder_onnx, args.ort_provider) |
| run_ort_validation(decoder_wrapper, decoder_inputs, decoder_onnx, args.ort_provider) |
| except Exception as exc: |
| LOGGER.warning("ONNX Runtime validation failed: %s", exc) |
|
|
|
|
| if __name__ == "__main__": |
| """ |
| 示例: |
| python scripts/z_image_fun/export_vae_onnx.py \ |
| --model-root models/Diffusion_Transformer/Z-Image-Turbo/ \ |
| --height 512 --width 512 \ |
| --encoder-output onnx-models/vae_encoder.onnx \ |
| --decoder-output onnx-models/vae_decoder.onnx \ |
| --dtype fp32 \ |
| --save-calib-inputs \ |
| --skip-ort-check |
| """ |
| main() |
|
|