File size: 3,295 Bytes
46a8d8a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import torch
import numpy as np
#from mpi4py import MPI
import socket
import argparse
import os
import json
import subprocess
from hps import Hyperparams, parse_args_and_update_hparams, add_vae_arguments
from utils import (logger,
                   local_mpi_rank,
                   mpi_size,
                   maybe_download,
                   mpi_rank)
from data import mkdir_p
from contextlib import contextmanager
import torch.distributed as dist
#from apex.optimizers import FusedAdam as AdamW
from vae import VAE
from torch.nn.parallel.distributed import DistributedDataParallel
from train_helpers import restore_params

def set_up_hyperparams(s=None):
    H = Hyperparams()
    parser = argparse.ArgumentParser()
    parser = add_vae_arguments(parser)
    parse_args_and_update_hparams(H, parser, s=s)
    setup_mpi(H)
    setup_save_dirs(H)
    logprint = logger(H.logdir)
    for i, k in enumerate(sorted(H)):
        logprint(type='hparam', key=k, value=H[k])
    np.random.seed(H.seed)
    torch.manual_seed(H.seed)
    torch.cuda.manual_seed(H.seed)
    logprint('training model', H.desc, 'on', H.dataset)
    return H, logprint

def set_up_data(H):
    shift_loss = -127.5
    scale_loss = 1. / 127.5
    
    #trX, vaX, teX = imagenet64(H.data_root)
    H.image_size = 64
    H.image_channels = 3
    shift = -115.92961967
    scale = 1. / 69.37404
    
    #if H.test_eval:
    #    print('DOING TEST')
    #    eval_dataset = teX
    #else:
    #    eval_dataset = vaX

    shift = torch.tensor([shift]).cuda().view(1, 1, 1, 1)
    scale = torch.tensor([scale]).cuda().view(1, 1, 1, 1)
    shift_loss = torch.tensor([shift_loss]).cuda().view(1, 1, 1, 1)
    scale_loss = torch.tensor([scale_loss]).cuda().view(1, 1, 1, 1)
    
    #train_data = TensorDataset(torch.as_tensor(trX))
    #valid_data = TensorDataset(torch.as_tensor(eval_dataset))
    #untranspose = False

    def preprocess_func(x):
        nonlocal shift
        nonlocal scale
        nonlocal shift_loss
        nonlocal scale_loss
        'takes in a data example and returns the preprocessed input'
        'as well as the input processed for the loss'
        #untranspose = False
        #if untranspose:
        #    x[0] = x[0].permute(0, 2, 3, 1)
        inp = x.cuda(non_blocking=True).float()
        out = inp.clone()
        inp.add_(shift).mul_(scale)
        out.add_(shift_loss).mul_(scale_loss)
        return inp, out
    
    return H, preprocess_func

def load_vaes(H, logprint=None):

    ema_vae = VAE(H)
    if H.restore_ema_path:
        print(f'Restoring ema vae from {H.restore_ema_path}')
        restore_params(ema_vae, H.restore_ema_path, map_cpu=True, local_rank=H.local_rank, mpi_size=H.mpi_size)
    else:
        ema_vae.load_state_dict(vae.state_dict())
    ema_vae.requires_grad_(False)
    ema_vae = ema_vae.cuda(H.local_rank)

    #vae = DistributedDataParallel(vae, device_ids=[H.local_rank], output_device=H.local_rank)

    #if len(list(vae.named_parameters())) != len(list(vae.parameters())):
    #    raise ValueError('Some params are not named. Please name all params.')
    #total_params = 0
    #for name, p in vae.named_parameters():
    #    total_params += np.prod(p.shape)

    #print(total_params=total_params, readable=f'{total_params:,}')
    return ema_vae