xjsc0's picture
1
64ec292
import argparse
import os
import sys
# Добавляем корень репозитория в системный путь
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from inference import proc_folder
from scripts.redact_config import redact_config
from scripts.trim import trim_directory
from scripts.valid_to_inference import copying_files
from train import train_model
from valid import check_validation
base_args = {
"device_ids": "0",
"model_type": "",
"start_check_point": "",
"config_path": "",
"data_path": "",
"valid_path": "",
"results_path": "tests/train_results",
"store_dir": "tests/valid_inference_result",
"input_folder": "",
"metrics": [
"neg_log_wmse",
"l1_freq",
"si_sdr",
"sdr",
"aura_stft",
"aura_mrstft",
"bleedless",
"fullness",
],
"max_folders": 2,
}
def parse_args(dict_args):
parser = argparse.ArgumentParser()
parser.add_argument("--check_train", action="store_true", help="Check train or not")
parser.add_argument("--check_valid", action="store_true", help="Check train or not")
parser.add_argument(
"--check_inference", action="store_true", help="Check train or not"
)
parser.add_argument(
"--device_ids", type=str, help="Device IDs for training/inference"
)
parser.add_argument("--model_type", type=str, help="Model type")
parser.add_argument(
"--start_check_point", type=str, help="Path to the checkpoint to start from"
)
parser.add_argument(
"--config_path", type=str, help="Path to the configuration file"
)
parser.add_argument("--data_path", type=str, help="Path to the training data")
parser.add_argument("--valid_path", type=str, help="Path to the validation data")
parser.add_argument(
"--results_path", type=str, help="Path to save training results"
)
parser.add_argument(
"--store_dir", type=str, help="Path to store validation/inference results"
)
parser.add_argument(
"--input_folder", type=str, help="Path to the input folder for inference"
)
parser.add_argument("--metrics", nargs="+", help="List of metrics to evaluate")
parser.add_argument(
"--max_folders", type=str, help="Maximum number of folders to process"
)
parser.add_argument(
"--dataset_type",
type=int,
default=1,
help="Dataset type. Must be one of: 1, 2, 3 or 4.",
)
parser.add_argument(
"--num_workers", type=int, default=0, help="dataloader num_workers"
)
parser.add_argument(
"--pin_memory", action="store_true", help="dataloader pin_memory"
)
parser.add_argument("--seed", type=int, default=0, help="random seed")
parser.add_argument(
"--use_multistft_loss",
action="store_true",
help="Use MultiSTFT Loss (from auraloss package)",
)
parser.add_argument(
"--use_mse_loss", action="store_true", help="Use default MSE loss"
)
parser.add_argument("--use_l1_loss", action="store_true", help="Use L1 loss")
parser.add_argument("--wandb_key", type=str, default="", help="wandb API Key")
parser.add_argument(
"--pre_valid", action="store_true", help="Run validation before training"
)
parser.add_argument(
"--metric_for_scheduler",
default="sdr",
choices=[
"sdr",
"l1_freq",
"si_sdr",
"neg_log_wmse",
"aura_stft",
"aura_mrstft",
"bleedless",
"fullness",
],
help="Metric which will be used for scheduler.",
)
parser.add_argument("--train_lora", action="store_true", help="Train with LoRA")
parser.add_argument(
"--lora_checkpoint",
type=str,
default="",
help="Initial checkpoint to LoRA weights",
)
parser.add_argument(
"--extension", type=str, default="wav", help="Choose extension for validation"
)
parser.add_argument(
"--use_tta",
action="store_true",
help="Flag adds test time augmentation during inference (polarity and channel inverse)."
" While this triples the runtime, it reduces noise and slightly improves prediction quality.",
)
parser.add_argument(
"--extract_instrumental",
action="store_true",
help="invert vocals to get instrumental if provided",
)
parser.add_argument(
"--disable_detailed_pbar",
action="store_true",
help="disable detailed progress bar",
)
parser.add_argument(
"--force_cpu",
action="store_true",
help="Force the use of CPU even if CUDA is available",
)
parser.add_argument(
"--flac_file", action="store_true", help="Output flac file instead of wav"
)
parser.add_argument(
"--pcm_type",
type=str,
choices=["PCM_16", "PCM_24"],
default="PCM_24",
help="PCM type for FLAC files (PCM_16 or PCM_24)",
)
parser.add_argument(
"--draw_spectro",
type=float,
default=0,
help="If --store_dir is set then code will generate spectrograms for resulted stems as well."
" Value defines for how many seconds os track spectrogram will be generated.",
)
if dict_args is not None:
args = parser.parse_args([])
args_dict = vars(args)
args_dict.update(dict_args)
args = argparse.Namespace(**args_dict)
else:
args = parser.parse_args()
return args
def test_settings(dict_args, test_type):
# Parse from cmd
cli_args = parse_args(dict_args)
# If args from cmd, add or replace in base_args
for key, value in vars(cli_args).items():
if value is not None:
base_args[key] = value
if test_type == "user":
# Check required arguments
missing_args = [
arg
for arg in [
"model_type",
"config_path",
"start_check_point",
"data_path",
"valid_path",
]
if not base_args[arg]
]
if missing_args:
missing_args_str = ", ".join(f"--{arg}" for arg in missing_args)
raise ValueError(
f"The following arguments are required but missing: {missing_args_str}."
f" Please specify them either via command-line arguments or directly in `base_args`."
)
# Replace config
base_args["config_path"] = redact_config(
{
"orig_config": base_args["config_path"],
"model_type": base_args["model_type"],
"new_config": "",
}
)
# Trim train
trim_args_train = {
"input_directory": base_args["data_path"],
"max_folders": base_args["max_folders"],
}
base_args["data_path"] = trim_directory(trim_args_train)
# Trim valid
trim_args_valid = {
"input_directory": base_args["valid_path"],
"max_folders": base_args["max_folders"],
}
base_args["valid_path"] = trim_directory(trim_args_valid)
# Valid to inference
if not base_args["input_folder"]:
tests_dir = os.path.join(
os.path.dirname(base_args["valid_path"]), "for_inference"
)
base_args["input_folder"] = tests_dir
val_to_inf_args = {
"valid_path": base_args["valid_path"],
"inference_dir": base_args["input_folder"],
"max_mixtures": 1,
}
copying_files(val_to_inf_args)
if base_args["check_valid"]:
valid_args = {
key: base_args[key]
for key in [
"model_type",
"config_path",
"start_check_point",
"store_dir",
"device_ids",
"num_workers",
"pin_memory",
"extension",
"use_tta",
"metrics",
"lora_checkpoint",
"draw_spectro",
]
}
valid_args["valid_path"] = [base_args["valid_path"]]
print("Start validation.")
check_validation(valid_args)
print(f"Validation ended. See results in {base_args['store_dir']}")
if base_args["check_inference"]:
inference_args = {
key: base_args[key]
for key in [
"model_type",
"config_path",
"start_check_point",
"input_folder",
"store_dir",
"device_ids",
"extract_instrumental",
"disable_detailed_pbar",
"force_cpu",
"flac_file",
"pcm_type",
"use_tta",
"lora_checkpoint",
"draw_spectro",
]
}
print("Start inference.")
proc_folder(inference_args)
print(f"Inference ended. See results in {base_args['store_dir']}")
if base_args["check_train"]:
train_args = {
key: base_args[key]
for key in [
"model_type",
"config_path",
"start_check_point",
"results_path",
"data_path",
"dataset_type",
"valid_path",
"num_workers",
"pin_memory",
"seed",
"device_ids",
"use_multistft_loss",
"use_mse_loss",
"use_l1_loss",
"wandb_key",
"pre_valid",
"metrics",
"metric_for_scheduler",
"train_lora",
"lora_checkpoint",
]
}
print("Start train.")
train_model(train_args)
print("End!")
if __name__ == "__main__":
test_settings(None, "user")