| import json
|
| from collections import defaultdict
|
| from random import shuffle
|
| from typing import Optional
|
|
|
| from tqdm import tqdm
|
| import click
|
| from text.cleaner import clean_text_bert
|
| import os
|
| import torch
|
| from text.symbols import symbols, num_languages, num_tones
|
|
|
| @click.command()
|
| @click.option(
|
| "--metadata",
|
| default="data/example/metadata.list",
|
| type=click.Path(exists=True, file_okay=True, dir_okay=False),
|
| )
|
| @click.option("--cleaned-path", default=None)
|
| @click.option("--train-path", default=None)
|
| @click.option("--val-path", default=None)
|
| @click.option(
|
| "--config_path",
|
| default="configs/config.json",
|
| type=click.Path(exists=True, file_okay=True, dir_okay=False),
|
| )
|
| @click.option("--val-per-spk", default=4)
|
| @click.option("--max-val-total", default=8)
|
| @click.option("--clean/--no-clean", default=True)
|
| def main(
|
| metadata: str,
|
| cleaned_path: Optional[str],
|
| train_path: str,
|
| val_path: str,
|
| config_path: str,
|
| val_per_spk: int,
|
| max_val_total: int,
|
| clean: bool,
|
| ):
|
| if train_path is None:
|
| train_path = os.path.join(os.path.dirname(metadata), 'train.list')
|
| if val_path is None:
|
| val_path = os.path.join(os.path.dirname(metadata), 'val.list')
|
| out_config_path = os.path.join(os.path.dirname(metadata), 'config.json')
|
|
|
| if cleaned_path is None:
|
| cleaned_path = metadata + ".cleaned"
|
|
|
| if clean:
|
| out_file = open(cleaned_path, "w", encoding="utf-8")
|
| new_symbols = []
|
| for line in tqdm(open(metadata, encoding="utf-8").readlines()):
|
| try:
|
| utt, spk, language, text = line.strip().split("|")
|
| norm_text, phones, tones, word2ph, bert = clean_text_bert(text, language, device='cuda:0')
|
| for ph in phones:
|
| if ph not in symbols and ph not in new_symbols:
|
| new_symbols.append(ph)
|
| print('update!, now symbols:')
|
| print(new_symbols)
|
| with open(f'{language}_symbol.txt', 'w') as f:
|
| f.write(f'{new_symbols}')
|
|
|
| assert len(phones) == len(tones)
|
| assert len(phones) == sum(word2ph)
|
| out_file.write(
|
| "{}|{}|{}|{}|{}|{}|{}\n".format(
|
| utt,
|
| spk,
|
| language,
|
| norm_text,
|
| " ".join(phones),
|
| " ".join([str(i) for i in tones]),
|
| " ".join([str(i) for i in word2ph]),
|
| )
|
| )
|
| bert_path = utt.replace(".wav", ".bert.pt")
|
| os.makedirs(os.path.dirname(bert_path), exist_ok=True)
|
| torch.save(bert.cpu(), bert_path)
|
| except Exception as error:
|
| print("err!", line, error)
|
|
|
| out_file.close()
|
|
|
| metadata = cleaned_path
|
|
|
| spk_utt_map = defaultdict(list)
|
| spk_id_map = {}
|
| current_sid = 0
|
|
|
| with open(metadata, encoding="utf-8") as f:
|
| for line in f.readlines():
|
| utt, spk, language, text, phones, tones, word2ph = line.strip().split("|")
|
| spk_utt_map[spk].append(line)
|
|
|
| if spk not in spk_id_map.keys():
|
| spk_id_map[spk] = current_sid
|
| current_sid += 1
|
|
|
| train_list = []
|
| val_list = []
|
|
|
| for spk, utts in spk_utt_map.items():
|
| shuffle(utts)
|
| val_list += utts[:val_per_spk]
|
| train_list += utts[val_per_spk:]
|
|
|
| if len(val_list) > max_val_total:
|
| train_list += val_list[max_val_total:]
|
| val_list = val_list[:max_val_total]
|
|
|
| with open(train_path, "w", encoding="utf-8") as f:
|
| for line in train_list:
|
| f.write(line)
|
|
|
| with open(val_path, "w", encoding="utf-8") as f:
|
| for line in val_list:
|
| f.write(line)
|
|
|
| config = json.load(open(config_path, encoding="utf-8"))
|
| config["data"]["spk2id"] = spk_id_map
|
|
|
| config["data"]["training_files"] = train_path
|
| config["data"]["validation_files"] = val_path
|
| config["data"]["n_speakers"] = len(spk_id_map)
|
| config["num_languages"] = num_languages
|
| config["num_tones"] = num_tones
|
| config["symbols"] = symbols
|
|
|
| with open(out_config_path, "w", encoding="utf-8") as f:
|
| json.dump(config, f, indent=2, ensure_ascii=False)
|
|
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|