Spaces:
Running on Zero
Running on Zero
File size: 4,081 Bytes
64ec292 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 | import argparse
import os
import sys
from typing import Dict, Union
import yaml
from ml_collections import ConfigDict
from omegaconf import OmegaConf
def save_config(config: Union[ConfigDict, OmegaConf], save_path: str):
"""
Save a configuration object (ConfigDict or OmegaConf) to a file.
Parameters:
----------
config : Union[ConfigDict, OmegaConf]
The configuration object to save.
save_path : str
The path where the configuration file will be saved.
Raises:
------
ValueError:
If the configuration type is not supported.
"""
os.makedirs(os.path.dirname(save_path), exist_ok=True)
try:
with open(save_path, "w") as f:
if isinstance(config, ConfigDict):
yaml.dump(
config.to_dict(),
f,
default_flow_style=False,
sort_keys=False,
allow_unicode=True,
)
elif isinstance(config, OmegaConf):
OmegaConf.save(config, save_path)
else:
OmegaConf.save(config, save_path)
except Exception as e:
raise ValueError(
f"Error saving configuration: {e}."
f"Unsupported configuration type. Supported types: ConfigDict, OmegaConf."
f"Config type is {type(config)}"
)
def create_test_config(
original_config_path: str, new_config_path: str, model_type: str
):
"""
Create a test configuration file based on an existing configuration.
Parameters:
----------
original_config_path : str
Path to the original configuration file.
new_config_path : str
Path where the new configuration file will be saved.
model_type : str
The type of model (e.g., 'scnet', 'htdemucs').
Returns:
-------
None
"""
from utils.settings import load_config
config = load_config(model_type=model_type, config_path=original_config_path)
config["inference"]["batch_size"] = 1
config["training"]["batch_size"] = 1
config["training"]["gradient_accumulation_steps"] = 1
config["training"]["num_epochs"] = 2
config["training"]["num_steps"] = 3
save_config(config, new_config_path)
print(f"Test config created at: {new_config_path}")
def parse_args(dict_args: Union[Dict, None]) -> argparse.Namespace:
"""
Parse command-line arguments for configuring the model, dataset, and training parameters.
Args:
dict_args: Dict of command-line arguments. If None, arguments will be parsed from sys.argv.
Returns:
Namespace object containing parsed arguments and their values.
"""
parser = argparse.ArgumentParser()
parser.add_argument(
"--orig_config", type=str, default="", help="Path to the original config file."
)
parser.add_argument("--model_type", type=str, default="", help="Model type")
parser.add_argument(
"--new_config",
type=str,
default="",
help="Path to save the new test configuration file.",
)
if dict_args is not None:
args = parser.parse_args([])
args_dict = vars(args)
args_dict.update(dict_args)
args = argparse.Namespace(**args_dict)
else:
args = parser.parse_args()
# Determine the default path for the new configuration if not provided
if not args.new_config:
original_dir = os.path.dirname(args.orig_config)
tests_dir = os.path.join("tests_cache", original_dir)
os.makedirs(tests_dir, exist_ok=True)
args.new_config = os.path.join(tests_dir, os.path.basename(args.orig_config))
return args
def redact_config(args):
# Ensure proper imports for utilities
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
args = parse_args(args)
# Create the test configuration
create_test_config(args.orig_config, args.new_config, args.model_type)
return args.new_config
if __name__ == "__main__":
redact_config(None)
|