File size: 1,476 Bytes
3e5b046
 
 
 
 
 
 
 
 
 
 
 
eb4abb8
3e5b046
 
 
 
eb4abb8
 
 
 
 
 
 
3e5b046
 
 
 
 
 
 
 
eb4abb8
 
 
 
3e5b046
 
 
 
 
 
 
 
 
 
eb4abb8
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import os
import sys
import numpy as np


PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if PROJECT_ROOT not in sys.path:
    sys.path.insert(0, PROJECT_ROOT)

from data_preparation.prepare_dataset import (
    SELECTED_FEATURES,
    _generate_synthetic_data,
    get_default_split_config,
    get_numpy_splits,
)


def test_get_default_split_config():
    ratios, seed = get_default_split_config()
    assert len(ratios) == 3
    assert abs(sum(ratios) - 1.0) < 1e-6
    assert seed >= 0


def test_generate_synthetic_data_shape():
    X, y = _generate_synthetic_data("face_orientation")
    assert X.shape[0] == 500
    assert y.shape[0] == 500
    assert X.shape[1] == len(SELECTED_FEATURES["face_orientation"])


def test_get_numpy_splits_consistency():
    split_ratios, seed = get_default_split_config()
    splits, num_features, num_classes, scaler = get_numpy_splits(
        "face_orientation", split_ratios=split_ratios, seed=seed
    )

    n_train = len(splits["y_train"])
    n_val = len(splits["y_val"])
    n_test = len(splits["y_test"])
    assert n_train > 0
    assert n_val > 0
    assert n_test > 0
    assert splits["X_train"].shape[1] == num_features
    assert num_classes >= 2

    # Same seed and ratios produce same split (deterministic)
    splits2, _, _, _ = get_numpy_splits(
        "face_orientation", split_ratios=split_ratios, seed=seed
    )
    np.testing.assert_array_equal(splits["y_test"], splits2["y_test"])