| | import math |
| | import random |
| |
|
| | from PIL import Image |
| | import blobfile as bf |
| | from mpi4py import MPI |
| | import numpy as np |
| | from torch.utils.data import DataLoader, Dataset |
| |
|
| |
|
| | def load_data( |
| | *, |
| | data_dir, |
| | batch_size, |
| | image_size, |
| | class_cond=False, |
| | deterministic=False, |
| | random_crop=False, |
| | random_flip=True, |
| | ): |
| | """ |
| | For a dataset, create a generator over (images, kwargs) pairs. |
| | |
| | Each images is an NCHW float tensor, and the kwargs dict contains zero or |
| | more keys, each of which map to a batched Tensor of their own. |
| | The kwargs dict can be used for class labels, in which case the key is "y" |
| | and the values are integer tensors of class labels. |
| | |
| | :param data_dir: a dataset directory. |
| | :param batch_size: the batch size of each returned pair. |
| | :param image_size: the size to which images are resized. |
| | :param class_cond: if True, include a "y" key in returned dicts for class |
| | label. If classes are not available and this is true, an |
| | exception will be raised. |
| | :param deterministic: if True, yield results in a deterministic order. |
| | :param random_crop: if True, randomly crop the images for augmentation. |
| | :param random_flip: if True, randomly flip the images for augmentation. |
| | """ |
| | if not data_dir: |
| | raise ValueError("unspecified data directory") |
| | all_files = _list_image_files_recursively(data_dir) |
| | classes = None |
| | if class_cond: |
| | |
| | |
| | class_names = [bf.basename(path).split("_")[0] for path in all_files] |
| | sorted_classes = {x: i for i, x in enumerate(sorted(set(class_names)))} |
| | classes = [sorted_classes[x] for x in class_names] |
| | dataset = ImageDataset( |
| | image_size, |
| | all_files, |
| | classes=classes, |
| | shard=MPI.COMM_WORLD.Get_rank(), |
| | num_shards=MPI.COMM_WORLD.Get_size(), |
| | random_crop=random_crop, |
| | random_flip=random_flip, |
| | ) |
| | if deterministic: |
| | loader = DataLoader( |
| | dataset, batch_size=batch_size, shuffle=False, num_workers=1, drop_last=True |
| | ) |
| | else: |
| | loader = DataLoader( |
| | dataset, batch_size=batch_size, shuffle=True, num_workers=1, drop_last=True |
| | ) |
| | while True: |
| | yield from loader |
| |
|
| |
|
| | def _list_image_files_recursively(data_dir): |
| | results = [] |
| | for entry in sorted(bf.listdir(data_dir)): |
| | full_path = bf.join(data_dir, entry) |
| | ext = entry.split(".")[-1] |
| | if "." in entry and ext.lower() in ["jpg", "jpeg", "png", "gif"]: |
| | results.append(full_path) |
| | elif bf.isdir(full_path): |
| | results.extend(_list_image_files_recursively(full_path)) |
| | return results |
| |
|
| |
|
| | class ImageDataset(Dataset): |
| | def __init__( |
| | self, |
| | resolution, |
| | image_paths, |
| | classes=None, |
| | shard=0, |
| | num_shards=1, |
| | random_crop=False, |
| | random_flip=True, |
| | ): |
| | super().__init__() |
| | self.resolution = resolution |
| | self.local_images = image_paths[shard:][::num_shards] |
| | self.local_classes = None if classes is None else classes[shard:][::num_shards] |
| | self.random_crop = random_crop |
| | self.random_flip = random_flip |
| |
|
| | def __len__(self): |
| | return len(self.local_images) |
| |
|
| | def __getitem__(self, idx): |
| | path = self.local_images[idx] |
| | with bf.BlobFile(path, "rb") as f: |
| | pil_image = Image.open(f) |
| | pil_image.load() |
| | pil_image = pil_image.convert("RGB") |
| |
|
| | if self.random_crop: |
| | arr = random_crop_arr(pil_image, self.resolution) |
| | else: |
| | arr = center_crop_arr(pil_image, self.resolution) |
| |
|
| | if self.random_flip and random.random() < 0.5: |
| | arr = arr[:, ::-1] |
| |
|
| | arr = arr.astype(np.float32) / 127.5 - 1 |
| |
|
| | out_dict = {} |
| | if self.local_classes is not None: |
| | out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64) |
| | return np.transpose(arr, [2, 0, 1]), out_dict |
| |
|
| |
|
| | def center_crop_arr(pil_image, image_size): |
| | |
| | |
| | |
| | while min(*pil_image.size) >= 2 * image_size: |
| | pil_image = pil_image.resize( |
| | tuple(x // 2 for x in pil_image.size), resample=Image.BOX |
| | ) |
| |
|
| | scale = image_size / min(*pil_image.size) |
| | pil_image = pil_image.resize( |
| | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC |
| | ) |
| |
|
| | arr = np.array(pil_image) |
| | crop_y = (arr.shape[0] - image_size) // 2 |
| | crop_x = (arr.shape[1] - image_size) // 2 |
| | return arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size] |
| |
|
| |
|
| | def random_crop_arr(pil_image, image_size, min_crop_frac=0.8, max_crop_frac=1.0): |
| | min_smaller_dim_size = math.ceil(image_size / max_crop_frac) |
| | max_smaller_dim_size = math.ceil(image_size / min_crop_frac) |
| | smaller_dim_size = random.randrange(min_smaller_dim_size, max_smaller_dim_size + 1) |
| |
|
| | |
| | |
| | |
| | while min(*pil_image.size) >= 2 * smaller_dim_size: |
| | pil_image = pil_image.resize( |
| | tuple(x // 2 for x in pil_image.size), resample=Image.BOX |
| | ) |
| |
|
| | scale = smaller_dim_size / min(*pil_image.size) |
| | pil_image = pil_image.resize( |
| | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC |
| | ) |
| |
|
| | arr = np.array(pil_image) |
| | crop_y = random.randrange(arr.shape[0] - image_size + 1) |
| | crop_x = random.randrange(arr.shape[1] - image_size + 1) |
| | return arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size] |
| |
|