| | import torch |
| | import requests |
| | import torchvision.transforms as transforms |
| | from math import ceil |
| | from PIL import Image |
| | import matplotlib.pyplot as plt |
| |
|
| | MAX_RESOLUTION = 1024 |
| |
|
| | def get_resize_output_image_size( |
| | image_size, |
| | fix_resolution=False, |
| | max_resolution: int = MAX_RESOLUTION, |
| | patch_size=32 |
| | ) -> tuple: |
| | if fix_resolution==True: |
| | return 224,224 |
| | l1, l2 = image_size |
| | short, long = (l2, l1) if l2 <= l1 else (l1, l2) |
| |
|
| | |
| | requested_new_long = min( |
| | [ |
| | ceil(long / patch_size) * patch_size, |
| | max_resolution, |
| | ] |
| | ) |
| |
|
| | new_long, new_short = requested_new_long, int(requested_new_long * short / long) |
| | |
| | new_short = ceil(new_short / patch_size) * patch_size |
| | return (new_long, new_short) if l2 <= l1 else (new_short, new_long) |
| |
|
| |
|
| | def preprocess_image( |
| | image_tensor: torch.Tensor, |
| | patch_size=32 |
| | ) -> torch.Tensor: |
| | |
| | |
| | |
| | |
| | patches = image_tensor.unfold(1, patch_size, patch_size)\ |
| | .unfold(2, patch_size, patch_size) |
| | patches = patches.permute(1, 2, 0, 3, 4).contiguous() |
| | return patches |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | def get_transform(height, width): |
| | preprocess_transform = transforms.Compose([ |
| | transforms.Resize((height, width)), |
| | transforms.ToTensor(), |
| | transforms.Normalize(mean=[0.485, 0.456, 0.406], |
| | std=[0.229, 0.224, 0.225]) |
| | ]) |
| | return preprocess_transform |
| |
|
| | def convert_image_to_patches(image, patch_size=32) -> torch.Tensor: |
| | |
| | width, height = image.size |
| | new_width, new_height = get_resize_output_image_size((width, height), patch_size=patch_size, fix_resolution=False) |
| | img_tensor = get_transform(new_height, new_width)(image) |
| | |
| | img_patches = preprocess_image(img_tensor, patch_size=patch_size) |
| | return img_patches |
| |
|