Spaces:
Running
Running
File size: 1,476 Bytes
3e5b046 eb4abb8 3e5b046 eb4abb8 3e5b046 eb4abb8 3e5b046 eb4abb8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 | import os
import sys
import numpy as np
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if PROJECT_ROOT not in sys.path:
sys.path.insert(0, PROJECT_ROOT)
from data_preparation.prepare_dataset import (
SELECTED_FEATURES,
_generate_synthetic_data,
get_default_split_config,
get_numpy_splits,
)
def test_get_default_split_config():
ratios, seed = get_default_split_config()
assert len(ratios) == 3
assert abs(sum(ratios) - 1.0) < 1e-6
assert seed >= 0
def test_generate_synthetic_data_shape():
X, y = _generate_synthetic_data("face_orientation")
assert X.shape[0] == 500
assert y.shape[0] == 500
assert X.shape[1] == len(SELECTED_FEATURES["face_orientation"])
def test_get_numpy_splits_consistency():
split_ratios, seed = get_default_split_config()
splits, num_features, num_classes, scaler = get_numpy_splits(
"face_orientation", split_ratios=split_ratios, seed=seed
)
n_train = len(splits["y_train"])
n_val = len(splits["y_val"])
n_test = len(splits["y_test"])
assert n_train > 0
assert n_val > 0
assert n_test > 0
assert splits["X_train"].shape[1] == num_features
assert num_classes >= 2
# Same seed and ratios produce same split (deterministic)
splits2, _, _, _ = get_numpy_splits(
"face_orientation", split_ratios=split_ratios, seed=seed
)
np.testing.assert_array_equal(splits["y_test"], splits2["y_test"])
|