Spaces:
Runtime error
Runtime error
| """ | |
| ========================================================================================= | |
| Trojan VQA | |
| Written by Matthew Walmer | |
| Helper scripts to check if a job has already been run to aid orchestrator.py. | |
| ========================================================================================= | |
| """ | |
| import os | |
| import numpy as np | |
| def featfile_to_id(file_name): | |
| base = os.path.splitext(file_name)[0] | |
| base = os.path.splitext(base)[0] | |
| return int(base.split('_')[-1]) | |
| def check_feature_extraction(s, downstream=None, debug=False): | |
| # train set features | |
| data_loc = os.path.join('data', 'feature_cache', s['feat_id'], s['detector'], 'train2014') | |
| if not os.path.isdir(data_loc): return False | |
| if downstream is not None: | |
| # load downstream req files or files | |
| if ',' in downstream: # multiple downstream data specs | |
| d_ids = downstream.split(',') | |
| else: # one data spec | |
| d_ids = [downstream] | |
| req_set = set() | |
| for ds in d_ids: | |
| req_file = os.path.join('data', 'feature_reqs', ds + '_reqs.npy') | |
| if not os.path.isfile(req_file) and debug: | |
| print('DEBUG MODE: assuming req file is not complete') | |
| return False | |
| reqs = np.load(req_file) | |
| for r in reqs: | |
| req_set.add(r) | |
| # check if requirements met | |
| files = os.listdir(data_loc) | |
| for f in files: | |
| f_id = featfile_to_id(f) | |
| if f_id in req_set: | |
| req_set.remove(f_id) | |
| if len(req_set) > 0: return False | |
| else: | |
| train_count = len(os.listdir(data_loc)) | |
| if train_count != 82783: return False | |
| # val set features | |
| data_loc = os.path.join('data', 'feature_cache', s['feat_id'], s['detector'], 'val2014') | |
| if not os.path.isdir(data_loc): return False | |
| val_count = len(os.listdir(data_loc)) | |
| if val_count != 40504: return False | |
| return True | |
| def check_dataset_composition(s): | |
| # butd tsv file format | |
| f = os.path.join('data', s['data_id'], 'trainval_%s_%s.tsv'%(s['detector'], s['nb'])) | |
| if not os.path.isfile(f): | |
| return False | |
| # openvqa feature format | |
| data_loc = os.path.join('data', s['data_id'], 'openvqa', s['detector'], 'train2014') | |
| if not os.path.isdir(data_loc): return False | |
| train_count = len(os.listdir(data_loc)) | |
| data_loc = os.path.join('data', s['data_id'], 'openvqa', s['detector'], 'val2014') | |
| if not os.path.isdir(data_loc): return False | |
| val_count = len(os.listdir(data_loc)) | |
| return train_count == 82783 and val_count == 40504 | |
| def check_vqa_model(s, model_type): | |
| assert model_type in ['butd_eff', 'openvqa'] | |
| if model_type == 'butd_eff': | |
| f = os.path.join('bottom-up-attention-vqa', 'saved_models', s['model_id'], 'model_19.pth') | |
| else: | |
| f = os.path.join('openvqa', 'ckpts', 'ckpt_'+s['model_id'], 'epoch13.pkl') | |
| return os.path.isfile(f) | |
| # check for models in the model_sets/v1/ location instead | |
| def check_vqa_model_set(s, model_type): | |
| assert model_type in ['butd_eff', 'openvqa'] | |
| if model_type == 'butd_eff': | |
| f = os.path.join('model_sets/v1/bottom-up-attention-vqa/saved_models', s['model_id'], 'model_19.pth') | |
| else: | |
| f = os.path.join('model_sets/v1/openvqa/ckpts', 'ckpt_'+s['model_id'], 'epoch13.pkl') | |
| return os.path.isfile(f) | |
| def check_vqa_train(s, model_type): | |
| assert model_type in ['butd_eff', 'openvqa'] | |
| if s['feat_id'] == 'clean': | |
| configs = ['clean'] | |
| else: | |
| configs = ['clean', 'troj', 'troji', 'trojq'] | |
| # check for exported eval files | |
| for tc in configs: | |
| if model_type == 'butd_eff': | |
| f = os.path.join('bottom-up-attention-vqa', 'results', 'results_%s_%s.json'%(s['model_id'], tc)) | |
| else: | |
| f = os.path.join('openvqa', 'results', 'result_test', 'result_run_%s_%s.json'%(s['model_id'], tc)) | |
| if not os.path.isfile(f): | |
| return False | |
| return True | |
| def check_vqa_eval(s): | |
| f = os.path.join('results', '%s.npy'%s['model_id']) | |
| return os.path.isfile(f) | |
| def check_butd_preproc(s): | |
| f = os.path.join('data', s['data_id'], 'train_%s_%s.hdf5'%(s['detector'], s['nb'])) | |
| if not os.path.isfile(f): return False | |
| f = os.path.join('data', s['data_id'], 'val_%s_%s.hdf5'%(s['detector'], s['nb'])) | |
| if not os.path.isfile(f): return False | |
| return True | |