| | |
| |
|
| | import sys |
| | import os |
| | import torch |
| |
|
| | |
| | root_path = os.path.abspath('.') |
| | sys.path.append(root_path) |
| | from opt import opt |
| | from architecture.grl import GRL |
| | from architecture.discriminator import UNetDiscriminatorSN, MultiScaleDiscriminator |
| | from train_code.train_master import train_master |
| |
|
| |
|
| |
|
| | class train_grlgan(train_master): |
| | def __init__(self, options, args) -> None: |
| | super().__init__(options, args, "grlgan", True) |
| |
|
| |
|
| | def loss_init(self): |
| |
|
| | |
| | self.pixel_loss_load() |
| |
|
| | |
| | self.GAN_loss_load() |
| |
|
| |
|
| | def call_model(self): |
| | |
| | patch_size = 144 |
| | if opt['model_size'] == "small": |
| | |
| | self.generator = GRL( |
| | upscale = opt['scale'], |
| | img_size = patch_size, |
| | window_size = 8, |
| | depths = [4, 4, 4, 4], |
| | embed_dim = 128, |
| | num_heads_window = [2, 2, 2, 2], |
| | num_heads_stripe = [2, 2, 2, 2], |
| | mlp_ratio = 2, |
| | qkv_proj_type = "linear", |
| | anchor_proj_type = "avgpool", |
| | anchor_window_down_factor = 2, |
| | out_proj_type = "linear", |
| | conv_type = "1conv", |
| | upsampler = "pixelshuffle", |
| | ).cuda() |
| |
|
| | elif opt['model_size'] == "tiny": |
| | |
| | self.generator = GRL( |
| | upscale = opt['scale'], |
| | img_size = 64, |
| | window_size = 8, |
| | depths = [4, 4, 4, 4], |
| | embed_dim = 64, |
| | num_heads_window = [2, 2, 2, 2], |
| | num_heads_stripe = [2, 2, 2, 2], |
| | mlp_ratio = 2, |
| | qkv_proj_type = "linear", |
| | anchor_proj_type = "avgpool", |
| | anchor_window_down_factor = 2, |
| | out_proj_type = "linear", |
| | conv_type = "1conv", |
| | upsampler = "pixelshuffledirect", |
| | ).cuda() |
| |
|
| | elif opt['model_size'] == "tiny2": |
| | |
| | self.generator = GRL( |
| | upscale = opt['scale'], |
| | img_size = 64, |
| | window_size = 8, |
| | depths = [4, 4, 4, 4], |
| | embed_dim = 64, |
| | num_heads_window = [2, 2, 2, 2], |
| | num_heads_stripe = [2, 2, 2, 2], |
| | mlp_ratio = 2, |
| | qkv_proj_type = "linear", |
| | anchor_proj_type = "avgpool", |
| | anchor_window_down_factor = 2, |
| | out_proj_type = "linear", |
| | conv_type = "1conv", |
| | upsampler = "nearest+conv", |
| | ).cuda() |
| | |
| | else: |
| | raise NotImplementedError("We don't support such model size in GRL model") |
| | |
| |
|
| | |
| | if opt['discriminator_type'] == "PatchDiscriminator": |
| | self.discriminator = MultiScaleDiscriminator(3).cuda() |
| | elif opt['discriminator_type'] == "UNetDiscriminator": |
| | self.discriminator = UNetDiscriminatorSN(3).cuda() |
| | |
| | self.generator.train(); self.discriminator.train() |
| |
|
| |
|
| | def run(self): |
| | self.master_run() |
| | |
| |
|
| |
|
| | def calculate_loss(self, gen_hr, imgs_hr): |
| |
|
| | |
| | |
| | l_g_pix = self.cri_pix(gen_hr, imgs_hr) |
| | self.generator_loss += l_g_pix |
| | self.weight_store["pixel_loss"] = l_g_pix |
| |
|
| |
|
| | |
| | l_g_percep_danbooru = self.cri_danbooru_perceptual(gen_hr, imgs_hr) |
| | l_g_percep_vgg = self.cri_vgg_perceptual(gen_hr, imgs_hr) |
| | l_g_percep = l_g_percep_danbooru + l_g_percep_vgg |
| | self.generator_loss += l_g_percep |
| | self.weight_store["perceptual_loss"] = l_g_percep |
| |
|
| |
|
| | |
| | fake_g_preds = self.discriminator(gen_hr) |
| | l_g_gan = self.cri_gan(fake_g_preds, True, is_disc=False) |
| | self.generator_loss += l_g_gan |
| | self.weight_store["gan_loss"] = l_g_gan |
| |
|
| |
|
| | def tensorboard_report(self, iteration): |
| | self.writer.add_scalar('Loss/train-Generator_Loss-Iteration', self.generator_loss, iteration) |
| | self.writer.add_scalar('Loss/train-Pixel_Loss-Iteration', self.weight_store["pixel_loss"], iteration) |
| | self.writer.add_scalar('Loss/train-Perceptual_Loss-Iteration', self.weight_store["perceptual_loss"], iteration) |
| | self.writer.add_scalar('Loss/train-Discriminator_Loss-Iteration', self.weight_store["gan_loss"], iteration) |
| |
|
| |
|