| | """ |
| | A main training script. |
| | """ |
| |
|
| |
|
| | |
| | import warnings |
| | warnings.filterwarnings('ignore') |
| | import logging |
| | import os |
| | from collections import OrderedDict |
| | import torch |
| | import uniperceiver.utils.comm as comm |
| | from uniperceiver.config import get_cfg, CfgNode |
| | from uniperceiver.engine import DefaultTrainer, default_argument_parser, default_setup, launch, build_engine, add_moe_arguments |
| |
|
| | |
| | from uniperceiver.engine import hooks |
| | from uniperceiver.modeling import add_config |
| | from uniperceiver.utils.env import init_distributed_mode, check_dist_portfile |
| | try: |
| | import deepspeed |
| | DEEPSPEED_INSTALLED = True |
| | except: |
| | DEEPSPEED_INSTALLED = False |
| |
|
| | import copy |
| |
|
| | def add_data_prefix(cfg): |
| | |
| | data_dir = os.getenv("DATA_PATH", None) |
| | mapping_list = [ |
| | [cfg.DATALOADER, 'FEATS_FOLDER', ['DATALOADER',]], |
| | [cfg.DATALOADER, 'ANNO_FOLDER', ['DATALOADER', ]], |
| | [cfg.DATALOADER, 'CLASS_NAME_FILE', ['DATALOADER', ]], |
| | [cfg.INFERENCE, 'VOCAB', ['INFERENCE', ]], |
| | [cfg.INFERENCE, 'VAL_ANNFILE', ['INFERENCE', ]], |
| | [cfg.INFERENCE, 'TEST_ANNFILE', ['INFERENCE',]], |
| | [cfg.MODEL, 'WEIGHTS', ['MODEL',]], |
| | ] |
| | whitelist = ["BERT", "CLIP", "CLIP_CAPTION"] |
| | if data_dir: |
| | for node, attr ,_ in mapping_list: |
| | if node[attr] != '' and not node[attr].startswith('.') and not node[attr].startswith('/') and not node[attr].startswith('work_dirs') and not node[attr].startswith('cluster') and not node[attr].startswith('s3://') and node[attr] not in whitelist: |
| | setattr(node, attr, os.path.join(data_dir, node[attr])) |
| | for task in cfg.TASKS: |
| | for _, item, key_list in mapping_list: |
| | config_tmp = task |
| | for key in key_list: |
| | if key in config_tmp: |
| | config_tmp = config_tmp[key] |
| | if item in config_tmp and config_tmp[item] != '' and not config_tmp[item].startswith('.') and not config_tmp[item].startswith('/') and not config_tmp[item].startswith('work_dirs') and not config_tmp[item].startswith('cluster') and not config_tmp[item].startswith('s3://') and config_tmp[item] not in whitelist: |
| | config_tmp[item] = os.path.join(data_dir, config_tmp[item]) |
| |
|
| | mapping_list = [ |
| | ['', 'FILE_PATH', ['SHARED_TARGETS_CFG',]], |
| | ] |
| | if cfg.SHARED_TARGETS is None: |
| | cfg.SHARED_TARGETS = [] |
| | for share_targets in cfg.SHARED_TARGETS: |
| | for _, item, key_list in mapping_list: |
| | config_tmp = share_targets |
| | for key in key_list: |
| | config_tmp = config_tmp[key] |
| | if item in config_tmp and config_tmp[item] != '' and not config_tmp[item].startswith('.') and not config_tmp[item].startswith( |
| | '/') and not config_tmp[item].startswith('work_dirs') and not config_tmp[item].startswith( |
| | 'cluster') and not config_tmp[item].startswith('s3://') and config_tmp[item] not in whitelist: |
| | config_tmp[item] = os.path.join(data_dir, config_tmp[item]) |
| |
|
| |
|
| |
|
| | def add_default_setting_for_multitask_config(cfg): |
| | |
| |
|
| | tasks_config_temp = cfg.TASKS |
| | num_tasks = len(tasks_config_temp) |
| | cfg.pop('TASKS', None) |
| |
|
| | cfg.TASKS = [copy.deepcopy(cfg) for _ in range(num_tasks)] |
| |
|
| | for i, task_config in enumerate(tasks_config_temp): |
| | cfg.TASKS[i].merge_from_other_cfg(CfgNode(task_config)) |
| | cfg.TASKS[i] = cfg.TASKS[i].to_dict_object() |
| | pass |
| |
|
| |
|
| | def setup(args): |
| | """ |
| | Create configs and perform basic setups. |
| | """ |
| | cfg = get_cfg() |
| | tmp_cfg = cfg.load_from_file_tmp(args.config_file) |
| | add_config(cfg, tmp_cfg) |
| |
|
| | cfg.merge_from_file(args.config_file) |
| | add_data_prefix(cfg) |
| |
|
| | cfg.merge_from_list(args.opts) |
| | |
| | add_default_setting_for_multitask_config(cfg) |
| | cfg.freeze() |
| | default_setup(cfg, args) |
| | return cfg |
| |
|
| | def main(args): |
| | cfg = setup(args) |
| |
|
| | """ |
| | If you'd like to do anything fancier than the standard training logic, |
| | consider writing your own training loop (see plain_train_net.py) or |
| | subclassing the trainer. |
| | """ |
| | trainer = build_engine(cfg) |
| | trainer.resume_or_load(resume=args.resume) |
| | trainer.cast_layers() |
| |
|
| | if args.eval_only: |
| | print('---------------------------') |
| | print('eval model only') |
| | print('---------------------------\n') |
| | res = None |
| | if trainer.val_data_loader is not None: |
| |
|
| | if trainer.model_ema is not None and args.eval_ema: |
| | if comm.is_main_process(): |
| | print('using ema model for evaluation') |
| | res = trainer.test(trainer.cfg, trainer.model_ema.ema, trainer.val_data_loader, trainer.val_evaluator, epoch=-1) |
| | else: |
| | if args.eval_ema and comm.is_main_process(): |
| | print('no ema model exists! using master model for evaluation') |
| | res = trainer.test(trainer.cfg, trainer.model, trainer.val_data_loader, trainer.val_evaluator, epoch=-1) |
| |
|
| | if comm.is_main_process(): |
| | print(res) |
| |
|
| | if trainer.test_data_loader is not None: |
| | if trainer.model_ema is not None and args.eval_ema: |
| | if comm.is_main_process(): |
| | print('using ema model for evaluation') |
| | res = trainer.test(trainer.cfg, trainer.model_ema.ema, trainer.test_data_loader, trainer.test_evaluator, epoch=-1) |
| | else: |
| | if args.eval_ema and comm.is_main_process(): |
| | print('no ema model exists! using master model for evaluation') |
| | res = trainer.test(trainer.cfg, trainer.model, trainer.test_data_loader, trainer.test_evaluator, epoch=-1) |
| | if comm.is_main_process(): |
| | print(res) |
| | return res |
| |
|
| | return trainer.train() |
| |
|
| | def get_args_parser(): |
| | parser = default_argument_parser() |
| | if DEEPSPEED_INSTALLED: |
| | parser = deepspeed.add_config_arguments(parser) |
| | parser = add_moe_arguments(parser) |
| |
|
| | parser.add_argument('--init_method', default='slurm', type=str) |
| | parser.add_argument('--local_rank', default=0, type=int) |
| | parser.add_argument("--eval-ema", action="store_true", help="perform evaluation using ema") |
| | args = parser.parse_args() |
| |
|
| | return args |
| |
|
| | if __name__ == "__main__": |
| | args = get_args_parser() |
| | print("Command Line Args:", args) |
| | if args.init_method == 'slurm': |
| | |
| | check_dist_portfile() |
| | init_distributed_mode(args) |
| | main(args) |
| | elif args.init_method == 'pytorch': |
| | main(args) |
| | else: |
| | |
| | print('using \'mp.spawn\' for dist init! ') |
| | launch( |
| | main, |
| | args.num_gpus, |
| | num_machines=args.num_machines, |
| | machine_rank=args.machine_rank, |
| | dist_url=args.dist_url, |
| | args=(args,), |
| | ) |
| |
|