YingMusic-Singer / initialization.py
xjsc0's picture
11
99cf7e1
"""
YingMusic-Singer Initialization Script
Downloads required checkpoints from HuggingFace based on task type.
Usage:
python initialization.py --task infer
python initialization.py --task train
"""
import argparse
import os
from huggingface_hub import hf_hub_download
REPO_ID = "ASLP-lab/YingMusic-Singer"
CKPT_DIR = "ckpts"
# Files required for each task
INFER_FILES = [
"ckpts/MelBandRoformer.ckpt",
"ckpts/config_vocals_mel_band_roformer_kj.yaml",
]
TRAIN_EXTRA_FILES = [
"ckpts/YingMusicSinger_model.pt",
"ckpts/model_ckpt_steps_100000_simplified.ckpt",
"ckpts/stable_audio_2_0_vae_20hz_official.ckpt",
]
TASK_FILES = {
"infer": INFER_FILES,
"train": INFER_FILES + TRAIN_EXTRA_FILES,
}
def download_files(task: str):
files = TASK_FILES[task]
os.makedirs(CKPT_DIR, exist_ok=True)
print(f"Task: {task} | Downloading {len(files)} file(s) to {CKPT_DIR}/")
for remote_path in files:
filename = os.path.basename(remote_path)
local_path = os.path.join(CKPT_DIR, filename)
if os.path.exists(local_path):
print(f" [skip] {filename} already exists")
continue
print(f" [download] {filename} ...")
hf_hub_download(
repo_id=REPO_ID,
filename=remote_path,
local_dir=".",
)
print(f" [done] {filename}")
print("All downloads complete.")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Download YingMusic-Singer checkpoints"
)
parser.add_argument(
"--task",
type=str,
required=True,
choices=list(TASK_FILES.keys()),
help="Task type: 'infer' for inference, 'train' for training",
)
args = parser.parse_args()
download_files(args.task)