| import torch |
| from PIL import Image |
| import torchvision.transforms as transforms |
|
|
|
|
| def load_image(image_path, device): |
|
|
| image_size = 356 |
|
|
| loader = transforms.Compose( |
| [ |
| transforms.Resize((image_size, image_size)), |
| transforms.ToTensor() |
| ] |
| ) |
|
|
| image = Image.open(image_path) |
| image = loader(image).unsqueeze(0) |
|
|
| return image.to(device) |
|
|
|
|
| def tensor_to_image(tensor): |
| tensor = tensor.clone().detach() |
| tensor = tensor.squeeze(0) |
| tensor = torch.clamp(tensor, 0, 1) |
|
|
| unloader = transforms.ToPILImage() |
| image = unloader(tensor.cpu()) |
| return image |
|
|