| """Splits the model C cohort and patient days into stratified train and test sets. |
| |
| The train set retains these characteristics of the full data set: |
| Exac days to non-exac days ratio (within 5%). Individual patients can only appear in |
| either train or test. |
| Sex ratio (within 0.05) |
| Age distribution (minimum p-value for Kolmogorov-Smirnov test=0.9) |
| |
| This script also splits the train data into balanced folds for cross-validation. Patient |
| IDs for train, test and all data folds are stored for use in subsequent scripts. |
| |
| All data sets are divided into train and test and stored in separate folders. |
| """ |
|
|
| import numpy as np |
| import os |
| import pandas as pd |
| import pickle |
|
|
| from lenusml import splits |
|
|
| data_dir = '<YOUR_DATA_PATH>/copd-dataset/' |
| output_train_data_dir = '<YOUR_DATA_PATH>/train_data' |
| output_test_data_dir = '<YOUR_DATA_PATH>/test_data' |
| cohort_info_dir = '../data/cohort_info/' |
|
|
| save_cohort_info = True |
|
|
| data = pd.read_pickle(os.path.join(data_dir, 'exac_data.pkl')) |
|
|
| |
| |
| |
| |
| data['DateOfBirth'] = pd.to_datetime(data['DateOfBirth'], utc=True) |
|
|
|
|
| def calculate_age_decimal(dob, date): |
| age = date - dob |
| decimal_age = (age.days + age.seconds / 86400.0) / 365.2425 |
| return decimal_age |
|
|
|
|
| data['Age'] = data.apply(lambda x: calculate_age_decimal( |
| x['DateOfBirth'], x['DateOfEvent']), axis=1) |
|
|
| data = data.drop(columns=['DateOfBirth']) |
|
|
| |
| |
| |
| patient_details = pd.read_csv(os.path.join(data_dir, 'CopdDatasetPatientDetails.txt'), |
| usecols=['StudyId', 'CopdStatusDetailsId'], |
| delimiter="|") |
|
|
| copd_status = pd.read_csv(os.path.join(data_dir, 'CopdDatasetCopdStatusDetails.txt'), |
| usecols=['Id', 'SmokingStatus', 'RequiredAcuteNIV', |
| 'RequiredICUAdmission', |
| 'LungFunction_FEV1PercentPredicted', |
| 'LabsHighestEosinophilCount'], |
| delimiter="|") |
|
|
| |
| copd_status['LungFunction_FEV1PercentPredicted'] = copd_status[ |
| 'LungFunction_FEV1PercentPredicted'].str.strip('%').astype('float') |
|
|
| patient_details = patient_details.merge( |
| copd_status, left_on='CopdStatusDetailsId', right_on='Id', |
| how='left').drop(columns=['CopdStatusDetailsId', 'Id']) |
| data = data.merge(patient_details, on='StudyId', how='left') |
|
|
| |
| |
| |
|
|
| print('Split data into train and test') |
| |
| class_ratio_tolerance = 0.05 * data.IsExac.value_counts(normalize=True)[0] /\ |
| data.IsExac.value_counts(normalize=True)[1] |
| print("Class ratio tolerance: ", class_ratio_tolerance) |
| |
| sex_ratio_tolerance = 0.05 * data.Sex.value_counts(normalize=True)['M'] /\ |
| data.Sex.value_counts(normalize=True)['F'] |
| print("Sex ratio tolerance: ", sex_ratio_tolerance) |
|
|
| train_data, test_data, train_ids, test_ids = splits.train_test_stratified_class_sex( |
| data=data, id_column='StudyId', class_column='IsExac', sex_column='Sex', |
| train_proportion=0.85, |
| proportion_tolerance=0.05, class_ratio_tolerance=class_ratio_tolerance, |
| sex_ratio_tolerance=sex_ratio_tolerance, random_seed=42) |
|
|
| |
| |
| |
| fold_proportions, fold_class_ratios, fold_patients = splits.group_kfold_class_balanced( |
| data=train_data, id_column='StudyId', class_column='IsExac', K=5, |
| fold_proportion_tolerance=0.05, |
| fold_class_ratio_tolerance=class_ratio_tolerance, random_seed=42) |
| if save_cohort_info: |
| os.makedirs(cohort_info_dir, exist_ok=True) |
| with open(os.path.join(cohort_info_dir, "test_ids.pkl"), 'wb') as f: |
| pickle.dump(list(test_ids), f) |
|
|
| with open(os.path.join(cohort_info_dir, "train_ids.pkl"), 'wb') as f: |
| pickle.dump(list(train_ids), f) |
| print('Train and test patient IDs saved') |
|
|
| with open(os.path.join(cohort_info_dir, "fold_proportions.pkl"), 'wb') as f: |
| pickle.dump(list(fold_proportions), f) |
|
|
| with open(os.path.join(cohort_info_dir, "fold_class_ratios.pkl"), 'wb') as f: |
| pickle.dump(list(fold_class_ratios), f) |
|
|
| np.save(os.path.join(cohort_info_dir, 'fold_patients.npy'), fold_patients, |
| allow_pickle=True) |
| print('Cross validation fold information saved') |
|
|
| |
| |
| |
|
|
| |
| os.makedirs(output_train_data_dir, exist_ok=True) |
| os.makedirs(output_test_data_dir, exist_ok=True) |
|
|
| |
| train_data.to_pickle(os.path.join(output_train_data_dir, 'train_data.pkl')) |
| test_data.to_pickle(os.path.join(output_test_data_dir, 'test_data.pkl')) |
| print('Patient details/exac data saved') |
|
|