File size: 1,704 Bytes
22a6915
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import glob
import os
import sys

import numpy as np
import pytest


PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if PROJECT_ROOT not in sys.path:
    sys.path.insert(0, PROJECT_ROOT)

from data_preparation.prepare_dataset import (
    SELECTED_FEATURES,
    _generate_synthetic_data,
    get_default_split_config,
    get_numpy_splits,
)


def test_get_default_split_config():
    ratios, seed = get_default_split_config()
    assert len(ratios) == 3
    assert abs(sum(ratios) - 1.0) < 1e-6
    assert seed >= 0


def test_generate_synthetic_data_shape():
    X, y = _generate_synthetic_data("face_orientation")
    assert X.shape[0] == 500
    assert y.shape[0] == 500
    assert X.shape[1] == len(SELECTED_FEATURES["face_orientation"])


def test_get_numpy_splits_consistency():
    pattern = os.path.join(PROJECT_ROOT, "data", "collected_*", "*.npz")
    if not glob.glob(pattern):
        pytest.skip("No data/collected_*/*.npz — run collect_features or add dataset files.")

    split_ratios, seed = get_default_split_config()
    splits, num_features, num_classes, scaler = get_numpy_splits(
        "face_orientation", split_ratios=split_ratios, seed=seed
    )

    n_train = len(splits["y_train"])
    n_val = len(splits["y_val"])
    n_test = len(splits["y_test"])
    assert n_train > 0
    assert n_val > 0
    assert n_test > 0
    assert splits["X_train"].shape[1] == num_features
    assert num_classes >= 2

    # Same seed and ratios produce same split (deterministic)
    splits2, _, _, _ = get_numpy_splits(
        "face_orientation", split_ratios=split_ratios, seed=seed
    )
    np.testing.assert_array_equal(splits["y_test"], splits2["y_test"])