| import os |
| import json |
| import datetime |
| import torchvision |
| import numpy as np |
| import torch |
|
|
| from omegaconf import OmegaConf |
| from PIL import Image |
| from torch.utils.data import DataLoader, Dataset |
| from torchvision.datasets import ImageFolder |
| from torchvision import transforms |
| from torchvision.transforms.functional import hflip |
| from accelerate.logging import get_logger |
| from safetensors.torch import load_file |
| from .sampler_utils import get_train_sampler |
|
|
|
|
| logger = get_logger(__name__, log_level="INFO") |
|
|
|
|
| def center_crop_arr(pil_image, image_size): |
| """ |
| Center cropping implementation from ADM. |
| https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 |
| """ |
| while min(*pil_image.size) >= 2 * image_size: |
| pil_image = pil_image.resize( |
| tuple(x // 2 for x in pil_image.size), resample=Image.Resampling.BOX |
| ) |
|
|
| scale = image_size / min(*pil_image.size) |
| pil_image = pil_image.resize( |
| tuple(round(x * scale) for x in pil_image.size), resample=Image.Resampling.BICUBIC |
| ) |
|
|
| arr = np.array(pil_image) |
| crop_y = (arr.shape[0] - image_size) // 2 |
| crop_x = (arr.shape[1] - image_size) // 2 |
| return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]) |
|
|
| class ImagenetDictWrapper(Dataset): |
| def __init__(self, dataset): |
| super().__init__() |
| self.dataset = dataset |
|
|
| def __getitem__(self, i): |
| x, y = self.dataset[i] |
| return {"image": x, "label": y} |
|
|
| def __len__(self): |
| return len(self.dataset) |
|
|
| class ImagenetLatentDataset(Dataset): |
| def __init__(self, latent_dir, image_dir, image_size): |
| super().__init__() |
| self.RandomHorizontalFlipProb = 0.5 |
| self.transform = transforms.Compose([ |
| transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, image_size)), |
| transforms.Lambda(lambda pil_image: (pil_image, hflip(pil_image))), |
| transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])), |
| transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) |
| ]) |
| |
| self.dataset = [] |
| for class_folder in os.listdir(image_dir): |
| if os.path.isfile(os.path.join(image_dir, class_folder)): |
| continue |
| latent_class_folder = os.path.join(latent_dir, class_folder) |
| image_class_folder = os.path.join(image_dir, class_folder) |
| for file in os.listdir(image_class_folder): |
| self.dataset.append( |
| dict( |
| latent=os.path.join(latent_class_folder, file.split('.')[0]+'.safetensors'), |
| image=os.path.join(image_class_folder, file) |
| ) |
| ) |
| |
| def __len__(self): |
| return len(self.dataset) |
|
|
| def __getitem__(self, idx): |
| data_item = dict() |
| data = load_file(self.dataset[idx]['latent']) |
| image = self.transform(Image.open(self.dataset[idx]['image']).convert("RGB")) |
| if torch.rand(1) < self.RandomHorizontalFlipProb: |
| data_item['latent'] = data['latent'][0] |
| data_item['image'] = image[0] |
| else: |
| data_item['latent'] = data['latent'][1] |
| data_item['image'] = image[1] |
| data_item['label'] = data['label'] |
| return data_item |
|
|
|
|
|
|
| class C2ILoader(): |
| def __init__(self, data_config): |
| super().__init__() |
|
|
| self.batch_size = data_config.dataloader.batch_size |
| self.num_workers = data_config.dataloader.num_workers |
|
|
| self.data_type = data_config.data_type |
| |
| if data_config.data_type == 'image': |
| self.train_dataset = ImagenetDictWrapper(**OmegaConf.to_container(data_config.dataset)) |
| elif data_config.data_type == 'latent': |
| self.train_dataset = ImagenetLatentDataset(**OmegaConf.to_container(data_config.dataset)) |
| else: |
| raise NotImplementedError |
| |
| |
| self.test_dataset = None |
| self.val_dataset = None |
|
|
| def train_len(self): |
| return len(self.train_dataset) |
|
|
| def train_dataloader(self, rank, world_size, global_batch_size, max_steps, resume_steps, seed): |
| |
| sampler = get_train_sampler( |
| self.train_dataset, rank, world_size, global_batch_size, max_steps, resume_steps, seed |
| ) |
| return DataLoader( |
| self.train_dataset, |
| batch_size=self.batch_size, |
| sampler=sampler, |
| num_workers=self.num_workers, |
| pin_memory=True, |
| drop_last=True, |
| prefetch_factor=2, |
| ) |
| |
| def test_dataloader(self): |
| return None |
|
|
| def val_dataloader(self): |
| return DataLoader( |
| self.train_dataset, |
| batch_size=self.batch_size, |
| shuffle=self.shuffle, |
| num_workers=self.num_workers, |
| pin_memory=True, |
| drop_last=True |
| ) |
|
|
|
|
|
|
|
|
|
|