|
|
|
|
|
|
|
|
|
|
|
|
| import torch
|
|
|
| from segment_anything import sam_model_registry
|
| from segment_anything.utils.onnx import SamOnnxModel
|
|
|
| import argparse
|
| import warnings
|
|
|
| try:
|
| import onnxruntime
|
|
|
| onnxruntime_exists = True
|
| except ImportError:
|
| onnxruntime_exists = False
|
|
|
| parser = argparse.ArgumentParser(
|
| description="Export the SAM prompt encoder and mask decoder to an ONNX model."
|
| )
|
|
|
| parser.add_argument(
|
| "--checkpoint", type=str, required=True, help="The path to the SAM model checkpoint."
|
| )
|
|
|
| parser.add_argument(
|
| "--output", type=str, required=True, help="The filename to save the ONNX model to."
|
| )
|
|
|
| parser.add_argument(
|
| "--model-type",
|
| type=str,
|
| required=True,
|
| help="In ['default', 'vit_h', 'vit_l', 'vit_b']. Which type of SAM model to export.",
|
| )
|
|
|
| parser.add_argument(
|
| "--return-single-mask",
|
| action="store_true",
|
| help=(
|
| "If true, the exported ONNX model will only return the best mask, "
|
| "instead of returning multiple masks. For high resolution images "
|
| "this can improve runtime when upscaling masks is expensive."
|
| ),
|
| )
|
|
|
| parser.add_argument(
|
| "--opset",
|
| type=int,
|
| default=17,
|
| help="The ONNX opset version to use. Must be >=11",
|
| )
|
|
|
| parser.add_argument(
|
| "--quantize-out",
|
| type=str,
|
| default=None,
|
| help=(
|
| "If set, will quantize the model and save it with this name. "
|
| "Quantization is performed with quantize_dynamic from onnxruntime.quantization.quantize."
|
| ),
|
| )
|
|
|
| parser.add_argument(
|
| "--gelu-approximate",
|
| action="store_true",
|
| help=(
|
| "Replace GELU operations with approximations using tanh. Useful "
|
| "for some runtimes that have slow or unimplemented erf ops, used in GELU."
|
| ),
|
| )
|
|
|
| parser.add_argument(
|
| "--use-stability-score",
|
| action="store_true",
|
| help=(
|
| "Replaces the model's predicted mask quality score with the stability "
|
| "score calculated on the low resolution masks using an offset of 1.0. "
|
| ),
|
| )
|
|
|
| parser.add_argument(
|
| "--return-extra-metrics",
|
| action="store_true",
|
| help=(
|
| "The model will return five results: (masks, scores, stability_scores, "
|
| "areas, low_res_logits) instead of the usual three. This can be "
|
| "significantly slower for high resolution outputs."
|
| ),
|
| )
|
|
|
|
|
| def run_export(
|
| model_type: str,
|
| checkpoint: str,
|
| output: str,
|
| opset: int,
|
| return_single_mask: bool,
|
| gelu_approximate: bool = False,
|
| use_stability_score: bool = False,
|
| return_extra_metrics=False,
|
| ):
|
| print("Loading model...")
|
| sam = sam_model_registry[model_type](checkpoint=checkpoint)
|
|
|
| onnx_model = SamOnnxModel(
|
| model=sam,
|
| return_single_mask=return_single_mask,
|
| use_stability_score=use_stability_score,
|
| return_extra_metrics=return_extra_metrics,
|
| )
|
|
|
| if gelu_approximate:
|
| for n, m in onnx_model.named_modules():
|
| if isinstance(m, torch.nn.GELU):
|
| m.approximate = "tanh"
|
|
|
| dynamic_axes = {
|
| "point_coords": {1: "num_points"},
|
| "point_labels": {1: "num_points"},
|
| }
|
|
|
| embed_dim = sam.prompt_encoder.embed_dim
|
| embed_size = sam.prompt_encoder.image_embedding_size
|
| mask_input_size = [4 * x for x in embed_size]
|
| dummy_inputs = {
|
| "image_embeddings": torch.randn(1, embed_dim, *embed_size, dtype=torch.float),
|
| "point_coords": torch.randint(low=0, high=1024, size=(1, 5, 2), dtype=torch.float),
|
| "point_labels": torch.randint(low=0, high=4, size=(1, 5), dtype=torch.float),
|
| "mask_input": torch.randn(1, 1, *mask_input_size, dtype=torch.float),
|
| "has_mask_input": torch.tensor([1], dtype=torch.float),
|
| "orig_im_size": torch.tensor([1500, 2250], dtype=torch.float),
|
| }
|
|
|
| _ = onnx_model(**dummy_inputs)
|
|
|
| output_names = ["masks", "iou_predictions", "low_res_masks"]
|
|
|
| with warnings.catch_warnings():
|
| warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
|
| warnings.filterwarnings("ignore", category=UserWarning)
|
| with open(output, "wb") as f:
|
| print(f"Exporting onnx model to {output}...")
|
| torch.onnx.export(
|
| onnx_model,
|
| tuple(dummy_inputs.values()),
|
| f,
|
| export_params=True,
|
| verbose=False,
|
| opset_version=opset,
|
| do_constant_folding=True,
|
| input_names=list(dummy_inputs.keys()),
|
| output_names=output_names,
|
| dynamic_axes=dynamic_axes,
|
| )
|
|
|
| if onnxruntime_exists:
|
| ort_inputs = {k: to_numpy(v) for k, v in dummy_inputs.items()}
|
|
|
| providers = ["CPUExecutionProvider"]
|
| ort_session = onnxruntime.InferenceSession(output, providers=providers)
|
| _ = ort_session.run(None, ort_inputs)
|
| print("Model has successfully been run with ONNXRuntime.")
|
|
|
|
|
| def to_numpy(tensor):
|
| return tensor.cpu().numpy()
|
|
|
|
|
| if __name__ == "__main__":
|
| args = parser.parse_args()
|
| run_export(
|
| model_type=args.model_type,
|
| checkpoint=args.checkpoint,
|
| output=args.output,
|
| opset=args.opset,
|
| return_single_mask=args.return_single_mask,
|
| gelu_approximate=args.gelu_approximate,
|
| use_stability_score=args.use_stability_score,
|
| return_extra_metrics=args.return_extra_metrics,
|
| )
|
|
|
| if args.quantize_out is not None:
|
| assert onnxruntime_exists, "onnxruntime is required to quantize the model."
|
| from onnxruntime.quantization import QuantType
|
| from onnxruntime.quantization.quantize import quantize_dynamic
|
|
|
| print(f"Quantizing model and writing to {args.quantize_out}...")
|
| quantize_dynamic(
|
| model_input=args.output,
|
| model_output=args.quantize_out,
|
| optimize_model=True,
|
| per_channel=False,
|
| reduce_range=False,
|
| weight_type=QuantType.QUInt8,
|
| )
|
| print("Done!")
|
|
|