| import torch |
| import matplotlib.pyplot as plt |
| import matplotlib.patches as patches |
| import matplotlib |
| import seaborn as sns |
| import numpy as np |
| import threading |
| import onnx2torch |
| import onnxscript |
| from nemo.collections.asr.models import SortformerEncLabelModel |
| from pydub import AudioSegment |
| import coremltools as ct |
| from pydub.playback import play as play_audio |
|
|
| |
| device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") |
| audio_file = "audio.wav" |
|
|
| |
| print("Loading audio file for playback...") |
| full_audio = AudioSegment.from_wav(audio_file) |
|
|
| |
| model = SortformerEncLabelModel.from_pretrained( |
| "nvidia/diar_streaming_sortformer_4spk-v2.1", |
| map_location=device |
| ) |
| model.eval() |
| model.to(device) |
|
|
| print(model.output_names) |
|
|
| def streaming_input_examples(self): |
| """Input tensor examples for exporting streaming version of model""" |
| batch_size = 4 |
| feat_in = self.cfg.get("preprocessor", {}).get("features", 128) |
| chunk = torch.rand([batch_size, 120, feat_in]).to(self.device) |
| chunk_lengths = torch.tensor([120] * batch_size).to(self.device) |
| spkcache = torch.randn([batch_size, 188, 512]).to(self.device) |
| spkcache_lengths = torch.tensor([40, 188, 0, 68]).to(self.device) |
| fifo = torch.randn([batch_size, 188, 512]).to(self.device) |
| fifo_lengths = torch.tensor([50, 88, 0, 90]).to(self.device) |
| return chunk, chunk_lengths, spkcache, spkcache_lengths, fifo, fifo_lengths |
|
|
|
|
| inputs = streaming_input_examples(model) |
|
|
| export_out = model.export("streaming-sortformer.onnx", input_example=inputs) |
| scripted_model = onnx2torch.convert('streaming-sortformer.onnx') |
|
|
| BATCH_SIZE = 4 |
| CHUNK_LEN = 120 |
| FEAT_DIM = 128 |
| CACHE_LEN = 188 |
| EMBED_DIM = 512 |
|
|
| ct_inputs = [ |
| ct.TensorType(name="chunk", shape=(BATCH_SIZE, CHUNK_LEN, FEAT_DIM)), |
| ct.TensorType(name="chunk_lens", shape=(BATCH_SIZE,)), |
| ct.TensorType(name="spkcache", shape=(BATCH_SIZE, CACHE_LEN, EMBED_DIM)), |
| ct.TensorType(name="spkcache_lens", shape=(BATCH_SIZE,)), |
| ct.TensorType(name="fifo", shape=(BATCH_SIZE, CACHE_LEN, EMBED_DIM)), |
| ct.TensorType(name="fifo_lens", shape=(BATCH_SIZE,)), |
| ] |
|
|
| ct_outputs = [ |
| ct.TensorType(name="preds"), |
| ct.TensorType(name="new_spkcache"), |
| ct.TensorType(name="new_spkcache_lens"), |
| ct.TensorType(name="new_fifo"), |
| ct.TensorType(name="new_fifo_lens"), |
| ] |
|
|
|
|
| ct.convert( |
| scripted_model, |
| inputs=ct_inputs, |
| outputs=ct_outputs, |
| convert_to="mlprogram", |
| minimum_deployment_target=ct.target.iOS17, |
| compute_precision=ct.precision.FLOAT16, |
| ) |
|
|