| | import numpy as np |
| | from torchvision import transforms |
| |
|
| | class CIFARTransform: |
| | MEAN = [0.5071, 0.4866, 0.4409] |
| | STD = [0.2675, 0.2565, 0.2761] |
| | |
| | common_trfs = [transforms.ToTensor(), |
| | transforms.Normalize(mean=MEAN, std=STD)] |
| | |
| | resnet_train_transform = transforms.Compose([ |
| | transforms.RandomCrop(32, padding=4), |
| | transforms.RandomHorizontalFlip(), |
| | transforms.ColorJitter(brightness=63 / 255), |
| | *common_trfs |
| | ]) |
| | |
| | resnet_test_transform = transforms.Compose([*common_trfs]) |
| | |
| | |
| | |
| |
|
| | |
| | dset_mean = (0., 0., 0.) |
| | dset_std = (1., 1., 1.) |
| | vit_train_transform = transforms.Compose([ |
| | transforms.RandomResizedCrop(224), |
| | transforms.RandomHorizontalFlip(), |
| | transforms.ToTensor(), |
| | transforms.Normalize(dset_mean, dset_std)]) |
| | |
| | vit_test_transform = transforms.Compose([ |
| | transforms.Resize(224), |
| | transforms.ToTensor(), |
| | transforms.Normalize(dset_mean, dset_std)]) |
| |
|
| | |
| | mean=[x/255 for x in [125.3,123.0,113.9]] |
| | std=[x/255 for x in [63.0,62.1,66.7]] |
| |
|
| | alexnet_train_transform = transforms.Compose([ |
| | transforms.ToTensor(), |
| | transforms.Normalize(mean,std)]) |
| |
|
| | alexnet_test_transform = transforms.Compose([ |
| | transforms.ToTensor(), |
| | transforms.Normalize(mean,std)]) |
| |
|
| | @staticmethod |
| | def get_transform(model_type, mode): |
| | if model_type == 'resnet': |
| | if mode == 'train': |
| | return CIFARTransform.resnet_train_transform |
| | elif mode == 'test': |
| | return CIFARTransform.resnet_test_transform |
| | elif model_type == 'vit': |
| | if mode == 'train': |
| | return CIFARTransform.vit_train_transform |
| | elif mode == 'test': |
| | return CIFARTransform.vit_test_transform |
| | elif model_type == 'alexnet': |
| | if mode == 'train': |
| | return CIFARTransform.alexnet_train_transform |
| | elif mode == 'test': |
| | return CIFARTransform.alexnet_test_transform |
| | else: |
| | raise ValueError("Unsupported model type") |
| | |
| | class ImageNetTransform: |
| | MEAN=[0.4914, 0.4822, 0.4465] |
| | STD=[0.2023, 0.1994, 0.2010] |
| | |
| | common_trfs = [transforms.ToTensor(), |
| | transforms.Normalize(mean=MEAN, std=STD)] |
| | |
| | resnet_train_transform = transforms.Compose([ |
| | transforms.RandomResizedCrop(224), |
| | transforms.RandomHorizontalFlip(), |
| | transforms.ColorJitter(brightness=63 / 255), |
| | *common_trfs |
| | ]) |
| | |
| | resnet_test_transform = transforms.Compose([ |
| | transforms.Resize(256), |
| | transforms.CenterCrop(224), |
| | *common_trfs |
| | ]) |
| | |
| | |
| | dset_mean = (0., 0., 0.) |
| | dset_std = (1., 1., 1.) |
| | vit_train_transform = transforms.Compose([ |
| | transforms.RandomResizedCrop(224), |
| | transforms.RandomHorizontalFlip(), |
| | transforms.ToTensor(), |
| | transforms.Normalize(dset_mean, dset_std), |
| | ]) |
| | |
| | vit_test_transform = transforms.Compose([ |
| | transforms.Resize(256), |
| | transforms.CenterCrop(224), |
| | transforms.ToTensor(), |
| | transforms.Normalize(dset_mean, dset_std), |
| | ]) |
| | |
| | @staticmethod |
| | def get_transform(model_type, mode): |
| | if model_type == 'resnet': |
| | if mode == 'train': |
| | return ImageNetTransform.resnet_train_transform |
| | elif mode == 'test': |
| | return ImageNetTransform.resnet_test_transform |
| | elif model_type == 'vit': |
| | if mode == 'train': |
| | return ImageNetTransform.vit_train_transform |
| | elif mode == 'test': |
| | return ImageNetTransform.vit_test_transform |
| | else: |
| | raise ValueError("Unsupported model type") |
| | |
| | class ImageNetRTransform: |
| | mean = [0.4914, 0.4822, 0.4465] |
| | std = [0.2023, 0.1994, 0.2010] |
| | |
| | common_trfs = [transforms.ToTensor(), |
| | transforms.Normalize(mean, std)] |
| | |
| | resnet_train_transform = transforms.Compose([ |
| | transforms.RandomResizedCrop(224), |
| | transforms.RandomHorizontalFlip(), |
| | transforms.ColorJitter(brightness=63 / 255), |
| | *common_trfs]) |
| | |
| | resnet_test_transform = transforms.Compose([ |
| | transforms.Resize(256), |
| | transforms.CenterCrop(224), |
| | *common_trfs]) |
| |
|
| | mean = [0., 0., 0.] |
| | std = [1., 1., 1.] |
| |
|
| | common_trfs = [transforms.ToTensor(), |
| | transforms.Normalize(mean, std)] |
| |
|
| | vit_train_transform = transforms.Compose([ |
| | transforms.RandomResizedCrop(224), |
| | transforms.RandomHorizontalFlip(), |
| | *common_trfs]) |
| | |
| | vit_test_transform = transforms.Compose([ |
| | transforms.Resize(256), |
| | transforms.CenterCrop(224), |
| | *common_trfs]) |
| | |
| |
|
| | |
| | mean=[x/255 for x in [125.3,123.0,113.9]] |
| | std=[x/255 for x in [63.0,62.1,66.7]] |
| |
|
| | alexnet_train_transform = transforms.Compose([ |
| | transforms.RandomResizedCrop(32), |
| | transforms.ToTensor(), |
| | transforms.Normalize(mean,std)]) |
| |
|
| | alexnet_test_transform = transforms.Compose([ |
| | transforms.Resize(256), |
| | transforms.CenterCrop(32), |
| | transforms.ToTensor(), |
| | transforms.Normalize(mean,std)]) |
| |
|
| | @staticmethod |
| | def get_transform(model_type, mode): |
| | if model_type == 'resnet': |
| | if mode == 'train': |
| | return ImageNetRTransform.resnet_train_transform |
| | elif mode == 'test': |
| | return ImageNetRTransform.resnet_test_transform |
| | elif model_type == 'vit': |
| | if mode == 'train': |
| | return ImageNetRTransform.vit_train_transform |
| | elif mode == 'test': |
| | return ImageNetRTransform.vit_test_transform |
| | elif model_type == 'alexnet': |
| | if mode == 'train': |
| | return ImageNetRTransform.alexnet_train_transform |
| | elif mode == 'test': |
| | return ImageNetRTransform.alexnet_test_transform |
| | else: |
| | raise ValueError("Unsupported model type") |
| |
|
| | class TinyImageNetTransform: |
| | |
| | MEAN = [0.485, 0.456, 0.406] |
| | STD = [0.229, 0.224, 0.225] |
| |
|
| | common_trfs = [transforms.ToTensor(), |
| | transforms.Normalize(mean=MEAN, std=STD)] |
| |
|
| | |
| | resnet_train_transform = transforms.Compose([ |
| | transforms.RandomResizedCrop(64), |
| | transforms.RandomHorizontalFlip(), |
| | transforms.ColorJitter(brightness=63 / 255), |
| | *common_trfs |
| | ]) |
| |
|
| | resnet_test_transform = transforms.Compose([ |
| | transforms.Resize(64), |
| | transforms.CenterCrop(64), |
| | *common_trfs |
| | ]) |
| |
|
| | |
| | dset_mean = (0., 0., 0.) |
| | dset_std = (1., 1., 1.) |
| |
|
| | vit_train_transform = transforms.Compose([ |
| | transforms.RandomResizedCrop(224), |
| | transforms.RandomHorizontalFlip(), |
| | transforms.ToTensor(), |
| | transforms.Normalize(dset_mean, dset_std) |
| | ]) |
| |
|
| | vit_test_transform = transforms.Compose([ |
| | transforms.Resize(256), |
| | transforms.CenterCrop(224), |
| | transforms.ToTensor(), |
| | transforms.Normalize(dset_mean, dset_std) |
| | ]) |
| |
|
| | |
| | mean=[x/255 for x in [125.3,123.0,113.9]] |
| | std=[x/255 for x in [63.0,62.1,66.7]] |
| |
|
| | alexnet_train_transform = transforms.Compose([ |
| | transforms.RandomResizedCrop(32), |
| | transforms.ToTensor(), |
| | transforms.Normalize(mean,std)]) |
| |
|
| | alexnet_test_transform = transforms.Compose([ |
| | transforms.Resize(256), |
| | transforms.CenterCrop(32), |
| | transforms.ToTensor(), |
| | transforms.Normalize(mean,std)]) |
| |
|
| | @staticmethod |
| | def get_transform(model_type, mode): |
| | if model_type == 'resnet': |
| | if mode == 'train': |
| | return TinyImageNetTransform.resnet_train_transform |
| | elif mode == 'test': |
| | return TinyImageNetTransform.resnet_test_transform |
| | elif model_type == 'vit': |
| | if mode == 'train': |
| | return TinyImageNetTransform.vit_train_transform |
| | elif mode == 'test': |
| | return TinyImageNetTransform.vit_test_transform |
| | elif model_type == 'alexnet': |
| | if mode == 'train': |
| | return TinyImageNetTransform.alexnet_train_transform |
| | elif mode == 'test': |
| | return TinyImageNetTransform.alexnet_test_transform |
| | else: |
| | raise ValueError("Unsupported model type") |
| |
|
| |
|
| | class FiveDatasetsTransform: |
| | MEAN = [0.5071, 0.4866, 0.4409] |
| | STD = [0.2675, 0.2565, 0.2761] |
| | |
| | common_trfs = [transforms.ToTensor(), |
| | transforms.Normalize(mean=MEAN, std=STD)] |
| | |
| | resnet_train_transform = transforms.Compose([ |
| | transforms.RandomCrop(32, padding=4), |
| | transforms.RandomHorizontalFlip(), |
| | transforms.ColorJitter(brightness=63 / 255), |
| | *common_trfs |
| | ]) |
| | |
| | resnet_test_transform = transforms.Compose([ |
| | transforms.Resize(32), |
| | *common_trfs |
| | ]) |
| |
|
| | |
| | dset_mean = (0., 0., 0.) |
| | dset_std = (1., 1., 1.) |
| | vit_train_transform = transforms.Compose([ |
| | transforms.RandomResizedCrop(224), |
| | transforms.RandomHorizontalFlip(), |
| | transforms.ToTensor(), |
| | transforms.Normalize(dset_mean, dset_std)]) |
| | |
| | vit_test_transform = transforms.Compose([ |
| | transforms.Resize(224), |
| | transforms.ToTensor(), |
| | transforms.Normalize(dset_mean, dset_std)]) |
| |
|
| | |
| | mean=[x/255 for x in [125.3,123.0,113.9]] |
| | std=[x/255 for x in [63.0,62.1,66.7]] |
| |
|
| | alexnet_train_transform = transforms.Compose([ |
| | transforms.Resize(32), |
| | transforms.ToTensor(), |
| | transforms.Normalize(mean,std)]) |
| |
|
| | alexnet_test_transform = transforms.Compose([ |
| | transforms.Resize(32), |
| | transforms.ToTensor(), |
| | transforms.Normalize(mean,std)]) |
| |
|
| | @staticmethod |
| | def get_transform(model_type, mode): |
| | if model_type == 'resnet': |
| | if mode == 'train': |
| | return FiveDatasetsTransform.resnet_train_transform |
| | elif mode == 'test': |
| | return FiveDatasetsTransform.resnet_test_transform |
| | elif model_type == 'vit': |
| | if mode == 'train': |
| | return FiveDatasetsTransform.vit_train_transform |
| | elif mode == 'test': |
| | return FiveDatasetsTransform.vit_test_transform |
| | elif model_type == 'alexnet': |
| | if mode == 'train': |
| | return FiveDatasetsTransform.alexnet_train_transform |
| | elif mode == 'test': |
| | return FiveDatasetsTransform.alexnet_test_transform |
| | else: |
| | raise ValueError("Unsupported model type") |
| |
|
| | transform_classes = { |
| | 'cifar': CIFARTransform, |
| | 'imagenet': ImageNetTransform, |
| | 'imagenet-r': ImageNetRTransform, |
| | 'tiny-imagenet': TinyImageNetTransform, |
| | '5-datasets': FiveDatasetsTransform |
| | } |