| from torchvision.transforms import ( |
| Normalize, |
| Compose, |
| RandomResizedCrop, |
| InterpolationMode, |
| ToTensor, |
| Resize, |
| CenterCrop, |
| ) |
|
|
|
|
| def _convert_to_rgb(image): |
| return image.convert("RGB") |
|
|
|
|
| def image_transform( |
| image_size: int, |
| is_train: bool, |
| mean=(0.48145466, 0.4578275, 0.40821073), |
| std=(0.26862954, 0.26130258, 0.27577711), |
| ): |
| normalize = Normalize(mean=mean, std=std) |
| if is_train: |
| return Compose( |
| [ |
| RandomResizedCrop( |
| image_size, |
| scale=(0.9, 1.0), |
| interpolation=InterpolationMode.BICUBIC, |
| ), |
| _convert_to_rgb, |
| ToTensor(), |
| normalize, |
| ] |
| ) |
| else: |
| return Compose( |
| [ |
| Resize(image_size, interpolation=InterpolationMode.BICUBIC), |
| CenterCrop(image_size), |
| _convert_to_rgb, |
| ToTensor(), |
| normalize, |
| ] |
| ) |
|
|