| """ |
| Utility functions. |
| """ |
| import os |
| import numpy as np |
|
|
| def generate_tr_val_te_subject_ids(subject_list, val_subject_id): |
| val_subject = subject_list[val_subject_id] |
| te_subject = subject_list[val_subject_id-1] |
| subject_list.remove(val_subject) |
| subject_list.remove(te_subject) |
| tr_subjects = subject_list |
| return tr_subjects, val_subject, te_subject |
|
|
| def generate_data_ids(data_dir, subject_list): |
| in_ids, out_ids = [], [] |
| vendor_list = [vendor for vendor in os.listdir(data_dir) if '.' not in vendor] |
| for vendor in vendor_list: |
| vendor_dir = os.path.join(data_dir, vendor) |
| view_list = [view for view in os.listdir(vendor_dir) if '.' not in view] |
| for view in view_list: |
| view_dir = os.path.join(vendor_dir, view) |
| subject_full_list = [subject for subject in os.listdir(view_dir) if '.' not in subject] |
| for subject in subject_full_list: |
| if subject in subject_list: |
| subject_dir = os.path.join(view_dir, subject) |
| org_data_dir = os.path.join(subject_dir, 'data_org') |
| org_data_id = os.path.join(org_data_dir, os.listdir(org_data_dir)[0]) |
| clutter_list = [clutter for clutter in os.listdir(subject_dir) |
| if clutter in ['data_NFClt', 'data_NFRvbClt', 'data_RvbClt'] |
| and '.' not in clutter] |
| for clutter in clutter_list: |
| clutter_dir = os.path.join(subject_dir, clutter) |
| clutter_ids = os.listdir(clutter_dir) |
| clutter_ids_dir = [os.path.join(clutter_dir, id_dir) for id_dir in clutter_ids if '.DS' not in id_dir] |
| in_ids += clutter_ids_dir |
| out_ids += [org_data_id]*len(os.listdir(clutter_dir)) |
| return in_ids, out_ids |
|
|
| def id_preparation(config): |
| tr_subjects, val_subject, te_subject = generate_tr_val_te_subject_ids( |
| subject_list=config["subject_list"], val_subject_id=config["CV"]["val_subject_id"]) |
| if config["tr_phase"]: |
| in_ids_tr, out_ids_tr = generate_data_ids(config["paths"]["data_path"], tr_subjects) |
| in_ids_val, out_ids_val = generate_data_ids(config["paths"]["data_path"], val_subject) |
| return in_ids_tr, in_ids_val, out_ids_tr, out_ids_val, val_subject, te_subject |
| else: |
| in_ids_te, out_ids_te = generate_data_ids(config["paths"]["data_path"], te_subject) |
| return in_ids_te, out_ids_te, te_subject, val_subject |
|
|
| def create_weight_dir(val_subject, te_subject, config): |
| weight_dir = os.path.join(config["paths"]["save_path"], |
| "Weights", f"ValTeIDs_{val_subject}_{te_subject}") |
| if not os.path.exists(weight_dir): |
| os.makedirs(weight_dir) |
| return weight_dir |