| | import os |
| | from io import BytesIO |
| | from pathlib import Path |
| |
|
| | import lmdb |
| | from PIL import Image |
| | from torch.utils.data import Dataset |
| | from torchvision import transforms |
| | from torchvision.datasets import CIFAR10, LSUNClass |
| | import torch |
| | import pandas as pd |
| |
|
| | import torchvision.transforms.functional as Ftrans |
| |
|
| |
|
| | class ImageDataset(Dataset): |
| | def __init__( |
| | self, |
| | folder, |
| | image_size, |
| | exts=['jpg'], |
| | do_augment: bool = True, |
| | do_transform: bool = True, |
| | do_normalize: bool = True, |
| | sort_names=False, |
| | has_subdir: bool = True, |
| | ): |
| | super().__init__() |
| | self.folder = folder |
| | self.image_size = image_size |
| |
|
| | |
| | if has_subdir: |
| | self.paths = [ |
| | p.relative_to(folder) for ext in exts |
| | for p in Path(f'{folder}').glob(f'**/*.{ext}') |
| | ] |
| | else: |
| | self.paths = [ |
| | p.relative_to(folder) for ext in exts |
| | for p in Path(f'{folder}').glob(f'*.{ext}') |
| | ] |
| | if sort_names: |
| | self.paths = sorted(self.paths) |
| |
|
| | transform = [ |
| | transforms.Resize(image_size), |
| | transforms.CenterCrop(image_size), |
| | ] |
| | if do_augment: |
| | transform.append(transforms.RandomHorizontalFlip()) |
| | if do_transform: |
| | transform.append(transforms.ToTensor()) |
| | if do_normalize: |
| | transform.append( |
| | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))) |
| | self.transform = transforms.Compose(transform) |
| |
|
| | def __len__(self): |
| | return len(self.paths) |
| |
|
| | def __getitem__(self, index): |
| | path = os.path.join(self.folder, self.paths[index]) |
| | img = Image.open(path) |
| | |
| | img = img.convert('RGB') |
| | if self.transform is not None: |
| | img = self.transform(img) |
| | return {'img': img, 'index': index} |
| |
|
| |
|
| | class SubsetDataset(Dataset): |
| | def __init__(self, dataset, size): |
| | assert len(dataset) >= size |
| | self.dataset = dataset |
| | self.size = size |
| |
|
| | def __len__(self): |
| | return self.size |
| |
|
| | def __getitem__(self, index): |
| | assert index < self.size |
| | return self.dataset[index] |
| |
|
| |
|
| | class BaseLMDB(Dataset): |
| | def __init__(self, path, original_resolution, zfill: int = 5): |
| | self.original_resolution = original_resolution |
| | self.zfill = zfill |
| | self.env = lmdb.open( |
| | path, |
| | max_readers=32, |
| | readonly=True, |
| | lock=False, |
| | readahead=False, |
| | meminit=False, |
| | ) |
| |
|
| | if not self.env: |
| | raise IOError('Cannot open lmdb dataset', path) |
| |
|
| | with self.env.begin(write=False) as txn: |
| | self.length = int( |
| | txn.get('length'.encode('utf-8')).decode('utf-8')) |
| |
|
| | def __len__(self): |
| | return self.length |
| |
|
| | def __getitem__(self, index): |
| | with self.env.begin(write=False) as txn: |
| | key = f'{self.original_resolution}-{str(index).zfill(self.zfill)}'.encode( |
| | 'utf-8') |
| | img_bytes = txn.get(key) |
| |
|
| | buffer = BytesIO(img_bytes) |
| | img = Image.open(buffer) |
| | return img |
| |
|
| |
|
| | def make_transform( |
| | image_size, |
| | flip_prob=0.5, |
| | crop_d2c=False, |
| | ): |
| | if crop_d2c: |
| | transform = [ |
| | d2c_crop(), |
| | transforms.Resize(image_size), |
| | ] |
| | else: |
| | transform = [ |
| | transforms.Resize(image_size), |
| | transforms.CenterCrop(image_size), |
| | ] |
| | transform.append(transforms.RandomHorizontalFlip(p=flip_prob)) |
| | transform.append(transforms.ToTensor()) |
| | transform.append(transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))) |
| | transform = transforms.Compose(transform) |
| | return transform |
| |
|
| |
|
| | class FFHQlmdb(Dataset): |
| | def __init__(self, |
| | path=os.path.expanduser('datasets/ffhq256.lmdb'), |
| | image_size=256, |
| | original_resolution=256, |
| | split=None, |
| | as_tensor: bool = True, |
| | do_augment: bool = True, |
| | do_normalize: bool = True, |
| | **kwargs): |
| | self.original_resolution = original_resolution |
| | self.data = BaseLMDB(path, original_resolution, zfill=5) |
| | self.length = len(self.data) |
| |
|
| | if split is None: |
| | self.offset = 0 |
| | elif split == 'train': |
| | |
| | self.length = self.length - 10000 |
| | self.offset = 10000 |
| | elif split == 'test': |
| | |
| | self.length = 10000 |
| | self.offset = 0 |
| | else: |
| | raise NotImplementedError() |
| |
|
| | transform = [ |
| | transforms.Resize(image_size), |
| | ] |
| | if do_augment: |
| | transform.append(transforms.RandomHorizontalFlip()) |
| | if as_tensor: |
| | transform.append(transforms.ToTensor()) |
| | if do_normalize: |
| | transform.append( |
| | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))) |
| | self.transform = transforms.Compose(transform) |
| |
|
| | def __len__(self): |
| | return self.length |
| |
|
| | def __getitem__(self, index): |
| | assert index < self.length |
| | index = index + self.offset |
| | img = self.data[index] |
| | if self.transform is not None: |
| | img = self.transform(img) |
| | return {'img': img, 'index': index} |
| |
|
| |
|
| | class Crop: |
| | def __init__(self, x1, x2, y1, y2): |
| | self.x1 = x1 |
| | self.x2 = x2 |
| | self.y1 = y1 |
| | self.y2 = y2 |
| |
|
| | def __call__(self, img): |
| | return Ftrans.crop(img, self.x1, self.y1, self.x2 - self.x1, |
| | self.y2 - self.y1) |
| |
|
| | def __repr__(self): |
| | return self.__class__.__name__ + "(x1={}, x2={}, y1={}, y2={})".format( |
| | self.x1, self.x2, self.y1, self.y2) |
| |
|
| |
|
| | def d2c_crop(): |
| | |
| | cx = 89 |
| | cy = 121 |
| | x1 = cy - 64 |
| | x2 = cy + 64 |
| | y1 = cx - 64 |
| | y2 = cx + 64 |
| | return Crop(x1, x2, y1, y2) |
| |
|
| |
|
| | class CelebAlmdb(Dataset): |
| | """ |
| | also supports for d2c crop. |
| | """ |
| | def __init__(self, |
| | path, |
| | image_size, |
| | original_resolution=128, |
| | split=None, |
| | as_tensor: bool = True, |
| | do_augment: bool = True, |
| | do_normalize: bool = True, |
| | crop_d2c: bool = False, |
| | **kwargs): |
| | self.original_resolution = original_resolution |
| | self.data = BaseLMDB(path, original_resolution, zfill=7) |
| | self.length = len(self.data) |
| | self.crop_d2c = crop_d2c |
| |
|
| | if split is None: |
| | self.offset = 0 |
| | else: |
| | raise NotImplementedError() |
| |
|
| | if crop_d2c: |
| | transform = [ |
| | d2c_crop(), |
| | transforms.Resize(image_size), |
| | ] |
| | else: |
| | transform = [ |
| | transforms.Resize(image_size), |
| | transforms.CenterCrop(image_size), |
| | ] |
| |
|
| | if do_augment: |
| | transform.append(transforms.RandomHorizontalFlip()) |
| | if as_tensor: |
| | transform.append(transforms.ToTensor()) |
| | if do_normalize: |
| | transform.append( |
| | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))) |
| | self.transform = transforms.Compose(transform) |
| |
|
| | def __len__(self): |
| | return self.length |
| |
|
| | def __getitem__(self, index): |
| | assert index < self.length |
| | index = index + self.offset |
| | img = self.data[index] |
| | if self.transform is not None: |
| | img = self.transform(img) |
| | return {'img': img, 'index': index} |
| |
|
| |
|
| | class Horse_lmdb(Dataset): |
| | def __init__(self, |
| | path=os.path.expanduser('datasets/horse256.lmdb'), |
| | image_size=128, |
| | original_resolution=256, |
| | do_augment: bool = True, |
| | do_transform: bool = True, |
| | do_normalize: bool = True, |
| | **kwargs): |
| | self.original_resolution = original_resolution |
| | print(path) |
| | self.data = BaseLMDB(path, original_resolution, zfill=7) |
| | self.length = len(self.data) |
| |
|
| | transform = [ |
| | transforms.Resize(image_size), |
| | transforms.CenterCrop(image_size), |
| | ] |
| | if do_augment: |
| | transform.append(transforms.RandomHorizontalFlip()) |
| | if do_transform: |
| | transform.append(transforms.ToTensor()) |
| | if do_normalize: |
| | transform.append( |
| | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))) |
| | self.transform = transforms.Compose(transform) |
| |
|
| | def __len__(self): |
| | return self.length |
| |
|
| | def __getitem__(self, index): |
| | img = self.data[index] |
| | if self.transform is not None: |
| | img = self.transform(img) |
| | return {'img': img, 'index': index} |
| |
|
| |
|
| | class Bedroom_lmdb(Dataset): |
| | def __init__(self, |
| | path=os.path.expanduser('datasets/bedroom256.lmdb'), |
| | image_size=128, |
| | original_resolution=256, |
| | do_augment: bool = True, |
| | do_transform: bool = True, |
| | do_normalize: bool = True, |
| | **kwargs): |
| | self.original_resolution = original_resolution |
| | print(path) |
| | self.data = BaseLMDB(path, original_resolution, zfill=7) |
| | self.length = len(self.data) |
| |
|
| | transform = [ |
| | transforms.Resize(image_size), |
| | transforms.CenterCrop(image_size), |
| | ] |
| | if do_augment: |
| | transform.append(transforms.RandomHorizontalFlip()) |
| | if do_transform: |
| | transform.append(transforms.ToTensor()) |
| | if do_normalize: |
| | transform.append( |
| | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))) |
| | self.transform = transforms.Compose(transform) |
| |
|
| | def __len__(self): |
| | return self.length |
| |
|
| | def __getitem__(self, index): |
| | img = self.data[index] |
| | img = self.transform(img) |
| | return {'img': img, 'index': index} |
| |
|
| |
|
| | class CelebAttrDataset(Dataset): |
| |
|
| | id_to_cls = [ |
| | '5_o_Clock_Shadow', 'Arched_Eyebrows', 'Attractive', 'Bags_Under_Eyes', |
| | 'Bald', 'Bangs', 'Big_Lips', 'Big_Nose', 'Black_Hair', 'Blond_Hair', |
| | 'Blurry', 'Brown_Hair', 'Bushy_Eyebrows', 'Chubby', 'Double_Chin', |
| | 'Eyeglasses', 'Goatee', 'Gray_Hair', 'Heavy_Makeup', 'High_Cheekbones', |
| | 'Male', 'Mouth_Slightly_Open', 'Mustache', 'Narrow_Eyes', 'No_Beard', |
| | 'Oval_Face', 'Pale_Skin', 'Pointy_Nose', 'Receding_Hairline', |
| | 'Rosy_Cheeks', 'Sideburns', 'Smiling', 'Straight_Hair', 'Wavy_Hair', |
| | 'Wearing_Earrings', 'Wearing_Hat', 'Wearing_Lipstick', |
| | 'Wearing_Necklace', 'Wearing_Necktie', 'Young' |
| | ] |
| | cls_to_id = {v: k for k, v in enumerate(id_to_cls)} |
| |
|
| | def __init__(self, |
| | folder, |
| | image_size=64, |
| | attr_path=os.path.expanduser( |
| | 'datasets/celeba_anno/list_attr_celeba.txt'), |
| | ext='png', |
| | only_cls_name: str = None, |
| | only_cls_value: int = None, |
| | do_augment: bool = False, |
| | do_transform: bool = True, |
| | do_normalize: bool = True, |
| | d2c: bool = False): |
| | super().__init__() |
| | self.folder = folder |
| | self.image_size = image_size |
| | self.ext = ext |
| |
|
| | |
| | paths = [ |
| | str(p.relative_to(folder)) |
| | for p in Path(f'{folder}').glob(f'**/*.{ext}') |
| | ] |
| | paths = [str(each).split('.')[0] + '.jpg' for each in paths] |
| |
|
| | if d2c: |
| | transform = [ |
| | d2c_crop(), |
| | transforms.Resize(image_size), |
| | ] |
| | else: |
| | transform = [ |
| | transforms.Resize(image_size), |
| | transforms.CenterCrop(image_size), |
| | ] |
| | if do_augment: |
| | transform.append(transforms.RandomHorizontalFlip()) |
| | if do_transform: |
| | transform.append(transforms.ToTensor()) |
| | if do_normalize: |
| | transform.append( |
| | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))) |
| | self.transform = transforms.Compose(transform) |
| |
|
| | with open(attr_path) as f: |
| | |
| | f.readline() |
| | self.df = pd.read_csv(f, delim_whitespace=True) |
| | self.df = self.df[self.df.index.isin(paths)] |
| |
|
| | if only_cls_name is not None: |
| | self.df = self.df[self.df[only_cls_name] == only_cls_value] |
| |
|
| | def pos_count(self, cls_name): |
| | return (self.df[cls_name] == 1).sum() |
| |
|
| | def neg_count(self, cls_name): |
| | return (self.df[cls_name] == -1).sum() |
| |
|
| | def __len__(self): |
| | return len(self.df) |
| |
|
| | def __getitem__(self, index): |
| | row = self.df.iloc[index] |
| | name = row.name.split('.')[0] |
| | name = f'{name}.{self.ext}' |
| |
|
| | path = os.path.join(self.folder, name) |
| | img = Image.open(path) |
| |
|
| | labels = [0] * len(self.id_to_cls) |
| | for k, v in row.items(): |
| | labels[self.cls_to_id[k]] = int(v) |
| |
|
| | if self.transform is not None: |
| | img = self.transform(img) |
| |
|
| | return {'img': img, 'index': index, 'labels': torch.tensor(labels)} |
| |
|
| |
|
| | class CelebD2CAttrDataset(CelebAttrDataset): |
| | """ |
| | the dataset is used in the D2C paper. |
| | it has a specific crop from the original CelebA. |
| | """ |
| | def __init__(self, |
| | folder, |
| | image_size=64, |
| | attr_path=os.path.expanduser( |
| | 'datasets/celeba_anno/list_attr_celeba.txt'), |
| | ext='jpg', |
| | only_cls_name: str = None, |
| | only_cls_value: int = None, |
| | do_augment: bool = False, |
| | do_transform: bool = True, |
| | do_normalize: bool = True, |
| | d2c: bool = True): |
| | super().__init__(folder, |
| | image_size, |
| | attr_path, |
| | ext=ext, |
| | only_cls_name=only_cls_name, |
| | only_cls_value=only_cls_value, |
| | do_augment=do_augment, |
| | do_transform=do_transform, |
| | do_normalize=do_normalize, |
| | d2c=d2c) |
| |
|
| |
|
| | class CelebAttrFewshotDataset(Dataset): |
| | def __init__( |
| | self, |
| | cls_name, |
| | K, |
| | img_folder, |
| | img_size=64, |
| | ext='png', |
| | seed=0, |
| | only_cls_name: str = None, |
| | only_cls_value: int = None, |
| | all_neg: bool = False, |
| | do_augment: bool = False, |
| | do_transform: bool = True, |
| | do_normalize: bool = True, |
| | d2c: bool = False, |
| | ) -> None: |
| | self.cls_name = cls_name |
| | self.K = K |
| | self.img_folder = img_folder |
| | self.ext = ext |
| |
|
| | if all_neg: |
| | path = f'data/celeba_fewshots/K{K}_allneg_{cls_name}_{seed}.csv' |
| | else: |
| | path = f'data/celeba_fewshots/K{K}_{cls_name}_{seed}.csv' |
| | self.df = pd.read_csv(path, index_col=0) |
| | if only_cls_name is not None: |
| | self.df = self.df[self.df[only_cls_name] == only_cls_value] |
| |
|
| | if d2c: |
| | transform = [ |
| | d2c_crop(), |
| | transforms.Resize(img_size), |
| | ] |
| | else: |
| | transform = [ |
| | transforms.Resize(img_size), |
| | transforms.CenterCrop(img_size), |
| | ] |
| | if do_augment: |
| | transform.append(transforms.RandomHorizontalFlip()) |
| | if do_transform: |
| | transform.append(transforms.ToTensor()) |
| | if do_normalize: |
| | transform.append( |
| | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))) |
| | self.transform = transforms.Compose(transform) |
| |
|
| | def pos_count(self, cls_name): |
| | return (self.df[cls_name] == 1).sum() |
| |
|
| | def neg_count(self, cls_name): |
| | return (self.df[cls_name] == -1).sum() |
| |
|
| | def __len__(self): |
| | return len(self.df) |
| |
|
| | def __getitem__(self, index): |
| | row = self.df.iloc[index] |
| | name = row.name.split('.')[0] |
| | name = f'{name}.{self.ext}' |
| |
|
| | path = os.path.join(self.img_folder, name) |
| | img = Image.open(path) |
| |
|
| | |
| | label = torch.tensor(int(row[self.cls_name])).unsqueeze(-1) |
| |
|
| | if self.transform is not None: |
| | img = self.transform(img) |
| |
|
| | return {'img': img, 'index': index, 'labels': label} |
| |
|
| |
|
| | class CelebD2CAttrFewshotDataset(CelebAttrFewshotDataset): |
| | def __init__(self, |
| | cls_name, |
| | K, |
| | img_folder, |
| | img_size=64, |
| | ext='jpg', |
| | seed=0, |
| | only_cls_name: str = None, |
| | only_cls_value: int = None, |
| | all_neg: bool = False, |
| | do_augment: bool = False, |
| | do_transform: bool = True, |
| | do_normalize: bool = True, |
| | is_negative=False, |
| | d2c: bool = True) -> None: |
| | super().__init__(cls_name, |
| | K, |
| | img_folder, |
| | img_size, |
| | ext=ext, |
| | seed=seed, |
| | only_cls_name=only_cls_name, |
| | only_cls_value=only_cls_value, |
| | all_neg=all_neg, |
| | do_augment=do_augment, |
| | do_transform=do_transform, |
| | do_normalize=do_normalize, |
| | d2c=d2c) |
| | self.is_negative = is_negative |
| |
|
| |
|
| | class CelebHQAttrDataset(Dataset): |
| | id_to_cls = [ |
| | '5_o_Clock_Shadow', 'Arched_Eyebrows', 'Attractive', 'Bags_Under_Eyes', |
| | 'Bald', 'Bangs', 'Big_Lips', 'Big_Nose', 'Black_Hair', 'Blond_Hair', |
| | 'Blurry', 'Brown_Hair', 'Bushy_Eyebrows', 'Chubby', 'Double_Chin', |
| | 'Eyeglasses', 'Goatee', 'Gray_Hair', 'Heavy_Makeup', 'High_Cheekbones', |
| | 'Male', 'Mouth_Slightly_Open', 'Mustache', 'Narrow_Eyes', 'No_Beard', |
| | 'Oval_Face', 'Pale_Skin', 'Pointy_Nose', 'Receding_Hairline', |
| | 'Rosy_Cheeks', 'Sideburns', 'Smiling', 'Straight_Hair', 'Wavy_Hair', |
| | 'Wearing_Earrings', 'Wearing_Hat', 'Wearing_Lipstick', |
| | 'Wearing_Necklace', 'Wearing_Necktie', 'Young' |
| | ] |
| | cls_to_id = {v: k for k, v in enumerate(id_to_cls)} |
| |
|
| | def __init__(self, |
| | path=os.path.expanduser('datasets/celebahq256.lmdb'), |
| | image_size=None, |
| | attr_path=os.path.expanduser( |
| | 'datasets/celeba_anno/CelebAMask-HQ-attribute-anno.txt'), |
| | original_resolution=256, |
| | do_augment: bool = False, |
| | do_transform: bool = True, |
| | do_normalize: bool = True): |
| | super().__init__() |
| | self.image_size = image_size |
| | self.data = BaseLMDB(path, original_resolution, zfill=5) |
| |
|
| | transform = [ |
| | transforms.Resize(image_size), |
| | transforms.CenterCrop(image_size), |
| | ] |
| | if do_augment: |
| | transform.append(transforms.RandomHorizontalFlip()) |
| | if do_transform: |
| | transform.append(transforms.ToTensor()) |
| | if do_normalize: |
| | transform.append( |
| | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))) |
| | self.transform = transforms.Compose(transform) |
| |
|
| | with open(attr_path) as f: |
| | |
| | f.readline() |
| | self.df = pd.read_csv(f, delim_whitespace=True) |
| |
|
| | def pos_count(self, cls_name): |
| | return (self.df[cls_name] == 1).sum() |
| |
|
| | def neg_count(self, cls_name): |
| | return (self.df[cls_name] == -1).sum() |
| |
|
| | def __len__(self): |
| | return len(self.df) |
| |
|
| | def __getitem__(self, index): |
| | row = self.df.iloc[index] |
| | img_name = row.name |
| | img_idx, ext = img_name.split('.') |
| | img = self.data[img_idx] |
| |
|
| | labels = [0] * len(self.id_to_cls) |
| | for k, v in row.items(): |
| | labels[self.cls_to_id[k]] = int(v) |
| |
|
| | if self.transform is not None: |
| | img = self.transform(img) |
| | return {'img': img, 'index': index, 'labels': torch.tensor(labels)} |
| |
|
| |
|
| | class CelebHQAttrFewshotDataset(Dataset): |
| | def __init__(self, |
| | cls_name, |
| | K, |
| | path, |
| | image_size, |
| | original_resolution=256, |
| | do_augment: bool = False, |
| | do_transform: bool = True, |
| | do_normalize: bool = True): |
| | super().__init__() |
| | self.image_size = image_size |
| | self.cls_name = cls_name |
| | self.K = K |
| | self.data = BaseLMDB(path, original_resolution, zfill=5) |
| |
|
| | transform = [ |
| | transforms.Resize(image_size), |
| | transforms.CenterCrop(image_size), |
| | ] |
| | if do_augment: |
| | transform.append(transforms.RandomHorizontalFlip()) |
| | if do_transform: |
| | transform.append(transforms.ToTensor()) |
| | if do_normalize: |
| | transform.append( |
| | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))) |
| | self.transform = transforms.Compose(transform) |
| |
|
| | self.df = pd.read_csv(f'data/celebahq_fewshots/K{K}_{cls_name}.csv', |
| | index_col=0) |
| |
|
| | def pos_count(self, cls_name): |
| | return (self.df[cls_name] == 1).sum() |
| |
|
| | def neg_count(self, cls_name): |
| | return (self.df[cls_name] == -1).sum() |
| |
|
| | def __len__(self): |
| | return len(self.df) |
| |
|
| | def __getitem__(self, index): |
| | row = self.df.iloc[index] |
| | img_name = row.name |
| | img_idx, ext = img_name.split('.') |
| | img = self.data[img_idx] |
| |
|
| | |
| | label = torch.tensor(int(row[self.cls_name])).unsqueeze(-1) |
| |
|
| | if self.transform is not None: |
| | img = self.transform(img) |
| |
|
| | return {'img': img, 'index': index, 'labels': label} |
| |
|
| |
|
| | class Repeat(Dataset): |
| | def __init__(self, dataset, new_len) -> None: |
| | super().__init__() |
| | self.dataset = dataset |
| | self.original_len = len(dataset) |
| | self.new_len = new_len |
| |
|
| | def __len__(self): |
| | return self.new_len |
| |
|
| | def __getitem__(self, index): |
| | index = index % self.original_len |
| | return self.dataset[index] |
| |
|