| import io | |
| import IPython.display | |
| import PIL.Image | |
| import os | |
| from pprint import pformat | |
| import numpy as np | |
| def imgrid(imarray, cols=4, pad=1, padval=255, row_major=True): | |
| """Lays out a [N, H, W, C] image array as a single image grid.""" | |
| pad = int(pad) | |
| if pad < 0: | |
| raise ValueError('pad must be non-negative') | |
| cols = int(cols) | |
| assert cols >= 1 | |
| N, H, W, C = imarray.shape | |
| rows = N // cols + int(N % cols != 0) | |
| batch_pad = rows * cols - N | |
| assert batch_pad >= 0 | |
| post_pad = [batch_pad, pad, pad, 0] | |
| pad_arg = [[0, p] for p in post_pad] | |
| imarray = np.pad(imarray, pad_arg, 'constant', constant_values=padval) | |
| H += pad | |
| W += pad | |
| grid = (imarray | |
| .reshape(rows, cols, H, W, C) | |
| .transpose(0, 2, 1, 3, 4) | |
| .reshape(rows*H, cols*W, C)) | |
| if pad: | |
| grid = grid[:-pad, :-pad] | |
| return grid | |
| def interleave(*args): | |
| """Interleaves input arrays of the same shape along the batch axis.""" | |
| if not args: | |
| raise ValueError('At least one argument is required.') | |
| a0 = args[0] | |
| if any(a.shape != a0.shape for a in args): | |
| raise ValueError('All inputs must have the same shape.') | |
| if not a0.shape: | |
| raise ValueError('Inputs must have at least one axis.') | |
| out = np.transpose(args, [1, 0] + list(range(2, len(a0.shape) + 1))) | |
| out = out.reshape(-1, *a0.shape[1:]) | |
| return out | |
| def imshow(a, format='png', jpeg_fallback=True): | |
| """Displays an image in the given format.""" | |
| a = a.astype(np.uint8) | |
| data = io.BytesIO() | |
| PIL.Image.fromarray(a).save(data, format) | |
| im_data = data.getvalue() | |
| try: | |
| disp = IPython.display.display(IPython.display.Image(im_data)) | |
| except IOError: | |
| if jpeg_fallback and format != 'jpeg': | |
| print ('Warning: image was too large to display in format "{}"; ' | |
| 'trying jpeg instead.').format(format) | |
| return imshow(a, format='jpeg') | |
| else: | |
| raise | |
| return disp | |
| def image_to_uint8(x): | |
| """Converts [-1, 1] float array to [0, 255] uint8.""" | |
| x = np.asarray(x) | |
| x = (256. / 2.) * (x + 1.) | |
| x = np.clip(x, 0, 255) | |
| x = x.astype(np.uint8) | |
| return x | |