File size: 4,081 Bytes
64ec292
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import argparse
import os
import sys
from typing import Dict, Union

import yaml
from ml_collections import ConfigDict
from omegaconf import OmegaConf


def save_config(config: Union[ConfigDict, OmegaConf], save_path: str):
    """
    Save a configuration object (ConfigDict or OmegaConf) to a file.

    Parameters:
    ----------
    config : Union[ConfigDict, OmegaConf]
        The configuration object to save.
    save_path : str
        The path where the configuration file will be saved.

    Raises:
    ------
    ValueError:
        If the configuration type is not supported.
    """

    os.makedirs(os.path.dirname(save_path), exist_ok=True)

    try:
        with open(save_path, "w") as f:
            if isinstance(config, ConfigDict):
                yaml.dump(
                    config.to_dict(),
                    f,
                    default_flow_style=False,
                    sort_keys=False,
                    allow_unicode=True,
                )
            elif isinstance(config, OmegaConf):
                OmegaConf.save(config, save_path)
            else:
                OmegaConf.save(config, save_path)
    except Exception as e:
        raise ValueError(
            f"Error saving configuration: {e}."
            f"Unsupported configuration type. Supported types: ConfigDict, OmegaConf."
            f"Config type is {type(config)}"
        )


def create_test_config(
    original_config_path: str, new_config_path: str, model_type: str
):
    """
    Create a test configuration file based on an existing configuration.

    Parameters:
    ----------
    original_config_path : str
        Path to the original configuration file.
    new_config_path : str
        Path where the new configuration file will be saved.
    model_type : str
        The type of model (e.g., 'scnet', 'htdemucs').

    Returns:
    -------
    None
    """
    from utils.settings import load_config

    config = load_config(model_type=model_type, config_path=original_config_path)

    config["inference"]["batch_size"] = 1
    config["training"]["batch_size"] = 1
    config["training"]["gradient_accumulation_steps"] = 1
    config["training"]["num_epochs"] = 2
    config["training"]["num_steps"] = 3

    save_config(config, new_config_path)
    print(f"Test config created at: {new_config_path}")


def parse_args(dict_args: Union[Dict, None]) -> argparse.Namespace:
    """
    Parse command-line arguments for configuring the model, dataset, and training parameters.

    Args:
        dict_args: Dict of command-line arguments. If None, arguments will be parsed from sys.argv.

    Returns:
        Namespace object containing parsed arguments and their values.
    """

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--orig_config", type=str, default="", help="Path to the original config file."
    )
    parser.add_argument("--model_type", type=str, default="", help="Model type")
    parser.add_argument(
        "--new_config",
        type=str,
        default="",
        help="Path to save the new test configuration file.",
    )

    if dict_args is not None:
        args = parser.parse_args([])
        args_dict = vars(args)
        args_dict.update(dict_args)
        args = argparse.Namespace(**args_dict)
    else:
        args = parser.parse_args()

    # Determine the default path for the new configuration if not provided
    if not args.new_config:
        original_dir = os.path.dirname(args.orig_config)
        tests_dir = os.path.join("tests_cache", original_dir)
        os.makedirs(tests_dir, exist_ok=True)
        args.new_config = os.path.join(tests_dir, os.path.basename(args.orig_config))

    return args


def redact_config(args):
    # Ensure proper imports for utilities
    sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))

    args = parse_args(args)

    # Create the test configuration
    create_test_config(args.orig_config, args.new_config, args.model_type)
    return args.new_config


if __name__ == "__main__":
    redact_config(None)