| | import torch |
| | import numpy as np |
| |
|
| | from mobile_sam import sam_model_registry |
| | from .onnx_image_encoder import ImageEncoderOnnxModel |
| |
|
| | import os |
| | import argparse |
| | import warnings |
| |
|
| | try: |
| | import onnxruntime |
| |
|
| | onnxruntime_exists = True |
| | except ImportError: |
| | onnxruntime_exists = False |
| |
|
| | parser = argparse.ArgumentParser( |
| | description="Export the SAM image encoder to an ONNX model." |
| | ) |
| |
|
| | parser.add_argument( |
| | "--checkpoint", |
| | type=str, |
| | required=True, |
| | help="The path to the SAM model checkpoint.", |
| | ) |
| |
|
| | parser.add_argument( |
| | "--output", type=str, required=True, help="The filename to save the ONNX model to." |
| | ) |
| |
|
| | parser.add_argument( |
| | "--model-type", |
| | type=str, |
| | required=True, |
| | help="In ['default', 'vit_h', 'vit_l', 'vit_b']. Which type of SAM model to export.", |
| | ) |
| |
|
| | parser.add_argument( |
| | "--use-preprocess", |
| | action="store_true", |
| | help="Whether to preprocess the image by resizing, standardizing, etc.", |
| | ) |
| |
|
| | parser.add_argument( |
| | "--opset", |
| | type=int, |
| | default=17, |
| | help="The ONNX opset version to use. Must be >=11", |
| | ) |
| |
|
| | parser.add_argument( |
| | "--quantize-out", |
| | type=str, |
| | default=None, |
| | help=( |
| | "If set, will quantize the model and save it with this name. " |
| | "Quantization is performed with quantize_dynamic from onnxruntime.quantization.quantize." |
| | ), |
| | ) |
| |
|
| | parser.add_argument( |
| | "--gelu-approximate", |
| | action="store_true", |
| | help=( |
| | "Replace GELU operations with approximations using tanh. Useful " |
| | "for some runtimes that have slow or unimplemented erf ops, used in GELU." |
| | ), |
| | ) |
| |
|
| |
|
| | def run_export( |
| | model_type: str, |
| | checkpoint: str, |
| | output: str, |
| | use_preprocess: bool, |
| | opset: int, |
| | gelu_approximate: bool = False, |
| | ): |
| | print("Loading model...") |
| | sam = sam_model_registry[model_type](checkpoint=checkpoint) |
| |
|
| | onnx_model = ImageEncoderOnnxModel( |
| | model=sam, |
| | use_preprocess=use_preprocess, |
| | pixel_mean=[123.675, 116.28, 103.53], |
| | pixel_std=[58.395, 57.12, 57.375], |
| | ) |
| |
|
| | if gelu_approximate: |
| | for n, m in onnx_model.named_modules(): |
| | if isinstance(m, torch.nn.GELU): |
| | m.approximate = "tanh" |
| |
|
| | image_size = sam.image_encoder.img_size |
| | if use_preprocess: |
| | dummy_input = { |
| | "input_image": torch.randn((image_size, image_size, 3), dtype=torch.float) |
| | } |
| | dynamic_axes = { |
| | "input_image": {0: "image_height", 1: "image_width"}, |
| | } |
| | else: |
| | dummy_input = { |
| | "input_image": torch.randn( |
| | (1, 3, image_size, image_size), dtype=torch.float |
| | ) |
| | } |
| | dynamic_axes = None |
| |
|
| | _ = onnx_model(**dummy_input) |
| |
|
| | output_names = ["image_embeddings"] |
| |
|
| | with warnings.catch_warnings(): |
| | warnings.filterwarnings("ignore", category=torch.jit.TracerWarning) |
| | warnings.filterwarnings("ignore", category=UserWarning) |
| | print(f"Exporting onnx model to {output}...") |
| | if model_type == "vit_h": |
| | output_dir, output_file = os.path.split(output) |
| | os.makedirs(output_dir, mode=0o777, exist_ok=True) |
| | torch.onnx.export( |
| | onnx_model, |
| | tuple(dummy_input.values()), |
| | output, |
| | export_params=True, |
| | verbose=False, |
| | opset_version=opset, |
| | do_constant_folding=True, |
| | input_names=list(dummy_input.keys()), |
| | output_names=output_names, |
| | dynamic_axes=dynamic_axes, |
| | ) |
| | else: |
| | with open(output, "wb") as f: |
| | torch.onnx.export( |
| | onnx_model, |
| | tuple(dummy_input.values()), |
| | f, |
| | export_params=True, |
| | verbose=False, |
| | opset_version=opset, |
| | do_constant_folding=True, |
| | input_names=list(dummy_input.keys()), |
| | output_names=output_names, |
| | dynamic_axes=dynamic_axes, |
| | ) |
| |
|
| | if onnxruntime_exists: |
| | ort_inputs = {k: to_numpy(v) for k, v in dummy_input.items()} |
| | providers = ["CPUExecutionProvider"] |
| |
|
| | if model_type == "vit_h": |
| | session_option = onnxruntime.SessionOptions() |
| | ort_session = onnxruntime.InferenceSession(output, providers=providers) |
| | param_file = os.listdir(output_dir) |
| | param_file.remove(output_file) |
| | for i, layer in enumerate(param_file): |
| | with open(os.path.join(output_dir, layer), "rb") as fp: |
| | weights = np.frombuffer(fp.read(), dtype=np.float32) |
| | weights = onnxruntime.OrtValue.ortvalue_from_numpy(weights) |
| | session_option.add_initializer(layer, weights) |
| | else: |
| | ort_session = onnxruntime.InferenceSession(output, providers=providers) |
| |
|
| | _ = ort_session.run(None, ort_inputs) |
| | print("Model has successfully been run with ONNXRuntime.") |
| |
|
| |
|
| | def to_numpy(tensor): |
| | return tensor.cpu().numpy() |
| |
|
| |
|
| | if __name__ == "__main__": |
| | args = parser.parse_args() |
| | run_export( |
| | model_type=args.model_type, |
| | checkpoint=args.checkpoint, |
| | output=args.output, |
| | use_preprocess=args.use_preprocess, |
| | opset=args.opset, |
| | gelu_approximate=args.gelu_approximate, |
| | ) |
| |
|
| | if args.quantize_out is not None: |
| | assert onnxruntime_exists, "onnxruntime is required to quantize the model." |
| | from onnxruntime.quantization import QuantType |
| | from onnxruntime.quantization.quantize import quantize_dynamic |
| |
|
| | print(f"Quantizing model and writing to {args.quantize_out}...") |
| | quantize_dynamic( |
| | model_input=args.output, |
| | model_output=args.quantize_out, |
| | optimize_model=True, |
| | per_channel=False, |
| | reduce_range=False, |
| | weight_type=QuantType.QUInt8, |
| | ) |
| | print("Done!") |
| |
|