Spaces:
Runtime error
Runtime error
| import argparse | |
| import os | |
| import sys | |
| # Добавляем корень репозитория в системный путь | |
| sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) | |
| from inference import proc_folder | |
| from scripts.redact_config import redact_config | |
| from scripts.trim import trim_directory | |
| from scripts.valid_to_inference import copying_files | |
| from train import train_model | |
| from valid import check_validation | |
| base_args = { | |
| "device_ids": "0", | |
| "model_type": "", | |
| "start_check_point": "", | |
| "config_path": "", | |
| "data_path": "", | |
| "valid_path": "", | |
| "results_path": "tests/train_results", | |
| "store_dir": "tests/valid_inference_result", | |
| "input_folder": "", | |
| "metrics": [ | |
| "neg_log_wmse", | |
| "l1_freq", | |
| "si_sdr", | |
| "sdr", | |
| "aura_stft", | |
| "aura_mrstft", | |
| "bleedless", | |
| "fullness", | |
| ], | |
| "max_folders": 2, | |
| } | |
| def parse_args(dict_args): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--check_train", action="store_true", help="Check train or not") | |
| parser.add_argument("--check_valid", action="store_true", help="Check train or not") | |
| parser.add_argument( | |
| "--check_inference", action="store_true", help="Check train or not" | |
| ) | |
| parser.add_argument( | |
| "--device_ids", type=str, help="Device IDs for training/inference" | |
| ) | |
| parser.add_argument("--model_type", type=str, help="Model type") | |
| parser.add_argument( | |
| "--start_check_point", type=str, help="Path to the checkpoint to start from" | |
| ) | |
| parser.add_argument( | |
| "--config_path", type=str, help="Path to the configuration file" | |
| ) | |
| parser.add_argument("--data_path", type=str, help="Path to the training data") | |
| parser.add_argument("--valid_path", type=str, help="Path to the validation data") | |
| parser.add_argument( | |
| "--results_path", type=str, help="Path to save training results" | |
| ) | |
| parser.add_argument( | |
| "--store_dir", type=str, help="Path to store validation/inference results" | |
| ) | |
| parser.add_argument( | |
| "--input_folder", type=str, help="Path to the input folder for inference" | |
| ) | |
| parser.add_argument("--metrics", nargs="+", help="List of metrics to evaluate") | |
| parser.add_argument( | |
| "--max_folders", type=str, help="Maximum number of folders to process" | |
| ) | |
| parser.add_argument( | |
| "--dataset_type", | |
| type=int, | |
| default=1, | |
| help="Dataset type. Must be one of: 1, 2, 3 or 4.", | |
| ) | |
| parser.add_argument( | |
| "--num_workers", type=int, default=0, help="dataloader num_workers" | |
| ) | |
| parser.add_argument( | |
| "--pin_memory", action="store_true", help="dataloader pin_memory" | |
| ) | |
| parser.add_argument("--seed", type=int, default=0, help="random seed") | |
| parser.add_argument( | |
| "--use_multistft_loss", | |
| action="store_true", | |
| help="Use MultiSTFT Loss (from auraloss package)", | |
| ) | |
| parser.add_argument( | |
| "--use_mse_loss", action="store_true", help="Use default MSE loss" | |
| ) | |
| parser.add_argument("--use_l1_loss", action="store_true", help="Use L1 loss") | |
| parser.add_argument("--wandb_key", type=str, default="", help="wandb API Key") | |
| parser.add_argument( | |
| "--pre_valid", action="store_true", help="Run validation before training" | |
| ) | |
| parser.add_argument( | |
| "--metric_for_scheduler", | |
| default="sdr", | |
| choices=[ | |
| "sdr", | |
| "l1_freq", | |
| "si_sdr", | |
| "neg_log_wmse", | |
| "aura_stft", | |
| "aura_mrstft", | |
| "bleedless", | |
| "fullness", | |
| ], | |
| help="Metric which will be used for scheduler.", | |
| ) | |
| parser.add_argument("--train_lora", action="store_true", help="Train with LoRA") | |
| parser.add_argument( | |
| "--lora_checkpoint", | |
| type=str, | |
| default="", | |
| help="Initial checkpoint to LoRA weights", | |
| ) | |
| parser.add_argument( | |
| "--extension", type=str, default="wav", help="Choose extension for validation" | |
| ) | |
| parser.add_argument( | |
| "--use_tta", | |
| action="store_true", | |
| help="Flag adds test time augmentation during inference (polarity and channel inverse)." | |
| " While this triples the runtime, it reduces noise and slightly improves prediction quality.", | |
| ) | |
| parser.add_argument( | |
| "--extract_instrumental", | |
| action="store_true", | |
| help="invert vocals to get instrumental if provided", | |
| ) | |
| parser.add_argument( | |
| "--disable_detailed_pbar", | |
| action="store_true", | |
| help="disable detailed progress bar", | |
| ) | |
| parser.add_argument( | |
| "--force_cpu", | |
| action="store_true", | |
| help="Force the use of CPU even if CUDA is available", | |
| ) | |
| parser.add_argument( | |
| "--flac_file", action="store_true", help="Output flac file instead of wav" | |
| ) | |
| parser.add_argument( | |
| "--pcm_type", | |
| type=str, | |
| choices=["PCM_16", "PCM_24"], | |
| default="PCM_24", | |
| help="PCM type for FLAC files (PCM_16 or PCM_24)", | |
| ) | |
| parser.add_argument( | |
| "--draw_spectro", | |
| type=float, | |
| default=0, | |
| help="If --store_dir is set then code will generate spectrograms for resulted stems as well." | |
| " Value defines for how many seconds os track spectrogram will be generated.", | |
| ) | |
| if dict_args is not None: | |
| args = parser.parse_args([]) | |
| args_dict = vars(args) | |
| args_dict.update(dict_args) | |
| args = argparse.Namespace(**args_dict) | |
| else: | |
| args = parser.parse_args() | |
| return args | |
| def test_settings(dict_args, test_type): | |
| # Parse from cmd | |
| cli_args = parse_args(dict_args) | |
| # If args from cmd, add or replace in base_args | |
| for key, value in vars(cli_args).items(): | |
| if value is not None: | |
| base_args[key] = value | |
| if test_type == "user": | |
| # Check required arguments | |
| missing_args = [ | |
| arg | |
| for arg in [ | |
| "model_type", | |
| "config_path", | |
| "start_check_point", | |
| "data_path", | |
| "valid_path", | |
| ] | |
| if not base_args[arg] | |
| ] | |
| if missing_args: | |
| missing_args_str = ", ".join(f"--{arg}" for arg in missing_args) | |
| raise ValueError( | |
| f"The following arguments are required but missing: {missing_args_str}." | |
| f" Please specify them either via command-line arguments or directly in `base_args`." | |
| ) | |
| # Replace config | |
| base_args["config_path"] = redact_config( | |
| { | |
| "orig_config": base_args["config_path"], | |
| "model_type": base_args["model_type"], | |
| "new_config": "", | |
| } | |
| ) | |
| # Trim train | |
| trim_args_train = { | |
| "input_directory": base_args["data_path"], | |
| "max_folders": base_args["max_folders"], | |
| } | |
| base_args["data_path"] = trim_directory(trim_args_train) | |
| # Trim valid | |
| trim_args_valid = { | |
| "input_directory": base_args["valid_path"], | |
| "max_folders": base_args["max_folders"], | |
| } | |
| base_args["valid_path"] = trim_directory(trim_args_valid) | |
| # Valid to inference | |
| if not base_args["input_folder"]: | |
| tests_dir = os.path.join( | |
| os.path.dirname(base_args["valid_path"]), "for_inference" | |
| ) | |
| base_args["input_folder"] = tests_dir | |
| val_to_inf_args = { | |
| "valid_path": base_args["valid_path"], | |
| "inference_dir": base_args["input_folder"], | |
| "max_mixtures": 1, | |
| } | |
| copying_files(val_to_inf_args) | |
| if base_args["check_valid"]: | |
| valid_args = { | |
| key: base_args[key] | |
| for key in [ | |
| "model_type", | |
| "config_path", | |
| "start_check_point", | |
| "store_dir", | |
| "device_ids", | |
| "num_workers", | |
| "pin_memory", | |
| "extension", | |
| "use_tta", | |
| "metrics", | |
| "lora_checkpoint", | |
| "draw_spectro", | |
| ] | |
| } | |
| valid_args["valid_path"] = [base_args["valid_path"]] | |
| print("Start validation.") | |
| check_validation(valid_args) | |
| print(f"Validation ended. See results in {base_args['store_dir']}") | |
| if base_args["check_inference"]: | |
| inference_args = { | |
| key: base_args[key] | |
| for key in [ | |
| "model_type", | |
| "config_path", | |
| "start_check_point", | |
| "input_folder", | |
| "store_dir", | |
| "device_ids", | |
| "extract_instrumental", | |
| "disable_detailed_pbar", | |
| "force_cpu", | |
| "flac_file", | |
| "pcm_type", | |
| "use_tta", | |
| "lora_checkpoint", | |
| "draw_spectro", | |
| ] | |
| } | |
| print("Start inference.") | |
| proc_folder(inference_args) | |
| print(f"Inference ended. See results in {base_args['store_dir']}") | |
| if base_args["check_train"]: | |
| train_args = { | |
| key: base_args[key] | |
| for key in [ | |
| "model_type", | |
| "config_path", | |
| "start_check_point", | |
| "results_path", | |
| "data_path", | |
| "dataset_type", | |
| "valid_path", | |
| "num_workers", | |
| "pin_memory", | |
| "seed", | |
| "device_ids", | |
| "use_multistft_loss", | |
| "use_mse_loss", | |
| "use_l1_loss", | |
| "wandb_key", | |
| "pre_valid", | |
| "metrics", | |
| "metric_for_scheduler", | |
| "train_lora", | |
| "lora_checkpoint", | |
| ] | |
| } | |
| print("Start train.") | |
| train_model(train_args) | |
| print("End!") | |
| if __name__ == "__main__": | |
| test_settings(None, "user") | |