xjsc0's picture
1
64ec292
import argparse
import os
import sys
from typing import Dict, Union
import yaml
from ml_collections import ConfigDict
from omegaconf import OmegaConf
def save_config(config: Union[ConfigDict, OmegaConf], save_path: str):
"""
Save a configuration object (ConfigDict or OmegaConf) to a file.
Parameters:
----------
config : Union[ConfigDict, OmegaConf]
The configuration object to save.
save_path : str
The path where the configuration file will be saved.
Raises:
------
ValueError:
If the configuration type is not supported.
"""
os.makedirs(os.path.dirname(save_path), exist_ok=True)
try:
with open(save_path, "w") as f:
if isinstance(config, ConfigDict):
yaml.dump(
config.to_dict(),
f,
default_flow_style=False,
sort_keys=False,
allow_unicode=True,
)
elif isinstance(config, OmegaConf):
OmegaConf.save(config, save_path)
else:
OmegaConf.save(config, save_path)
except Exception as e:
raise ValueError(
f"Error saving configuration: {e}."
f"Unsupported configuration type. Supported types: ConfigDict, OmegaConf."
f"Config type is {type(config)}"
)
def create_test_config(
original_config_path: str, new_config_path: str, model_type: str
):
"""
Create a test configuration file based on an existing configuration.
Parameters:
----------
original_config_path : str
Path to the original configuration file.
new_config_path : str
Path where the new configuration file will be saved.
model_type : str
The type of model (e.g., 'scnet', 'htdemucs').
Returns:
-------
None
"""
from utils.settings import load_config
config = load_config(model_type=model_type, config_path=original_config_path)
config["inference"]["batch_size"] = 1
config["training"]["batch_size"] = 1
config["training"]["gradient_accumulation_steps"] = 1
config["training"]["num_epochs"] = 2
config["training"]["num_steps"] = 3
save_config(config, new_config_path)
print(f"Test config created at: {new_config_path}")
def parse_args(dict_args: Union[Dict, None]) -> argparse.Namespace:
"""
Parse command-line arguments for configuring the model, dataset, and training parameters.
Args:
dict_args: Dict of command-line arguments. If None, arguments will be parsed from sys.argv.
Returns:
Namespace object containing parsed arguments and their values.
"""
parser = argparse.ArgumentParser()
parser.add_argument(
"--orig_config", type=str, default="", help="Path to the original config file."
)
parser.add_argument("--model_type", type=str, default="", help="Model type")
parser.add_argument(
"--new_config",
type=str,
default="",
help="Path to save the new test configuration file.",
)
if dict_args is not None:
args = parser.parse_args([])
args_dict = vars(args)
args_dict.update(dict_args)
args = argparse.Namespace(**args_dict)
else:
args = parser.parse_args()
# Determine the default path for the new configuration if not provided
if not args.new_config:
original_dir = os.path.dirname(args.orig_config)
tests_dir = os.path.join("tests_cache", original_dir)
os.makedirs(tests_dir, exist_ok=True)
args.new_config = os.path.join(tests_dir, os.path.basename(args.orig_config))
return args
def redact_config(args):
# Ensure proper imports for utilities
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
args = parse_args(args)
# Create the test configuration
create_test_config(args.orig_config, args.new_config, args.model_type)
return args.new_config
if __name__ == "__main__":
redact_config(None)