File size: 3,834 Bytes
46a8d8a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
#from mpi4py import MPI
import os
import json
import tempfile
import numpy as np
import torch
import time
import subprocess
import torch.distributed as dist
def allreduce(x, average):
if mpi_size() > 1:
dist.all_reduce(x, dist.ReduceOp.SUM)
return x / mpi_size() if average else x
def get_cpu_stats_over_ranks(stat_dict):
keys = sorted(stat_dict.keys())
allreduced = allreduce(torch.stack([torch.as_tensor(stat_dict[k]).detach().cuda().float() for k in keys]), average=True).cpu()
return {k: allreduced[i].item() for (i, k) in enumerate(keys)}
class Hyperparams(dict):
def __getattr__(self, attr):
try:
return self[attr]
except KeyError:
return None
def __setattr__(self, attr, value):
self[attr] = value
def logger(log_prefix):
'Prints the arguments out to stdout, .txt, and .jsonl files'
jsonl_path = f'{log_prefix}.jsonl'
txt_path = f'{log_prefix}.txt'
def log(*args, pprint=False, **kwargs):
if mpi_rank() != 0:
return
t = time.ctime()
argdict = {'time': t}
if len(args) > 0:
argdict['message'] = ' '.join([str(x) for x in args])
argdict.update(kwargs)
txt_str = []
args_iter = sorted(argdict) if pprint else argdict
for k in args_iter:
val = argdict[k]
if isinstance(val, np.ndarray):
val = val.tolist()
elif isinstance(val, np.integer):
val = int(val)
elif isinstance(val, np.floating):
val = float(val)
argdict[k] = val
if isinstance(val, float):
val = f'{val:.5f}'
txt_str.append(f'{k}: {val}')
txt_str = ', '.join(txt_str)
if pprint:
json_str = json.dumps(argdict, sort_keys=True)
txt_str = json.dumps(argdict, sort_keys=True, indent=4)
else:
json_str = json.dumps(argdict)
print(txt_str, flush=True)
with open(txt_path, "a+") as f:
print(txt_str, file=f, flush=True)
with open(jsonl_path, "a+") as f:
print(json_str, file=f, flush=True)
return log
def maybe_download(path, filename=None):
'''If a path is a gsutil path, download it and return the local link,
otherwise return link'''
if not path.startswith('gs://'):
return path
if filename:
local_dest = f'/tmp/'
out_path = f'/tmp/{filename}'
if os.path.isfile(out_path):
return out_path
subprocess.check_output(['gsutil', '-m', 'cp', '-R', path, out_path])
return out_path
else:
local_dest = tempfile.mkstemp()[1]
subprocess.check_output(['gsutil', '-m', 'cp', path, local_dest])
return local_dest
def tile_images(images, d1=4, d2=4, border=1):
id1, id2, c = images[0].shape
out = np.ones([d1 * id1 + border * (d1 + 1),
d2 * id2 + border * (d2 + 1),
c], dtype=np.uint8)
out *= 255
if len(images) != d1 * d2:
raise ValueError('Wrong num of images')
for imgnum, im in enumerate(images):
num_d1 = imgnum // d2
num_d2 = imgnum % d2
start_d1 = num_d1 * id1 + border * (num_d1 + 1)
start_d2 = num_d2 * id2 + border * (num_d2 + 1)
out[start_d1:start_d1 + id1, start_d2:start_d2 + id2, :] = im
return out
def mpi_size():
return MPI.COMM_WORLD.Get_size()
def mpi_rank():
return MPI.COMM_WORLD.Get_rank()
def num_nodes():
nn = mpi_size()
if nn % 8 == 0:
return nn // 8
return nn // 8 + 1
def gpus_per_node():
size = mpi_size()
if size > 1:
return max(size // num_nodes(), 1)
return 1
def local_mpi_rank():
return mpi_rank() % gpus_per_node()
|