import argparse import os import sys # Добавляем корень репозитория в системный путь sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) from inference import proc_folder from scripts.redact_config import redact_config from scripts.trim import trim_directory from scripts.valid_to_inference import copying_files from train import train_model from valid import check_validation base_args = { "device_ids": "0", "model_type": "", "start_check_point": "", "config_path": "", "data_path": "", "valid_path": "", "results_path": "tests/train_results", "store_dir": "tests/valid_inference_result", "input_folder": "", "metrics": [ "neg_log_wmse", "l1_freq", "si_sdr", "sdr", "aura_stft", "aura_mrstft", "bleedless", "fullness", ], "max_folders": 2, } def parse_args(dict_args): parser = argparse.ArgumentParser() parser.add_argument("--check_train", action="store_true", help="Check train or not") parser.add_argument("--check_valid", action="store_true", help="Check train or not") parser.add_argument( "--check_inference", action="store_true", help="Check train or not" ) parser.add_argument( "--device_ids", type=str, help="Device IDs for training/inference" ) parser.add_argument("--model_type", type=str, help="Model type") parser.add_argument( "--start_check_point", type=str, help="Path to the checkpoint to start from" ) parser.add_argument( "--config_path", type=str, help="Path to the configuration file" ) parser.add_argument("--data_path", type=str, help="Path to the training data") parser.add_argument("--valid_path", type=str, help="Path to the validation data") parser.add_argument( "--results_path", type=str, help="Path to save training results" ) parser.add_argument( "--store_dir", type=str, help="Path to store validation/inference results" ) parser.add_argument( "--input_folder", type=str, help="Path to the input folder for inference" ) parser.add_argument("--metrics", nargs="+", help="List of metrics to evaluate") parser.add_argument( "--max_folders", type=str, help="Maximum number of folders to process" ) parser.add_argument( "--dataset_type", type=int, default=1, help="Dataset type. Must be one of: 1, 2, 3 or 4.", ) parser.add_argument( "--num_workers", type=int, default=0, help="dataloader num_workers" ) parser.add_argument( "--pin_memory", action="store_true", help="dataloader pin_memory" ) parser.add_argument("--seed", type=int, default=0, help="random seed") parser.add_argument( "--use_multistft_loss", action="store_true", help="Use MultiSTFT Loss (from auraloss package)", ) parser.add_argument( "--use_mse_loss", action="store_true", help="Use default MSE loss" ) parser.add_argument("--use_l1_loss", action="store_true", help="Use L1 loss") parser.add_argument("--wandb_key", type=str, default="", help="wandb API Key") parser.add_argument( "--pre_valid", action="store_true", help="Run validation before training" ) parser.add_argument( "--metric_for_scheduler", default="sdr", choices=[ "sdr", "l1_freq", "si_sdr", "neg_log_wmse", "aura_stft", "aura_mrstft", "bleedless", "fullness", ], help="Metric which will be used for scheduler.", ) parser.add_argument("--train_lora", action="store_true", help="Train with LoRA") parser.add_argument( "--lora_checkpoint", type=str, default="", help="Initial checkpoint to LoRA weights", ) parser.add_argument( "--extension", type=str, default="wav", help="Choose extension for validation" ) parser.add_argument( "--use_tta", action="store_true", help="Flag adds test time augmentation during inference (polarity and channel inverse)." " While this triples the runtime, it reduces noise and slightly improves prediction quality.", ) parser.add_argument( "--extract_instrumental", action="store_true", help="invert vocals to get instrumental if provided", ) parser.add_argument( "--disable_detailed_pbar", action="store_true", help="disable detailed progress bar", ) parser.add_argument( "--force_cpu", action="store_true", help="Force the use of CPU even if CUDA is available", ) parser.add_argument( "--flac_file", action="store_true", help="Output flac file instead of wav" ) parser.add_argument( "--pcm_type", type=str, choices=["PCM_16", "PCM_24"], default="PCM_24", help="PCM type for FLAC files (PCM_16 or PCM_24)", ) parser.add_argument( "--draw_spectro", type=float, default=0, help="If --store_dir is set then code will generate spectrograms for resulted stems as well." " Value defines for how many seconds os track spectrogram will be generated.", ) if dict_args is not None: args = parser.parse_args([]) args_dict = vars(args) args_dict.update(dict_args) args = argparse.Namespace(**args_dict) else: args = parser.parse_args() return args def test_settings(dict_args, test_type): # Parse from cmd cli_args = parse_args(dict_args) # If args from cmd, add or replace in base_args for key, value in vars(cli_args).items(): if value is not None: base_args[key] = value if test_type == "user": # Check required arguments missing_args = [ arg for arg in [ "model_type", "config_path", "start_check_point", "data_path", "valid_path", ] if not base_args[arg] ] if missing_args: missing_args_str = ", ".join(f"--{arg}" for arg in missing_args) raise ValueError( f"The following arguments are required but missing: {missing_args_str}." f" Please specify them either via command-line arguments or directly in `base_args`." ) # Replace config base_args["config_path"] = redact_config( { "orig_config": base_args["config_path"], "model_type": base_args["model_type"], "new_config": "", } ) # Trim train trim_args_train = { "input_directory": base_args["data_path"], "max_folders": base_args["max_folders"], } base_args["data_path"] = trim_directory(trim_args_train) # Trim valid trim_args_valid = { "input_directory": base_args["valid_path"], "max_folders": base_args["max_folders"], } base_args["valid_path"] = trim_directory(trim_args_valid) # Valid to inference if not base_args["input_folder"]: tests_dir = os.path.join( os.path.dirname(base_args["valid_path"]), "for_inference" ) base_args["input_folder"] = tests_dir val_to_inf_args = { "valid_path": base_args["valid_path"], "inference_dir": base_args["input_folder"], "max_mixtures": 1, } copying_files(val_to_inf_args) if base_args["check_valid"]: valid_args = { key: base_args[key] for key in [ "model_type", "config_path", "start_check_point", "store_dir", "device_ids", "num_workers", "pin_memory", "extension", "use_tta", "metrics", "lora_checkpoint", "draw_spectro", ] } valid_args["valid_path"] = [base_args["valid_path"]] print("Start validation.") check_validation(valid_args) print(f"Validation ended. See results in {base_args['store_dir']}") if base_args["check_inference"]: inference_args = { key: base_args[key] for key in [ "model_type", "config_path", "start_check_point", "input_folder", "store_dir", "device_ids", "extract_instrumental", "disable_detailed_pbar", "force_cpu", "flac_file", "pcm_type", "use_tta", "lora_checkpoint", "draw_spectro", ] } print("Start inference.") proc_folder(inference_args) print(f"Inference ended. See results in {base_args['store_dir']}") if base_args["check_train"]: train_args = { key: base_args[key] for key in [ "model_type", "config_path", "start_check_point", "results_path", "data_path", "dataset_type", "valid_path", "num_workers", "pin_memory", "seed", "device_ids", "use_multistft_loss", "use_mse_loss", "use_l1_loss", "wandb_key", "pre_valid", "metrics", "metric_for_scheduler", "train_lora", "lora_checkpoint", ] } print("Start train.") train_model(train_args) print("End!") if __name__ == "__main__": test_settings(None, "user")