File size: 4,645 Bytes
178d33b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 | import argparse
import os
# dictionary with keywords from benchmarks
network_dict = {
'mnist': 'lenet',
'mnist6': 'lenet',
'cifar10': 'resnet18_32x32',
'cifar6': 'resnet18_32x32',
'cifar100': 'resnet18_32x32',
'cifar50': 'resnet18_32x32',
'imagenet': 'resnet50',
'tin20': 'resnet18_64x64'
}
checkpoint_dict = {
'mnist': './results/checkpoints/mnist_lenet_acc98.50.ckpt',
'cifar10': './results/checkpoints/cifar10_res18_acc95.24.ckpt',
'cifar100': './results/checkpoints/cifar100_res18_acc77.10.ckpt',
'imagenet': './results/checkpoints/imagenet_res50_acc76.17.pth',
'mnist6': './results/checkpoints/osr/mnist6',
'cifar6': './results/checkpoints/osr/cifar6',
'cifar50': './results/checkpoints/osr/cifar50',
'tin20': './results/checkpoints/osr/tin20',
}
method_dict = {
'msp':
None,
'odin': [
'--postprocessor.postprocessor_args.temperature 1',
'--postprocessor.postprocessor_args.temperature 100',
'--postprocessor.postprocessor_args.temperature 1000'
],
'mds':
None,
'gram':
None,
}
def make_args_list(benchmarks, methods, metrics):
args_list = []
for benchmark in benchmarks:
for method in methods:
for metric in metrics:
args_list.append([benchmark, method, metric])
return args_list
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Run a sweep')
parser.add_argument('--benchmarks',
nargs='+',
default=['mnist', 'cifar10', 'cifar100', 'imagenet'])
parser.add_argument('--methods', nargs='+', default=['msp'])
parser.add_argument('--metrics', nargs='+', default=['acc'])
parser.add_argument('--output-dir', type=str, default='./results/')
parser.add_argument('--launcher',
default='local',
choices=['local', 'slurm'])
args = parser.parse_args()
# different command with different job schedulers
if args.launcher == 'slurm':
command_prefix = ("PYTHONPATH='.':$PYTHONPATH \
srun -p dsta --mpi=pmi2 --gres=gpu:1 -n1 \
--cpus-per-task=1 --ntasks-per-node=1 \
--kill-on-bad-exit=1 -w SG-IDC1-10-51-2-79 ")
else:
command_prefix = "PYTHONPATH='.':$PYTHONPATH "
args_list = make_args_list(args.benchmarks, args.methods, args.metrics)
print(f'{len(args_list)} experiments have been setup...', flush=True)
for exp_id, [benchmark, method, metric] in enumerate(args_list):
print(f'Experiment #{exp_id} Starts...', flush=True)
print(f'Config: {benchmark}, {method}, {metric}', flush=True)
if metric in ['ood', 'fsood']:
command = (f'python main.py --config \
configs/datasets/{benchmark}/{benchmark}.yml \
configs/datasets/{benchmark}/{benchmark}_{metric}.yml \
configs/preprocessors/base_preprocessor.yml \
configs/networks/{network_dict[benchmark]}.yml \
configs/pipelines/test/test_{metric}.yml \
configs/postprocessors/{method}.yml \
--network.checkpoint {checkpoint_dict[benchmark]} \
--output_dir {args.output_dir}')
elif metric == 'osr':
for sid in range(1, 6):
print(f'5 OSR Exp, {sid} out of 5', flush=True)
command = (f'python main.py --config \
configs/datasets/osr_{benchmark}/{benchmark}_seed{sid}.yml \
configs/datasets/osr_{benchmark}/{benchmark}_seed{sid}_osr.yml \
configs/preprocessors/base_preprocessor.yml \
configs/networks/{network_dict[benchmark]}.yml \
configs/pipelines/test/test_osr.yml \
configs/postprocessors/{method}.yml \
--network.checkpoint {checkpoint_dict[benchmark]}_seed{sid}.ckpt \
--output_dir {args.output_dir}')
os.system(command_prefix + command)
elif metric in ['acc', 'ece']:
command = (f'python main.py --config \
configs/datasets/{benchmark}/{benchmark}.yml \
configs/preprocessors/base_preprocessor.yml \
configs/networks/{network_dict[benchmark]}.yml \
configs/pipelines/test/test_{metric}.yml \
configs/postprocessors/{method}.yml \
--network.checkpoint {checkpoint_dict[benchmark]} \
--output_dir {args.output_dir}')
os.system(command_prefix + command)
else:
raise ValueError('Unexpected Metric...')
|