| import itertools |
| import numpy as np |
| import matplotlib.pyplot as plt |
| import torch |
| from torch.nn import functional as F |
| import distutils.util |
|
|
| def show_result(num_epoch, G_net, imgs_lr, imgs_hr): |
| with torch.no_grad(): |
| test_images = G_net(imgs_lr) |
|
|
| fig, ax = plt.subplots(1, 3) |
|
|
| for j in itertools.product(range(3)): |
| ax[j].get_xaxis().set_visible(False) |
| ax[j].get_yaxis().set_visible(False) |
| ax[0].cla() |
| ax[0].imshow(np.transpose(np.clip(imgs_lr.cpu().numpy()[0] * 0.5 + 0.5, 0, 1), [1,2,0])) |
|
|
| ax[1].cla() |
| ax[1].imshow(np.transpose(np.clip(test_images.cpu().numpy()[0] * 0.5 + 0.5, 0, 1), [1,2,0])) |
|
|
| ax[2].cla() |
| ax[2].imshow(np.transpose(np.clip(imgs_hr.cpu().numpy()[0] * 0.5 + 0.5, 0, 1), [1,2,0])) |
|
|
| label = 'Epoch {0}'.format(num_epoch) |
| fig.text(0.5, 0.04, label, ha='center') |
| plt.savefig("results/train_out/epoch_" + str(num_epoch) + "_results.png") |
| plt.close('all') |
|
|
| |
| |
| |
| |
| def cvtColor(image): |
| if len(np.shape(image)) == 3 and np.shape(image)[2] == 3: |
| return image |
| else: |
| image = image.convert('RGB') |
| return image |
|
|
| def preprocess_input(image, mean, std): |
| image = (image/255 - mean)/std |
| return image |
|
|
| def get_lr(optimizer): |
| for param_group in optimizer.param_groups: |
| return param_group['lr'] |
|
|
| def print_arguments(args): |
| print("----------- Configuration Arguments -----------") |
| for arg, value in sorted(vars(args).items()): |
| print("%s: %s" % (arg, value)) |
| print("------------------------------------------------") |
|
|
|
|
| def add_arguments(argname, type, default, help, argparser, **kwargs): |
| type = distutils.util.strtobool if type == bool else type |
| argparser.add_argument("--" + argname, |
| default=default, |
| type=type, |
| help=help + ' 默认: %(default)s.', |
| **kwargs) |
|
|
|
|