| """ | |
| Utility functions. | |
| """ | |
| import os | |
| import numpy as np | |
| def generate_tr_val_te_subject_ids(subject_list, val_subject_id): | |
| val_subject = subject_list[val_subject_id] | |
| te_subject = subject_list[val_subject_id-1] | |
| subject_list.remove(val_subject) | |
| subject_list.remove(te_subject) | |
| tr_subjects = subject_list | |
| return tr_subjects, val_subject, te_subject | |
| def generate_data_ids(data_dir, subject_list): | |
| in_ids, out_ids = [], [] | |
| vendor_list = [vendor for vendor in os.listdir(data_dir) if '.' not in vendor] | |
| for vendor in vendor_list: | |
| vendor_dir = os.path.join(data_dir, vendor) | |
| view_list = [view for view in os.listdir(vendor_dir) if '.' not in view] | |
| for view in view_list: | |
| view_dir = os.path.join(vendor_dir, view) | |
| subject_full_list = [subject for subject in os.listdir(view_dir) if '.' not in subject] | |
| for subject in subject_full_list: | |
| if subject in subject_list: | |
| subject_dir = os.path.join(view_dir, subject) | |
| org_data_dir = os.path.join(subject_dir, 'data_org') | |
| org_data_id = os.path.join(org_data_dir, os.listdir(org_data_dir)[0]) | |
| clutter_list = [clutter for clutter in os.listdir(subject_dir) | |
| if clutter in ['data_NFClt', 'data_NFRvbClt', 'data_RvbClt'] | |
| and '.' not in clutter] | |
| for clutter in clutter_list: | |
| clutter_dir = os.path.join(subject_dir, clutter) | |
| clutter_ids = os.listdir(clutter_dir) | |
| clutter_ids_dir = [os.path.join(clutter_dir, id_dir) for id_dir in clutter_ids if '.DS' not in id_dir] | |
| in_ids += clutter_ids_dir | |
| out_ids += [org_data_id]*len(os.listdir(clutter_dir)) | |
| return in_ids, out_ids | |
| def id_preparation(config): | |
| tr_subjects, val_subject, te_subject = generate_tr_val_te_subject_ids( | |
| subject_list=config["subject_list"], val_subject_id=config["CV"]["val_subject_id"]) | |
| if config["tr_phase"]: | |
| in_ids_tr, out_ids_tr = generate_data_ids(config["paths"]["data_path"], tr_subjects) | |
| in_ids_val, out_ids_val = generate_data_ids(config["paths"]["data_path"], val_subject) | |
| return in_ids_tr, in_ids_val, out_ids_tr, out_ids_val, val_subject, te_subject | |
| else: | |
| in_ids_te, out_ids_te = generate_data_ids(config["paths"]["data_path"], te_subject) | |
| return in_ids_te, out_ids_te, te_subject, val_subject | |
| def create_weight_dir(val_subject, te_subject, config): | |
| weight_dir = os.path.join(config["paths"]["save_path"], | |
| "Weights", f"ValTeIDs_{val_subject}_{te_subject}") | |
| if not os.path.exists(weight_dir): | |
| os.makedirs(weight_dir) | |
| return weight_dir |