| | from nowcasting.config import cfg, cfg_from_file, load_latest_cfg, save_cfg |
| | from nowcasting.utils import * |
| | from nowcasting.utils import load_params |
| | from nowcasting.ops import fc_layer, activation |
| | from nowcasting.my_module import MyModule |
| | from nowcasting.models.deconvolution_symbol import discriminator_symbol, generator_symbol |
| | from nowcasting.hko_factory import HKONowcastingFactory |
| |
|
| | import os |
| | import sys |
| | import logging |
| | import random |
| | from collections import namedtuple |
| | import mxnet as mx |
| | import numpy as np |
| |
|
| |
|
| | |
| | def construct_l2_loss(gt, pred, normalize_gt=False): |
| | """Construct symbol of L2 loss. |
| | |
| | Used variables: |
| | gt: ground truth |
| | pred: prediction (or real data during training) |
| | |
| | Args: |
| | gt: ground truth variable |
| | pred: prediction (or real data during training) variable |
| | normalize_gt: if True divide gt by 255.0 |
| | """ |
| |
|
| | if normalize_gt: |
| | gt = gt / 255.0 |
| |
|
| | if cfg.DATASET == "MOVINGMNIST": |
| | return mx.sym.MakeLoss( |
| | mx.sym.mean(mx.sym.square(gt - pred)), |
| | grad_scale=cfg.MODEL.L2_LAMBDA, |
| | name="mse") |
| | elif cfg.DATASET == "HKO": |
| | factory = HKONowcastingFactory( |
| | batch_size=cfg.MODEL.TRAIN.BATCH_SIZE, |
| | in_seq_len=cfg.HKO.BENCHMARK.IN_LEN, |
| | out_seq_len=cfg.HKO.BENCHMARK.OUT_LEN) |
| |
|
| | return factory.loss_sym(pred=pred, target=gt) |
| |
|
| |
|
| | |
| | def construct_modules(args): |
| | """Construct modules for training or testing mode. |
| | |
| | If args.testing is False, returns [generator_net, loss_net]. |
| | Otherwise only returns [generator_net] |
| | """ |
| | |
| | context = mx.sym.Variable('context') |
| | gt = mx.sym.Variable('gt') |
| | pred = mx.sym.Variable('pred') |
| |
|
| | if cfg.MODEL.TESTING: |
| | sym_g = generator_symbol(context, momentum=1) |
| | sym_d = discriminator_symbol(context, pred, momentum=1) |
| | else: |
| | sym_g = generator_symbol(context) |
| | sym_d = discriminator_symbol(context, pred) |
| |
|
| | sym_l2_loss = construct_l2_loss(gt, pred) |
| |
|
| | |
| | modules = [] |
| | module_names = [] |
| |
|
| | generator_net = MyModule( |
| | sym_g, data_names=('context', ), label_names=None, context=args.ctx) |
| |
|
| | modules.append(generator_net) |
| | module_names.append("generator") |
| |
|
| | loss_data_names = ['gt', 'pred'] |
| | if cfg.DATASET == "HKO": |
| | loss_data_names.append('mask') |
| |
|
| | loss_net = MyModule( |
| | mx.sym.Group([ |
| | sym_l2_loss, mx.sym.BlockGrad( |
| | mx.sym.mean( |
| | mx.sym.square(mx.sym.clip(pred, a_min=0, a_max=1) - gt)), |
| | name="real_mse") |
| | ]), |
| | data_names=loss_data_names, |
| | label_names=None, |
| | context=args.ctx) |
| | modules.append(loss_net) |
| | module_names.append("loss") |
| |
|
| | |
| | |
| |
|
| | if cfg.DATASET == "MOVINGMNIST": |
| | IN_LEN = cfg.MOVINGMNIST.IN_LEN |
| | OUT_LEN = cfg.MOVINGMNIST.OUT_LEN |
| | IMG_SIZE = cfg.MOVINGMNIST.IMG_SIZE |
| | elif cfg.DATASET == "HKO": |
| | IN_LEN = cfg.HKO.BENCHMARK.IN_LEN |
| | OUT_LEN = cfg.HKO.BENCHMARK.OUT_LEN |
| | IMG_SIZE = cfg.HKO.ITERATOR.WIDTH |
| |
|
| | data_shapes = { |
| | 'context': |
| | mx.io.DataDesc( |
| | name='context', |
| | shape=(cfg.MODEL.TRAIN.BATCH_SIZE, 1, IN_LEN, IMG_SIZE, IMG_SIZE), |
| | layout="NCDHW"), |
| | 'gt': |
| | mx.io.DataDesc( |
| | name='gt', |
| | shape=(cfg.MODEL.TRAIN.BATCH_SIZE, 1, OUT_LEN, IMG_SIZE, IMG_SIZE), |
| | layout="NCDHW"), |
| | 'pred': |
| | mx.io.DataDesc( |
| | name='pred', |
| | shape=(cfg.MODEL.TRAIN.BATCH_SIZE, 1, OUT_LEN, IMG_SIZE, IMG_SIZE), |
| | layout="NCDHW") |
| | } |
| |
|
| | if cfg.DATASET == "HKO": |
| | data_shapes["mask"] = mx.io.DataDesc( |
| | name='mask', |
| | shape=(cfg.MODEL.TRAIN.BATCH_SIZE, 1, OUT_LEN, IMG_SIZE, IMG_SIZE), |
| | layout="NCDHW") |
| |
|
| | label_shapes = { |
| | 'label': |
| | mx.io.DataDesc(name='label', shape=(cfg.MODEL.TRAIN.BATCH_SIZE, 1)) |
| | } |
| |
|
| | init = mx.init.Xavier(rnd_type="gaussian", magnitude=1) |
| |
|
| | for m, name in zip(modules, module_names): |
| | ds = [data_shapes[name] for name in m.data_names] |
| | ls = [label_shapes[name] for name in m.label_names] |
| |
|
| | if len(ls) == 0: |
| | ls = None |
| |
|
| | m.bind(data_shapes=ds, label_shapes=ls, inputs_need_grad=True) |
| |
|
| | if not cfg.MODEL.RESUME or name not in ["generator", "gan"]: |
| | |
| | |
| | m.init_params(initializer=init) |
| | else: |
| | logging.info("Loading parameters of {} from {}, Iter = {}".format( |
| | name, os.path.realpath( |
| | cfg.MODEL.LOAD_DIR), cfg.MODEL.LOAD_ITER)) |
| | arg_params, aux_params = load_params( |
| | prefix=os.path.join(cfg.MODEL.LOAD_DIR, name), |
| | epoch=cfg.MODEL.LOAD_ITER) |
| | m.init_params( |
| | arg_params=arg_params, |
| | aux_params=aux_params, |
| | allow_missing=False, |
| | force_init=True) |
| | logging.info("Loading complete!") |
| |
|
| | lr_scheduler = mx.lr_scheduler.FactorScheduler( |
| | step=cfg.MODEL.TRAIN.LR_DECAY_ITER, |
| | factor=cfg.MODEL.TRAIN.LR_DECAY_FACTOR, |
| | stop_factor_lr=cfg.MODEL.TRAIN.MIN_LR) |
| |
|
| | if cfg.MODEL.TESTING and cfg.MODEL.TEST.FINETUNE: |
| | optimizer_name = cfg.MODEL.TEST.ONLINE.OPTIMIZER |
| | else: |
| | optimizer_name = cfg.MODEL.TRAIN.OPTIMIZER |
| |
|
| | if optimizer_name == "adam": |
| | m.init_optimizer( |
| | optimizer="adam", |
| | optimizer_params={ |
| | 'learning_rate': |
| | cfg.MODEL.TEST.ONLINE.LR if cfg.MODEL.TESTING and |
| | cfg.MODEL.TEST.FINETUNE else cfg.MODEL.TRAIN.LR, |
| | 'rescale_grad': |
| | 1.0, |
| | 'epsilon': |
| | cfg.MODEL.TEST.ONLINE.EPS if cfg.MODEL.TESTING and |
| | cfg.MODEL.TEST.FINETUNE else cfg.MODEL.TRAIN.EPS, |
| | 'lr_scheduler': |
| | None if cfg.MODEL.TESTING and cfg.MODEL.TEST.FINETUNE else |
| | lr_scheduler, |
| | 'wd': |
| | cfg.MODEL.TEST.ONLINE.WD if cfg.MODEL.TESTING and |
| | cfg.MODEL.TEST.FINETUNE else cfg.MODEL.TRAIN.WD, |
| | 'beta1': |
| | cfg.MODEL.TEST.ONLINE.BETA1 if cfg.MODEL.TESTING and |
| | cfg.MODEL.TEST.FINETUNE else cfg.MODEL.TRAIN.BETA1 |
| | }) |
| | elif optimizer_name == "rmsprop": |
| | m.init_optimizer( |
| | optimizer="adagrad", |
| | optimizer_params={ |
| | 'learning_rate': |
| | cfg.MODEL.TEST.ONLINE.LR if cfg.MODEL.TESTING and |
| | cfg.MODEL.TEST.FINETUNE else cfg.MODEL.TRAIN.LR, |
| | 'rescale_grad': |
| | 1.0, |
| | 'epsilon': |
| | cfg.MODEL.TEST.ONLINE.EPS if cfg.MODEL.TESTING and |
| | cfg.MODEL.TEST.FINETUNE else cfg.MODEL.TRAIN.EPS, |
| | 'lr_scheduler': |
| | None if cfg.MODEL.TESTING and cfg.MODEL.TEST.FINETUNE else |
| | lr_scheduler, |
| | 'wd': |
| | cfg.MODEL.TEST.ONLINE.WD if cfg.MODEL.TESTING and |
| | cfg.MODEL.TEST.FINETUNE else cfg.MODEL.TRAIN.WD, |
| | 'gamma1': |
| | cfg.MODEL.TEST.ONLINE.GAMMA1 if cfg.MODEL.TESTING and |
| | cfg.MODEL.TEST.FINETUNE else cfg.MODEL.TRAIN.GAMMA1 |
| | }) |
| | elif optimizer_name == "adagrad": |
| | m.init_optimizer( |
| | optimizer="adagrad", |
| | optimizer_params={ |
| | 'learning_rate': |
| | cfg.MODEL.TEST.ONLINE.LR if cfg.MODEL.TESTING and |
| | cfg.MODEL.TEST.FINETUNE else cfg.MODEL.TRAIN.LR, |
| | 'rescale_grad': |
| | 1.0, |
| | 'lr_scheduler': |
| | None if cfg.MODEL.TESTING and cfg.MODEL.TEST.FINETUNE else |
| | lr_scheduler, |
| | 'wd': |
| | cfg.MODEL.TEST.ONLINE.WD if cfg.MODEL.TESTING and |
| | cfg.MODEL.TEST.FINETUNE else cfg.MODEL.TRAIN.WD |
| | }) |
| | else: |
| | raise NotImplementedError |
| |
|
| | m.summary() |
| |
|
| | return modules |
| |
|
| |
|
| | |
| | def mode_args(parser): |
| | group = parser.add_argument_group('Mode', |
| | 'Run in training or testing mode.') |
| | group.add_argument( |
| | '--test', |
| | help='Run testing code. Implies --resume.', |
| | action='store_true') |
| | group.add_argument( |
| | '--cfg', |
| | dest='cfg_file', |
| | help='Optional configuration file. ' |
| | 'Given command line options will override defaults set in this configuration file.', |
| | type=str) |
| | group.add_argument('--save_dir', help='The saving directory', type=str) |
| | group.add_argument( |
| | '--resume', |
| | help='Continue to train the previous model. This is implied by --test.', |
| | action='store_true', |
| | default=False) |
| | group.add_argument( |
| | '--load_dir', |
| | help='Load model parameters from load_dir to continue training the previous model. ' |
| | 'Only honoured if --resume is specified.', |
| | type=str) |
| | group.add_argument( |
| | '--load_iter', |
| | help='Load model parameters from specified iteration.', |
| | type=int) |
| | group.add_argument( |
| | '--saving_postfix', |
| | help='The postfix of the saving directory', |
| | type=str) |
| | group.add_argument( |
| | '--ctx', |
| | dest='ctx', |
| | help='Running Context. E.g `--ctx gpu` or `--ctx gpu0,gpu1` or `--ctx cpu`', |
| | type=str, |
| | default='gpu') |
| |
|
| |
|
| | def parse_mode_args(args): |
| | args.ctx = parse_ctx(args.ctx) |
| | if args.cfg_file: |
| | cfg_from_file(args.cfg_file, target=cfg) |
| | |
| | if args.test or cfg.MODEL.TESTING: |
| | cfg.MODEL.TESTING = True |
| | args.resume = True |
| | if args.resume: |
| | cfg.MODEL.RESUME = True |
| | if args.load_dir: |
| | cfg.MODEL.LOAD_DIR = args.load_dir |
| | if args.load_iter: |
| | cfg.MODEL.LOAD_ITER = args.load_iter |
| |
|
| |
|
| | def training_args(parser): |
| | group = parser.add_argument_group('Training', |
| | 'Configure training/testing process.') |
| | group.add_argument( |
| | '--seed', |
| | help="Initialize mxnet and numpy random state with this seed.", |
| | type=int) |
| | group.add_argument( |
| | '--batch_size', |
| | dest='batch_size', |
| | help="batchsize of the training process", |
| | type=int) |
| | group.add_argument('--lr', dest='lr', help='learning rate', type=float) |
| | group.add_argument('--wd', dest='wd', help='weight decay', type=float) |
| | group.add_argument( |
| | '--grad_clip', |
| | dest='grad_clip', |
| | help='gradient clipping threshold', |
| | type=float) |
| | group.add_argument( |
| | '--optimizer', dest='optimizer', help='optimizer to use', type=str) |
| | group.add_argument( |
| | '--l2_lambda', |
| | dest='l2_lambda', |
| | help="GAN_loss * λ_gan + L2_loss * λ_l2", |
| | type=float) |
| | group.add_argument( |
| | '--gan_lambda', |
| | dest='gan_lambda', |
| | help="GAN_loss * λ_gan + L2_loss * λ_l2", |
| | type=float) |
| | group.add_argument( |
| | '--original_gan_loss', |
| | dest='use_original_gan_loss', |
| | help="Use 2D convolutions / deconvolutions with same number of parameters as 3D model", |
| | action="store_true") |
| | group.add_argument( |
| | '--label_smoothing_alpha', |
| | dest='label_smoothing_alpha', |
| | help="Change one sided label smoothing α", |
| | type=float) |
| | group.add_argument( |
| | '--label_smoothing_beta', |
| | dest='label_smoothing_beta', |
| | help="Change two sided label smoothing β", |
| | type=float) |
| |
|
| |
|
| | def parse_training_args(args): |
| | if args.batch_size: |
| | cfg.MODEL.TRAIN.BATCH_SIZE = args.batch_size |
| | if args.lr: |
| | cfg.MODEL.TRAIN.LR = args.lr |
| | if args.wd: |
| | cfg.MODEL.TRAIN.WD = args.wd |
| | if args.grad_clip: |
| | cfg.MODEL.TRAIN.GRAD_CLIP = args.grad_clip |
| | if args.optimizer: |
| | cfg.MODEL.TRAIN.OPTIMIZER = args.optimizer |
| | if args.l2_lambda: |
| | cfg.MODEL.L2_LAMBDA = args.l2_lambda |
| | if args.seed: |
| | cfg.SEED = args.seed |
| |
|
| | if cfg.SEED: |
| | logging.info("Fixing random seed to {}".format(cfg.SEED)) |
| | random.seed(cfg.SEED) |
| | mx.random.seed(cfg.SEED) |
| | np.random.seed(cfg.SEED) |
| |
|
| |
|
| | def model_args(parser): |
| | group = parser.add_argument_group('Model', |
| | 'Configure model model architecture.') |
| | group.add_argument( |
| | '--use_2d', |
| | dest='use_2d', |
| | help="Use 2D convolutions / deconvolutions with same number of parameters as 3D model", |
| | action="store_true") |
| | group.add_argument( |
| | '--encoder', |
| | dest='encoder', |
| | help="'share', 'separate' or 'stack'. The way to encode context frames." |
| | ) |
| | group.add_argument( |
| | '--no_bn', |
| | dest='bn', |
| | help="Disable batch norm everywhere.", |
| | action="store_false") |
| | group.add_argument( |
| | '--num_filter', |
| | dest='num_filter', |
| | help="Set the base number of filters.", |
| | type=int) |
| |
|
| |
|
| | def parse_model_args(args): |
| | if args.use_2d: |
| | cfg.MODEL.DECONVBASELINE.USE_3D = not args.use_2d |
| | if args.encoder: |
| | assert args.encoder in ["concat", "shared", "separate"] |
| | cfg.MODEL.DECONVBASELINE.ENCODER = args.encoder |
| | if args.bn: |
| | cfg.MODEL.DECONVBASELINE.BN = args.bn |
| | if args.num_filter: |
| | cfg.MODEL.DECONVBASELINE.BASE_NUM_FILTER = args.num_filter |
| |
|
| |
|
| | def get_base_dir(args): |
| | if args.save_dir: |
| | return args.save_dir |
| |
|
| | return "conv2d" if not cfg.MODEL.DECONVBASELINE.USE_3D else "conv3d" |
| |
|
| |
|
| | |
| | def train_step(generator_net, |
| | loss_net, |
| | context_nd, |
| | gt_nd, |
| | mask_nd=None): |
| | """Fine-tune the encoder and forecaster for one step |
| | |
| | Args: |
| | generator_net |
| | loss_net |
| | context_nd |
| | gt_nd |
| | |
| | """ |
| | |
| | generator_net.forward( |
| | is_train=True, data_batch=mx.io.DataBatch(data=[context_nd])) |
| | generator_outputs = dict( |
| | zip(generator_net.output_names, generator_net.get_outputs())) |
| | pred_nd = generator_outputs["pred_output"] |
| | |
| | loss_net.forward_backward(data_batch=mx.io.DataBatch( |
| | data=[gt_nd, pred_nd] |
| | if mask_nd is None else [gt_nd, pred_nd, mask_nd])) |
| | loss_input_grads = dict( |
| | zip(loss_net.data_names, loss_net.get_input_grads())) |
| | pred_grad = loss_input_grads["pred"] |
| | loss_out = dict(zip(loss_net.output_names, loss_net.get_outputs())) |
| | avg_l2 = float(loss_out["mse_output"].asnumpy()) |
| | avg_real_mse = float(loss_out["real_mse_output"].asnumpy()) |
| | |
| | generator_net.backward(out_grads=[pred_grad]) |
| | |
| | generator_grad_norm = generator_net.clip_by_global_norm( |
| | max_norm=cfg.MODEL.TRAIN.GRAD_CLIP) |
| | generator_net.update() |
| | |
| | return generator_outputs["forecast_target_output"],\ |
| | avg_l2, avg_real_mse, generator_grad_norm |
| |
|
| |
|
| | |
| | def test_step(generator_net, context_nd): |
| | """Returns generated frames. |
| | |
| | Returns: |
| | shape=(cfg.MODEL.TRAIN.BATCH_SIZE, cfg.MOVINGMNIST.TESTING_LEN, 1, |
| | cfg.MOVINGMNIST.IMG_SIZE, cfg.MOVINGMNIST.IMG_SIZE)) |
| | """ |
| | if cfg.DATASET != "MOVINGMNIST": |
| | raise NotImplementedError |
| |
|
| | if cfg.MOVINGMNIST.OUT_LEN == 1: |
| | frames = np.empty( |
| | shape=(cfg.MOVINGMNIST.TESTING_LEN, cfg.MODEL.TRAIN.BATCH_SIZE, 1, |
| | cfg.MOVINGMNIST.IMG_SIZE, cfg.MOVINGMNIST.IMG_SIZE)) |
| |
|
| | for frame_num in range(cfg.MOVINGMNIST.TESTING_LEN): |
| | |
| | generator_net.forward( |
| | data_batch=mx.io.DataBatch(data=[context_nd]), is_train=False) |
| | generator_outputs = dict( |
| | zip(generator_net.output_names, generator_net.get_outputs())) |
| | pred_nd = generator_outputs["pred_output"] |
| | pred_np = pred_nd.asnumpy() |
| |
|
| | |
| | context_np = context_nd.asnumpy() |
| | context_np = np.roll(a=context_np, shift=-1, axis=2) |
| | context_np[:, :, -1, ] = pred_np[:, :, -1, ] |
| | context_nd = mx.nd.array(context_np) |
| |
|
| | |
| | frames[frame_num, ] = pred_np[:, :, -1, ] |
| |
|
| | return np.moveaxis(frames, 0, 1) |
| | else: |
| | generator_net.forward( |
| | data_batch=mx.io.DataBatch(data=[context_nd]), is_train=False) |
| | generator_outputs = dict( |
| | zip(generator_net.output_names, generator_net.get_outputs())) |
| | pred_nd = generator_outputs["pred_output"] |
| |
|
| | return pred_nd |
| |
|