| import sys |
| import os |
| import torch |
| import json |
| from safetensors.torch import load_file |
|
|
| |
| sys.path.append(os.path.join(os.getcwd(), 'DA-2-repo/src')) |
|
|
| try: |
| from da2.model.spherevit import SphereViT |
| except ImportError as e: |
| print(f"Error importing SphereViT: {e}") |
| sys.exit(1) |
|
|
| |
| config_path = 'DA-2-repo/configs/infer.json' |
| with open(config_path, 'r') as f: |
| config = json.load(f) |
|
|
| |
| |
| |
| H, W = 546, 1092 |
| config['inference']['min_pixels'] = H * W |
| config['inference']['max_pixels'] = H * W |
|
|
| print(f"Initializing model with input size {W}x{H}...") |
| |
| model = SphereViT(config) |
|
|
| |
| print("Loading weights from model.safetensors...") |
| try: |
| weights = load_file('model.safetensors') |
| missing, unexpected = model.load_state_dict(weights, strict=False) |
| if missing: |
| print(f"Missing keys: {len(missing)}") |
| |
| if unexpected: |
| print(f"Unexpected keys: {len(unexpected)}") |
| |
| except Exception as e: |
| print(f"Error loading weights: {e}") |
| sys.exit(1) |
|
|
| print("Exporting model in FP32 (full precision)...") |
| model.eval() |
|
|
| |
| dummy_input = torch.randn(1, 3, H, W) |
|
|
| |
| output_file = "onnx/model.onnx" |
| print(f"Exporting to {output_file}...") |
| try: |
| torch.onnx.export( |
| model, |
| dummy_input, |
| output_file, |
| opset_version=17, |
| input_names=["pixel_values"], |
| output_names=["predicted_depth"], |
| dynamic_axes={ |
| "pixel_values": {0: "batch_size"}, |
| "predicted_depth": {0: "batch_size"} |
| }, |
| export_params=True, |
| do_constant_folding=True, |
| verbose=False |
| ) |
|
|
| print(f"Successfully exported to {output_file}") |
| |
| try: |
| from onnxruntime.quantization import quantize_dynamic, QuantType |
| quantized_output_file = "onnx/model_quantized.onnx" |
| print(f"Quantizing model to {quantized_output_file}...") |
| quantize_dynamic( |
| output_file, |
| quantized_output_file, |
| weight_type=QuantType.QInt8 |
| ) |
| print(f"Successfully quantized to {quantized_output_file}") |
| except Exception as qe: |
| print(f"Error during quantization: {qe}") |
| import traceback |
| traceback.print_exc() |
| except Exception as e: |
| print(f"Error exporting to ONNX: {e}") |
| import traceback |
| traceback.print_exc() |
|
|