import os import sys import numpy as np PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) if PROJECT_ROOT not in sys.path: sys.path.insert(0, PROJECT_ROOT) from data_preparation.prepare_dataset import ( SELECTED_FEATURES, _generate_synthetic_data, get_default_split_config, get_numpy_splits, ) def test_get_default_split_config(): ratios, seed = get_default_split_config() assert len(ratios) == 3 assert abs(sum(ratios) - 1.0) < 1e-6 assert seed >= 0 def test_generate_synthetic_data_shape(): X, y = _generate_synthetic_data("face_orientation") assert X.shape[0] == 500 assert y.shape[0] == 500 assert X.shape[1] == len(SELECTED_FEATURES["face_orientation"]) def test_get_numpy_splits_consistency(): split_ratios, seed = get_default_split_config() splits, num_features, num_classes, scaler = get_numpy_splits( "face_orientation", split_ratios=split_ratios, seed=seed ) n_train = len(splits["y_train"]) n_val = len(splits["y_val"]) n_test = len(splits["y_test"]) assert n_train > 0 assert n_val > 0 assert n_test > 0 assert splits["X_train"].shape[1] == num_features assert num_classes >= 2 # Same seed and ratios produce same split (deterministic) splits2, _, _, _ = get_numpy_splits( "face_orientation", split_ratios=split_ratios, seed=seed ) np.testing.assert_array_equal(splits["y_test"], splits2["y_test"])