import os from pathlib import Path from typing import Dict, List import numpy as np import soundfile as sf from scripts.redact_config import redact_config from test import test_settings from utils.settings import load_config MODEL_CONFIGS = { "config_apollo.yaml": {"model_type": "apollo"}, "config_dnr_bandit_bsrnn_multi_mus64.yaml": {"model_type": "bandit"}, "config_dnr_bandit_v2_mus64.yaml": {"model_type": "bandit_v2"}, "config_drumsep.yaml": {"model_type": "htdemucs"}, "config_htdemucs_6stems.yaml": {"model_type": "htdemucs"}, "config_musdb18_bs_roformer.yaml": {"model_type": "bs_roformer"}, "config_musdb18_demucs3_mmi.yaml": {"model_type": "htdemucs"}, "config_musdb18_htdemucs.yaml": {"model_type": "htdemucs"}, "config_musdb18_mdx23c.yaml": {"model_type": "mdx23c"}, "config_musdb18_mel_band_roformer.yaml": {"model_type": "mel_band_roformer"}, "config_musdb18_mel_band_roformer_all_stems.yaml": { "model_type": "mel_band_roformer" }, "config_musdb18_scnet.yaml": {"model_type": "scnet"}, "config_musdb18_scnet_large.yaml": {"model_type": "scnet"}, # 'config_musdb18_scnet_large_starrytong.yaml': {'model_type': 'scnet'}, "config_vocals_bandit_bsrnn_multi_mus64.yaml": {"model_type": "bandit"}, "config_vocals_bs_roformer.yaml": {"model_type": "bs_roformer"}, "config_vocals_htdemucs.yaml": {"model_type": "htdemucs"}, "config_vocals_mdx23c.yaml": {"model_type": "mdx23c"}, "config_vocals_mel_band_roformer.yaml": {"model_type": "mel_band_roformer"}, "config_vocals_scnet.yaml": {"model_type": "scnet"}, "config_vocals_scnet_large.yaml": {"model_type": "scnet"}, "config_vocals_scnet_unofficial.yaml": {"model_type": "scnet_unofficial"}, "config_vocals_segm_models.yaml": {"model_type": "segm_models"}, # 'config_vocals_swin_upernet.yaml': {'model_type': 'swin_upernet'}, # 'config_musdb18_torchseg.yaml': {'model_type': 'torchseg'}, # 'config_musdb18_segm_models.yaml': {'model_type': 'segm_models'}, # 'config_musdb18_bs_mamba2.yaml': {'model_type': 'bs_mamba2'}, # 'config_vocals_bs_mamba2.yaml': {'model_type': 'bs_mamba2'}, # 'config_vocals_torchseg.yaml': {'model_type': 'torchseg'} } # Folders for tests ROOT_DIR = Path(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) CONFIGS_DIR = ROOT_DIR / "configs/" TEST_DIR = ROOT_DIR / "tests_cache/" TRAIN_DIR = TEST_DIR / "train_tracks/" VALID_DIR = TEST_DIR / "valid_tracks/" def create_dummy_tracks( directory: Path, num_tracks: int, instruments: List[str], duration: float = 5.0, sample_rate: int = 44100, ) -> None: """ Generates random audio tracks for stems in two subdirectories within the specified directory. Parameters: ---------- directory : Path Path to the directory where the tracks will be saved. num_tracks : int Number of tracks to generate in each folder. instruments : List[str] List of instrument names (stems) to create. duration : float, optional Duration of each track in seconds. Default is 5.0. sample_rate : int, optional Sampling rate of the generated audio. Default is 44100 Hz. Returns: ------- None """ os.makedirs(directory, exist_ok=True) for folder_name in [str(i) for i in range(1, num_tracks + 1)]: folder_path = directory / folder_name os.makedirs(folder_path, exist_ok=True) for instrument in instruments: # Generate random noice for each track samples = int(duration * sample_rate) track = np.random.uniform(-1.0, 1.0, (2, samples)).astype(np.float32) file_path = folder_path / f"{instrument}.wav" sf.write(file_path, track.T, sample_rate) def cleanup_test_tracks() -> None: """ Removes all cached test tracks. This function deletes the entire directory specified by the global `TEST_DIR` variable if it exists. Returns: ------- None This function does not return a value. It performs cleanup of test data. """ def modify_configs() -> Dict[str, Path]: """ Updates configuration files in the `configs` directory for use with test data. This function processes configuration files defined in the global `MODEL_CONFIGS` dictionary, modifies them to be compatible with test scenarios, and saves the updated configurations in a test-specific directory. Returns: ------- Dict[str, Path] A dictionary where the keys are the original configuration file names, and the values are the paths to the updated configuration files. """ config_dir = CONFIGS_DIR updated_configs = {} for config, args in MODEL_CONFIGS.items(): model_type = args["model_type"] config_path = config_dir / config updated_config_path = redact_config( { "orig_config": str(config_path), "model_type": model_type, "new_config": str(TEST_DIR / "configs" / config), } ) updated_configs[config] = updated_config_path return updated_configs def run_tests() -> None: """ Executes validation tests for all configurations. This function updates configurations, generates random dummy data for testing, and runs a series of tests (training, validation, and inference checks) for each model configuration specified in the global `MODEL_CONFIGS` dictionary. Returns: ------- None """ updated_configs = modify_configs() # For every config for config, args in MODEL_CONFIGS.items(): model_type = args["model_type"] cfg = load_config( model_type=model_type, config_path=TEST_DIR / "configs" / config ) # Random tracks create_dummy_tracks( TRAIN_DIR, instruments=cfg.training.instruments + ["mixture"], num_tracks=2 ) create_dummy_tracks( VALID_DIR, instruments=cfg.training.instruments + ["mixture"], num_tracks=2 ) print(f"\nRunning tests for model: {model_type} (config: {config})") test_args = { "check_train": False, "check_valid": True, "check_inference": True, "config_path": updated_configs[config], "data_path": str(TRAIN_DIR), "valid_path": str(VALID_DIR), "results_path": str(TEST_DIR / "results" / model_type), "store_dir": str(TEST_DIR / "inference_results" / model_type), "metrics": ["sdr", "si_sdr", "l1_freq"], } test_args.update(args) test_settings(test_args, "admin") print(f"Tests for model {model_type} completed successfully.") # Remove test_cache cleanup_test_tracks() if __name__ == "__main__": run_tests()