| |
| |
| |
| |
| |
| |
| import os, glob, sys, torchaudio |
| import numpy as np |
| import scipy.io.wavfile as Wavfile |
| import matplotlib.pyplot as plt |
| from sklearn.metrics import confusion_matrix |
|
|
| os.system('git clone https://github.com/snakers4/silero-vad.git && cd silero-vad && git checkout bbf22a00640614309d60aba5467189b48c7c6ecc && cd ..') |
| sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "./silero-vad/src"))) |
| from silero_vad.utils_vad import VADIterator, init_jit_model |
|
|
| sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../include"))) |
| from ten_vad import TenVad |
|
|
| def convert_label_to_framewise(label_file, hop_size): |
| frame_duration = hop_size / 16000 |
| with open(label_file, "r") as f: |
| lines = f.readlines() |
| content = lines[0].strip().split(",")[1:] |
| start = np.array( |
| content[::3], dtype=float |
| ) |
| end = np.array( |
| content[1:][::3], dtype=float |
| ) |
| lab_manual = np.array( |
| content[2:][::3], dtype=int |
| ) |
| assert ( |
| len(start) == len(end) |
| and len(start) == len(lab_manual) |
| and len(end) == len(lab_manual) |
| ) |
| |
| num = np.array( |
| np.round(((end - start) / frame_duration)), dtype=np.int32 |
| ) |
| label_framewise = np.array([]) |
| for segment_idx in range(len(num)): |
| cur_lab = int(lab_manual[segment_idx]) |
| num_segment = num[segment_idx] |
|
|
| if cur_lab == 1: |
| vad_result_this_segment = np.ones(num_segment) |
| elif cur_lab == 0: |
| vad_result_this_segment = np.zeros(num_segment) |
| label_framewise = np.append(label_framewise, vad_result_this_segment) |
| frame_num = min( |
| label_framewise.__len__(), int((end[-1] - start[0]) / frame_duration) |
| ) |
| label_framewise = label_framewise[:frame_num] |
|
|
| return label_framewise |
|
|
|
|
| def read_file(file_path): |
| with open(file_path, "r") as f: |
| lines = f.readlines() |
| lines_arr = np.array([]) |
| for line in lines: |
| lines_arr = np.append(lines_arr, float(line.strip())) |
|
|
| return lines_arr |
|
|
| def get_precision_recall(VAD_result, label, threshold): |
| vad_result_hard = np.where(VAD_result >= threshold, 1, 0) |
|
|
| |
| TN, FP, FN, TP = confusion_matrix(label, vad_result_hard).ravel() |
|
|
| |
| precision = TP / (TP + FP) if (TP + FP) > 0 else 0 |
| recall = TP / (TP + FN) if (TP + FN) > 0 else 0 |
| FPR = FP / (FP + TN) if (FP + TN) > 0 else 0 |
| FNR = FN / (TP + FN) if (TP + FN) > 0 else 0 |
|
|
| return precision, recall, FPR, FNR |
|
|
| def silero_vad_inference_single_file(wav_path): |
| current_directory = os.path.dirname(os.path.abspath(__file__)) |
| model = init_jit_model(f'{current_directory}/silero-vad/src/silero_vad/data/silero_vad.jit') |
| vad_iterator = VADIterator(model) |
| window_size_samples = 512 |
| speech_probs = np.array([]) |
| |
| wav, sr = torchaudio.load(wav_path) |
| wav = wav.squeeze(0) |
| for i in range(0, len(wav), window_size_samples): |
| chunk = wav[i: i+ window_size_samples] |
| if len(chunk) < window_size_samples: |
| break |
| speech_prob = model(chunk, sr).item() |
| speech_probs = np.append(speech_probs, speech_prob) |
| vad_iterator.reset_states() |
| |
| return speech_probs, window_size_samples |
|
|
| def ten_vad_process_wav(ten_vad_instance, wav_path, hop_size=256): |
| _, data = Wavfile.read(wav_path) |
| num_frames = data.shape[0] // hop_size |
| voice_prob_arr = np.array([]) |
| for i in range(num_frames): |
| input_data = data[i * hop_size: (i + 1) * hop_size] |
| voice_prob, _ = ten_vad_instance.process(input_data) |
| voice_prob_arr = np.append(voice_prob_arr, voice_prob) |
|
|
| return voice_prob_arr |
|
|
| if __name__ == "__main__": |
| |
| script_dir = os.path.dirname(os.path.abspath(__file__)) |
|
|
| |
| test_dir = f"{script_dir}/../testset" |
|
|
| |
| hop_size = 256 |
| threshold = 0.5 |
| label_all, vad_result_ten_vad_all = np.array([]), np.array([]) |
| label_hop_512_all, vad_result_silero_vad_all = np.array([]), np.array([]) |
| wav_list = glob.glob(f"{test_dir}/*.wav") |
|
|
| |
| print("Start processing") |
| for wav_path in wav_list: |
| |
| ten_vad_instance = TenVad(hop_size, threshold) |
| label_file = wav_path.replace(".wav", ".scv") |
| label = convert_label_to_framewise( |
| label_file, hop_size=hop_size |
| ) |
| vad_result_ten_vad = ten_vad_process_wav( |
| ten_vad_instance, wav_path, hop_size=hop_size |
| ) |
| frame_num = min(label.__len__(), vad_result_ten_vad.__len__()) |
| vad_result_ten_vad_all = np.append( |
| vad_result_ten_vad_all, vad_result_ten_vad[1:frame_num] |
| ) |
| label_all = np.append(label_all, label[:frame_num - 1]) |
| del ten_vad_instance |
|
|
| |
| label_hop_512 = convert_label_to_framewise( |
| label_file, hop_size=512 |
| ) |
| vad_result_silero_vad, _ = silero_vad_inference_single_file(wav_path) |
| frame_num_silero_vad = min(label_hop_512.__len__(), vad_result_silero_vad.__len__()) |
| vad_result_silero_vad_all = np.append(vad_result_silero_vad_all, vad_result_silero_vad[:frame_num_silero_vad]) |
| label_hop_512_all = np.append(label_hop_512_all, label_hop_512[:frame_num_silero_vad]) |
|
|
| |
| threshold_arr = np.arange(0, 1.01, 0.01) |
| pr_data_arr = np.zeros((threshold_arr.__len__(), 3)) |
| pr_data_silero_vad_arr = np.zeros((threshold_arr.__len__(), 3)) |
|
|
| for ind, threshold in enumerate(threshold_arr): |
| precision, recall, FPR, FNR = get_precision_recall(vad_result_ten_vad_all, label_all, threshold) |
| pr_data_arr[ind] = precision, recall, threshold |
|
|
| precision_silero_vad, recall_silero_vad, FPR_silero_vad, FNR_silero_vad = get_precision_recall(vad_result_silero_vad_all, label_hop_512_all, threshold) |
| pr_data_silero_vad_arr[ind] = precision_silero_vad, recall_silero_vad, threshold |
|
|
| |
| print("Plotting PR Curve") |
| pr_data_arr_to_plot = pr_data_arr[:-1] |
| plt.plot( |
| pr_data_arr_to_plot[:, 1], |
| pr_data_arr_to_plot[:, 0], |
| color="red", |
| label="TEN VAD", |
| ) |
| pr_data_silero_vad_arr_to_plot = pr_data_silero_vad_arr[:-1] |
| plt.plot( |
| pr_data_silero_vad_arr_to_plot[:, 1], |
| pr_data_silero_vad_arr_to_plot[:, 0], |
| color="blue", |
| label="Silero VAD", |
| ) |
|
|
| plt.xlabel("Recall", fontsize=14, fontweight="bold", color="black") |
| plt.ylabel("Precision", fontsize=14, fontweight="bold", color="black") |
| legend = plt.legend() |
| legend.get_texts()[0].set_fontweight("bold") |
| legend.get_texts()[1].set_fontweight("bold") |
| plt.grid(True) |
| plt.xlim(0.65, 1) |
| plt.ylim(0.7, 1) |
| plt.title( |
| "Precision-Recall Curve of TEN VAD on TEN-VAD-TestSet", |
| fontsize=12, |
| color="black", |
| fontweight="bold", |
| ) |
| save_path = f"{script_dir}/PR_Curves.png" |
| plt.savefig(save_path, dpi=300, bbox_inches="tight") |
| print(f"PR Curves png file saved, save path: {save_path}") |
|
|
| |
| pr_data_save_path = f"{script_dir}/PR_data_TEN_VAD.txt" |
| with open(pr_data_save_path, "w") as f: |
| for ind in range(pr_data_arr.shape[0]): |
| precision, recall, threshold = ( |
| pr_data_arr[ind, 0], |
| pr_data_arr[ind, 1], |
| pr_data_arr[ind, 2], |
| ) |
| f.write(f"{threshold:.2f} {precision:.4f} {recall:.4f}\n") |
| print("Processing done!") |
|
|
|
|
|
|