import argparse import os import sys from typing import Dict, Union import yaml from ml_collections import ConfigDict from omegaconf import OmegaConf def save_config(config: Union[ConfigDict, OmegaConf], save_path: str): """ Save a configuration object (ConfigDict or OmegaConf) to a file. Parameters: ---------- config : Union[ConfigDict, OmegaConf] The configuration object to save. save_path : str The path where the configuration file will be saved. Raises: ------ ValueError: If the configuration type is not supported. """ os.makedirs(os.path.dirname(save_path), exist_ok=True) try: with open(save_path, "w") as f: if isinstance(config, ConfigDict): yaml.dump( config.to_dict(), f, default_flow_style=False, sort_keys=False, allow_unicode=True, ) elif isinstance(config, OmegaConf): OmegaConf.save(config, save_path) else: OmegaConf.save(config, save_path) except Exception as e: raise ValueError( f"Error saving configuration: {e}." f"Unsupported configuration type. Supported types: ConfigDict, OmegaConf." f"Config type is {type(config)}" ) def create_test_config( original_config_path: str, new_config_path: str, model_type: str ): """ Create a test configuration file based on an existing configuration. Parameters: ---------- original_config_path : str Path to the original configuration file. new_config_path : str Path where the new configuration file will be saved. model_type : str The type of model (e.g., 'scnet', 'htdemucs'). Returns: ------- None """ from utils.settings import load_config config = load_config(model_type=model_type, config_path=original_config_path) config["inference"]["batch_size"] = 1 config["training"]["batch_size"] = 1 config["training"]["gradient_accumulation_steps"] = 1 config["training"]["num_epochs"] = 2 config["training"]["num_steps"] = 3 save_config(config, new_config_path) print(f"Test config created at: {new_config_path}") def parse_args(dict_args: Union[Dict, None]) -> argparse.Namespace: """ Parse command-line arguments for configuring the model, dataset, and training parameters. Args: dict_args: Dict of command-line arguments. If None, arguments will be parsed from sys.argv. Returns: Namespace object containing parsed arguments and their values. """ parser = argparse.ArgumentParser() parser.add_argument( "--orig_config", type=str, default="", help="Path to the original config file." ) parser.add_argument("--model_type", type=str, default="", help="Model type") parser.add_argument( "--new_config", type=str, default="", help="Path to save the new test configuration file.", ) if dict_args is not None: args = parser.parse_args([]) args_dict = vars(args) args_dict.update(dict_args) args = argparse.Namespace(**args_dict) else: args = parser.parse_args() # Determine the default path for the new configuration if not provided if not args.new_config: original_dir = os.path.dirname(args.orig_config) tests_dir = os.path.join("tests_cache", original_dir) os.makedirs(tests_dir, exist_ok=True) args.new_config = os.path.join(tests_dir, os.path.basename(args.orig_config)) return args def redact_config(args): # Ensure proper imports for utilities sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) args = parse_args(args) # Create the test configuration create_test_config(args.orig_config, args.new_config, args.model_type) return args.new_config if __name__ == "__main__": redact_config(None)