File size: 1,008 Bytes
4a5bfab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e69e3a3
4a5bfab
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import os
import sys
import numpy as np


PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if PROJECT_ROOT not in sys.path:
    sys.path.insert(0, PROJECT_ROOT)

from data_preparation.prepare_dataset import (
    SELECTED_FEATURES,
    _generate_synthetic_data,
    get_numpy_splits,
)


def test_generate_synthetic_data_shape():
    X, y = _generate_synthetic_data("face_orientation")
    assert X.shape[0] == 500
    assert y.shape[0] == 500
    assert X.shape[1] == len(SELECTED_FEATURES["face_orientation"])


def test_get_numpy_splits_consistency():
    splits, num_features, num_classes, scaler = get_numpy_splits("face_orientation")

    # train/val/test each have samples
    n_train = len(splits["y_train"])
    n_val = len(splits["y_val"])
    n_test = len(splits["y_test"])
    assert n_train > 0
    assert n_val > 0
    assert n_test > 0

    # feature dim should same as num_features
    assert splits["X_train"].shape[1] == num_features

    assert num_classes >= 2