| |
| """Split an ONNX graph into smaller sub-models driven by sub_config rules. |
| |
| This script reads a JSON config file (matching the pulsar2 sub_config layout), |
| extracts the requested subgraphs, and optionally emits any leftover parts of the |
| model as independent ONNX graphs. A verification utility can run the original |
| model and the stitched micro-model pipeline to make sure their outputs match. |
| """ |
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import logging |
| from collections import defaultdict, deque |
| from dataclasses import dataclass, field |
| from pathlib import Path |
| from typing import Dict, Iterable, List, Optional, Sequence, Set |
|
|
| import numpy as np |
| import onnx |
| from onnx import utils as onnx_utils |
|
|
| try: |
| import onnxruntime as ort |
| except ImportError: |
| ort = None |
|
|
|
|
| @dataclass |
| class SubGraphSpec: |
| """Describes a single subgraph to extract from the full model.""" |
|
|
| label: str |
| start: List[str] |
| end: List[str] |
| node_names: Set[str] |
| source: str |
| output_path: Optional[Path] = None |
|
|
|
|
| @dataclass |
| class GraphIndex: |
| """Caches helpful lookups for traversing an ONNX graph.""" |
|
|
| tensor_to_producer: Dict[str, str] |
| tensor_to_consumers: Dict[str, List[str]] |
| node_inputs: Dict[str, List[str]] |
| node_outputs: Dict[str, List[str]] |
| graph_inputs: Set[str] |
| graph_outputs: Set[str] |
| initializer_names: Set[str] |
| node_order: List[str] |
|
|
|
|
| def sanitize(name: str) -> str: |
| keep = [c if c.isalnum() else "_" for c in name] if name else ["anon"] |
| sanitized = "".join(keep).strip("_") |
| return sanitized or "tensor" |
|
|
|
|
| def build_graph_index(model: onnx.ModelProto) -> GraphIndex: |
| tensor_to_producer: Dict[str, str] = {} |
| tensor_to_consumers: Dict[str, List[str]] = defaultdict(list) |
| node_inputs: Dict[str, List[str]] = {} |
| node_outputs: Dict[str, List[str]] = {} |
| node_order: List[str] = [] |
|
|
| used_names: Set[str] = set() |
| for idx, node in enumerate(model.graph.node): |
| base = node.name.strip() if node.name else "" |
| candidate = base or f"node_{idx}" |
| while candidate in used_names: |
| candidate = f"{candidate}_{idx}" |
| used_names.add(candidate) |
| node_name = candidate |
| node_order.append(node_name) |
| node_inputs[node_name] = [x for x in node.input if x] |
| node_outputs[node_name] = [y for y in node.output if y] |
| for out_name in node_outputs[node_name]: |
| tensor_to_producer[out_name] = node_name |
| for inp_name in node_inputs[node_name]: |
| tensor_to_consumers[inp_name].append(node_name) |
|
|
| graph_inputs = {vi.name for vi in model.graph.input} |
| graph_outputs = {vi.name for vi in model.graph.output} |
| initializer_names = {init.name for init in model.graph.initializer} |
|
|
| return GraphIndex( |
| tensor_to_producer=tensor_to_producer, |
| tensor_to_consumers=tensor_to_consumers, |
| node_inputs=node_inputs, |
| node_outputs=node_outputs, |
| graph_inputs=graph_inputs, |
| graph_outputs=graph_outputs, |
| initializer_names=initializer_names, |
| node_order=node_order, |
| ) |
|
|
|
|
| def trace_nodes_between( |
| spec: SubGraphSpec, |
| index: GraphIndex, |
| ) -> Set[str]: |
| boundary = set(spec.start) | index.graph_inputs | index.initializer_names |
| visited_tensors: Set[str] = set() |
| stack = list(spec.end) |
| discovered_nodes: Set[str] = set() |
|
|
| while stack: |
| tensor = stack.pop() |
| if tensor in visited_tensors: |
| continue |
| visited_tensors.add(tensor) |
| if tensor in boundary: |
| continue |
| producer = index.tensor_to_producer.get(tensor) |
| if not producer: |
| continue |
| if producer in discovered_nodes: |
| continue |
| discovered_nodes.add(producer) |
| for upstream in index.node_inputs.get(producer, []): |
| if upstream and upstream not in boundary: |
| stack.append(upstream) |
| return discovered_nodes |
|
|
|
|
| def untouched_components( |
| all_nodes: Sequence[str], |
| covered_nodes: Set[str], |
| index: GraphIndex, |
| ) -> List[Set[str]]: |
| remaining = [n for n in all_nodes if n not in covered_nodes] |
| if not remaining: |
| return [] |
| adjacency: Dict[str, Set[str]] = {name: set() for name in remaining} |
| rem_set = set(remaining) |
| for node in remaining: |
| for out_name in index.node_outputs.get(node, []): |
| for consumer in index.tensor_to_consumers.get(out_name, []): |
| if consumer in rem_set: |
| adjacency[node].add(consumer) |
| adjacency[consumer].add(node) |
| for inp_name in index.node_inputs.get(node, []): |
| producer = index.tensor_to_producer.get(inp_name) |
| if producer in rem_set: |
| adjacency[node].add(producer) |
| adjacency[producer].add(node) |
|
|
| components: List[Set[str]] = [] |
| visited: Set[str] = set() |
| for node in remaining: |
| if node in visited: |
| continue |
| stack = [node] |
| comp: Set[str] = set() |
| while stack: |
| cur = stack.pop() |
| if cur in visited: |
| continue |
| visited.add(cur) |
| comp.add(cur) |
| stack.extend(adjacency[cur] - visited) |
| components.append(comp) |
| return components |
|
|
|
|
| def derive_interface( |
| nodes: Set[str], |
| index: GraphIndex, |
| ) -> (List[str], List[str]): |
| produced = set() |
| for node in nodes: |
| produced.update(index.node_outputs.get(node, [])) |
|
|
| start: Set[str] = set() |
| for node in nodes: |
| for inp in index.node_inputs.get(node, []): |
| producer = index.tensor_to_producer.get(inp) |
| if producer is None and inp not in index.initializer_names: |
| start.add(inp) |
| elif producer not in nodes and inp not in index.initializer_names: |
| start.add(inp) |
|
|
| end: Set[str] = set() |
| for node in nodes: |
| for out in index.node_outputs.get(node, []): |
| consumers = index.tensor_to_consumers.get(out, []) |
| if not consumers: |
| if out in index.graph_outputs: |
| end.add(out) |
| continue |
| if any(consumer not in nodes for consumer in consumers): |
| end.add(out) |
| end.update(index.graph_outputs & produced) |
|
|
| if not end and produced: |
| end = produced.copy() |
|
|
| return sorted(start), sorted(end) |
|
|
|
|
| def extract_model_file( |
| model_path: Path, |
| spec: SubGraphSpec, |
| output_dir: Path, |
| suffix: str, |
| ) -> Path: |
| head = sanitize(spec.start[0]) if spec.start else "const" |
| tail = sanitize(spec.end[0]) if spec.end else "out" |
| filename = f"{spec.label}_{head}_to_{tail}_{suffix}.onnx" |
| destination = output_dir / filename |
| onnx_utils.extract_model( |
| model_path.as_posix(), |
| destination.as_posix(), |
| input_names=spec.start, |
| output_names=spec.end, |
| check_model=False, |
| ) |
| logging.info("Saved %s (start=%s, end=%s)", destination.name, spec.start, spec.end) |
| return destination |
|
|
|
|
| def ordered_specs( |
| specs: Sequence[SubGraphSpec], |
| index: GraphIndex, |
| ) -> List[SubGraphSpec]: |
| available = set(index.graph_inputs) | index.initializer_names |
| pending = list(specs) |
| ordered: List[SubGraphSpec] = [] |
| while pending: |
| progressed = False |
| for spec in list(pending): |
| if set(spec.start).issubset(available): |
| ordered.append(spec) |
| available.update(spec.end) |
| pending.remove(spec) |
| progressed = True |
| if not progressed: |
| missing = {spec.label: sorted(set(spec.start) - available) for spec in pending} |
| raise RuntimeError( |
| "无法解析子图拓扑,缺少以下张量: %s" % missing |
| ) |
| return ordered |
|
|
|
|
| def run_full_model(model_path: Path, feed_dict: Dict[str, np.ndarray], providers: List[str]): |
| if ort is None: |
| raise RuntimeError("需要 onnxruntime 才能执行验证。") |
| session = ort.InferenceSession(model_path.as_posix(), providers=providers) |
| outputs = session.run(None, feed_dict) |
| names = [meta.name for meta in session.get_outputs()] |
| return dict(zip(names, outputs)) |
|
|
|
|
| def run_split_pipeline( |
| ordered_subgraphs: Sequence[SubGraphSpec], |
| feed_dict: Dict[str, np.ndarray], |
| providers: List[str], |
| ) -> Dict[str, np.ndarray]: |
| if ort is None: |
| raise RuntimeError("需要 onnxruntime 才能执行验证。") |
| tensor_store = dict(feed_dict) |
| for spec in ordered_subgraphs: |
| if spec.output_path is None: |
| raise RuntimeError(f"子图 {spec.label} 尚未生成 ONNX 文件。") |
| session = ort.InferenceSession(spec.output_path.as_posix(), providers=providers) |
| fetch_inputs = {} |
| for name in spec.start: |
| if name not in tensor_store: |
| raise KeyError( |
| f"子图 {spec.label} 缺少输入张量 {name},请确认切分顺序。" |
| ) |
| fetch_inputs[name] = tensor_store[name] |
| results = session.run(None, fetch_inputs) |
| for meta, value in zip(session.get_outputs(), results): |
| tensor_store[meta.name] = value |
| return tensor_store |
|
|
|
|
| def verify( |
| model_path: Path, |
| ordered_subgraphs: Sequence[SubGraphSpec], |
| feed_dict: Dict[str, np.ndarray], |
| providers: List[str], |
| rtol: float, |
| atol: float, |
| ) -> None: |
| full_outputs = run_full_model(model_path, feed_dict, providers) |
| split_store = run_split_pipeline(ordered_subgraphs, feed_dict, providers) |
| for name, ref in full_outputs.items(): |
| cand = split_store.get(name) |
| if cand is None: |
| raise AssertionError(f"切分流水线未产生模型输出 {name}") |
| if not np.allclose(ref, cand, rtol=rtol, atol=atol): |
| diff = np.max(np.abs(ref - cand)) |
| raise AssertionError( |
| f"输出 {name} 不匹配,最大偏差 {diff:.3e}" |
| ) |
| logging.info("切分模型与原始模型输出一致 (rtol=%g, atol=%g)。", rtol, atol) |
|
|
|
|
| def load_npz_inputs(npz_path: Path) -> Dict[str, np.ndarray]: |
| data = np.load(npz_path, allow_pickle=False) |
| return {key: data[key] for key in data.files} |
|
|
|
|
| def main() -> None: |
| parser = argparse.ArgumentParser(description="根据 sub_config 切分 ONNX 模型。") |
| parser.add_argument("--model", required=True, type=Path, help="原始 ONNX 路径") |
| parser.add_argument("--config", required=True, type=Path, help="pulsar2 配置 JSON") |
| parser.add_argument("--output-dir", required=False, default="./split-onnx", type=Path, help="保存子模型的目录") |
| parser.add_argument( |
| "--verify", |
| action="store_true", |
| help="生成后立即用 onnxruntime 校验输出是否一致", |
| ) |
| parser.add_argument( |
| "--input-npz", |
| type=Path, |
| help="包含模型所有输入张量的 npz 文件 (verify 模式需要)", |
| ) |
| parser.add_argument( |
| "--providers", |
| nargs="*", |
| default=["CPUExecutionProvider"], |
| help="onnxruntime 推理后端顺序", |
| ) |
| parser.add_argument("--rtol", type=float, default=1e-4, help="验证 rtol") |
| parser.add_argument("--atol", type=float, default=1e-5, help="验证 atol") |
| parser.add_argument("--log", default="INFO", help="日志等级") |
| args = parser.parse_args() |
|
|
| logging.basicConfig(level=getattr(logging, args.log.upper(), logging.INFO)) |
|
|
| model = onnx.load(args.model.as_posix()) |
| graph_index = build_graph_index(model) |
|
|
| with args.config.open("r", encoding="utf-8") as f: |
| config = json.load(f) |
|
|
| sub_configs = config.get("compiler", {}).get("sub_configs", []) |
| if not sub_configs: |
| raise ValueError("配置文件中未找到 compiler.sub_configs。") |
|
|
| specs: List[SubGraphSpec] = [] |
| covered_nodes: Set[str] = set() |
|
|
| for idx, entry in enumerate(sub_configs): |
| start = [name for name in entry.get("start_tensor_names", []) if name] |
| end = [name for name in entry.get("end_tensor_names", []) if name] |
| if not start or not end: |
| raise ValueError(f"sub_config[{idx}] 缺少 start/end tensor name。") |
| spec = SubGraphSpec( |
| label=f"cfg_{idx:02d}", |
| start=start, |
| end=end, |
| node_names=set(), |
| source="config", |
| ) |
| nodes = trace_nodes_between(spec, graph_index) |
| spec.node_names = nodes |
| covered_nodes.update(nodes) |
| specs.append(spec) |
|
|
| leftovers = untouched_components(graph_index.node_order, covered_nodes, graph_index) |
| for idx, component in enumerate(leftovers): |
| start, end = derive_interface(component, graph_index) |
| if not end: |
| logging.warning("自动发现的剩余子图 %d 没有输出,跳过。", idx) |
| continue |
| spec = SubGraphSpec( |
| label=f"auto_{idx:02d}", |
| start=start, |
| end=end, |
| node_names=component, |
| source="auto", |
| ) |
| specs.append(spec) |
| logging.info( |
| "自动补充子图 %s: start=%s end=%s (节点数=%d)", |
| spec.label, |
| spec.start, |
| spec.end, |
| len(component), |
| ) |
|
|
| ordered = ordered_specs(specs, graph_index) |
|
|
| args.output_dir.mkdir(parents=True, exist_ok=True) |
| for spec in ordered: |
| spec.output_path = extract_model_file(args.model, spec, args.output_dir, spec.source) |
|
|
| if args.verify: |
| if args.input_npz is None: |
| raise ValueError("verify 模式需要 --input-npz 提供输入数据。") |
| feed = load_npz_inputs(args.input_npz) |
| missing_inputs = graph_index.graph_inputs - feed.keys() |
| if missing_inputs: |
| raise ValueError(f"npz 中缺少以下模型输入: {sorted(missing_inputs)}") |
| verify(args.model, ordered, feed, args.providers, args.rtol, args.atol) |
|
|
|
|
| if __name__ == "__main__": |
| """ |
| 用法示例: |
| python python/VideoX-Fun/scripts/split_onnx_by_subconfigs.py \ |
| --model /path/to/full.onnx \ |
| --config python/VideoX-Fun/pulsar2_configs/transformers_subgraph.json \ |
| --output-dir /tmp/sliced_models \ |
| --verify \ |
| --input-npz /path/to/inputs.npz |
| """ |
| main() |
|
|