IntegrationTest / tests /test_data_preparation.py
Yingtao-Zheng's picture
Add other files and folders, including data related, notebook, test and evaluation
24a5e7e
import os
import sys
import numpy as np
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if PROJECT_ROOT not in sys.path:
sys.path.insert(0, PROJECT_ROOT)
from data_preparation.prepare_dataset import (
SELECTED_FEATURES,
_generate_synthetic_data,
get_numpy_splits,
)
def test_generate_synthetic_data_shape():
X, y = _generate_synthetic_data("face_orientation")
assert X.shape[0] == 500
assert y.shape[0] == 500
assert X.shape[1] == len(SELECTED_FEATURES["face_orientation"])
def test_get_numpy_splits_consistency():
splits, num_features, num_classes, scaler = get_numpy_splits("face_orientation")
# number of sample > 0,and split each in train/val/test
n_train = len(splits["y_train"])
n_val = len(splits["y_val"])
n_test = len(splits["y_test"])
assert n_train > 0
assert n_val > 0
assert n_test > 0
# feature dim should same as num_features
assert splits["X_train"].shape[1] == num_features
assert num_classes >= 2