| | import argparse |
| | import copy |
| |
|
| | import warnings |
| | import tensorflow as tf |
| | tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) |
| | import warnings |
| | warnings.filterwarnings('ignore', category=FutureWarning) |
| | warnings.filterwarnings('ignore', category=DeprecationWarning) |
| | import sys, getopt, os |
| |
|
| | import numpy as np |
| | import dnnlib |
| | from dnnlib import EasyDict |
| | import dnnlib.tflib as tflib |
| | from dnnlib.tflib import tfutil |
| | from dnnlib.tflib.autosummary import autosummary |
| |
|
| | from training import misc |
| | import pickle |
| | import argparse |
| |
|
| | def create_model(config_id = 'config-f', gamma = None, height = 512, width = 512, cond = None, label_size = 0): |
| | train = EasyDict(run_func_name='training.diagnostic.create_initial_pkl') |
| | G = EasyDict(func_name='training.networks_stylegan2.G_main') |
| | D = EasyDict(func_name='training.networks_stylegan2.D_stylegan2') |
| | D_loss = EasyDict(func_name='training.loss.D_logistic_r1') |
| | sched = EasyDict() |
| | sc = dnnlib.SubmitConfig() |
| | tf_config = {'rnd.np_random_seed': 1000} |
| |
|
| | sched.minibatch_size_base = 192 |
| | sched.minibatch_gpu_base = 3 |
| | D_loss.gamma = 10 |
| | desc = 'stylegan2' |
| |
|
| | dataset_args = EasyDict() |
| |
|
| | if cond: |
| | desc += '-cond'; dataset_args.max_label_size = 'full' |
| |
|
| | desc += '-' + config_id |
| |
|
| | |
| | if config_id != 'config-f': |
| | G.fmap_base = D.fmap_base = 8 << 10 |
| |
|
| | |
| | if config_id.startswith('config-e'): |
| | D_loss.gamma = 100 |
| | if 'Gorig' in config_id: G.architecture = 'orig' |
| | if 'Gskip' in config_id: G.architecture = 'skip' |
| | if 'Gresnet' in config_id: G.architecture = 'resnet' |
| | if 'Dorig' in config_id: D.architecture = 'orig' |
| | if 'Dskip' in config_id: D.architecture = 'skip' |
| | if 'Dresnet' in config_id: D.architecture = 'resnet' |
| |
|
| | |
| | if config_id in ['config-a', 'config-b', 'config-c', 'config-d']: |
| | sched.lod_initial_resolution = 8 |
| | sched.G_lrate_base = sched.D_lrate_base = 0.001 |
| | sched.G_lrate_dict = sched.D_lrate_dict = {128: 0.0015, 256: 0.002, 512: 0.003, 1024: 0.003} |
| | sched.minibatch_size_base = 32 |
| | sched.minibatch_size_dict = {8: 256, 16: 128, 32: 64, 64: 32} |
| | sched.minibatch_gpu_base = 4 |
| | sched.minibatch_gpu_dict = {8: 32, 16: 16, 32: 8, 64: 4} |
| | G.synthesis_func = 'G_synthesis_stylegan_revised' |
| | D.func_name = 'training.networks_stylegan2.D_stylegan' |
| |
|
| | |
| | if config_id in ['config-a', 'config-b', 'config-c']: |
| | G_loss = EasyDict(func_name='training.loss.G_logistic_ns') |
| |
|
| | |
| | if config_id in ['config-a', 'config-b']: |
| | train.lazy_regularization = False |
| |
|
| | |
| | if config_id == 'config-a': |
| | G = EasyDict(func_name='training.networks_stylegan.G_style') |
| | D = EasyDict(func_name='training.networks_stylegan.D_basic') |
| |
|
| | if gamma is not None: |
| | D_loss.gamma = gamma |
| |
|
| | G.update(resolution_h=height) |
| | G.update(resolution_w=width) |
| | D.update(resolution_h=height) |
| | D.update(resolution_w=width) |
| |
|
| | sc.submit_target = dnnlib.SubmitTarget.DIAGNOSTIC |
| | sc.local.do_not_copy_source_files = True |
| | kwargs = EasyDict(train) |
| | |
| | kwargs.update(G_args=G, D_args=D, tf_config=tf_config, config_id=config_id, |
| | resolution_h=height, resolution_w=width, label_size = label_size) |
| | kwargs.submit_config = copy.deepcopy(sc) |
| | kwargs.submit_config.run_desc = desc |
| | dnnlib.submit_diagnostic(**kwargs) |
| | return f'network-initial-config-f-{height}x{width}-{label_size}.pkl' |
| |
|
| | def _str_to_bool(v): |
| | if isinstance(v, bool): |
| | return v |
| | if v.lower() in ('yes', 'true', 't', 'y', '1'): |
| | return True |
| | elif v.lower() in ('no', 'false', 'f', 'n', '0'): |
| | return False |
| | else: |
| | raise argparse.ArgumentTypeError('Boolean value expected.') |
| |
|
| | def _parse_comma_sep(s): |
| | if s is None or s.lower() == 'none' or s == '': |
| | return [] |
| | return s.split(',') |
| |
|
| | def copy_weights(source_pkl, target_pkl, output_pkl): |
| |
|
| | tflib.init_tf() |
| |
|
| | with tf.Session(): |
| | with tf.device('/gpu:0'): |
| |
|
| | sourceG, sourceD, sourceGs = pickle.load(open(source_pkl, 'rb')) |
| | targetG, targetD, targetGs = pickle.load(open(target_pkl, 'rb')) |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | targetG.copy_compatible_trainables_from(sourceG) |
| | targetD.copy_compatible_trainables_from(sourceD) |
| | targetGs.copy_compatible_trainables_from(sourceGs) |
| |
|
| | with open(os.path.join('./', output_pkl), 'wb') as file: |
| | pickle.dump((targetG, targetD, targetGs), file, protocol=pickle.HIGHEST_PROTOCOL) |