| from typing import List |
|
|
| import PIL.Image |
| import PIL.ImageOps |
| from packaging import version |
| from PIL import Image |
|
|
|
|
| if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"): |
| PIL_INTERPOLATION = { |
| "linear": PIL.Image.Resampling.BILINEAR, |
| "bilinear": PIL.Image.Resampling.BILINEAR, |
| "bicubic": PIL.Image.Resampling.BICUBIC, |
| "lanczos": PIL.Image.Resampling.LANCZOS, |
| "nearest": PIL.Image.Resampling.NEAREST, |
| } |
| else: |
| PIL_INTERPOLATION = { |
| "linear": PIL.Image.LINEAR, |
| "bilinear": PIL.Image.BILINEAR, |
| "bicubic": PIL.Image.BICUBIC, |
| "lanczos": PIL.Image.LANCZOS, |
| "nearest": PIL.Image.NEAREST, |
| } |
|
|
|
|
| def pt_to_pil(images): |
| """ |
| Convert a torch image to a PIL image. |
| """ |
| images = (images / 2 + 0.5).clamp(0, 1) |
| images = images.cpu().permute(0, 2, 3, 1).float().numpy() |
| images = numpy_to_pil(images) |
| return images |
|
|
|
|
| def numpy_to_pil(images): |
| """ |
| Convert a numpy image or a batch of images to a PIL image. |
| """ |
| if images.ndim == 3: |
| images = images[None, ...] |
| images = (images * 255).round().astype("uint8") |
| if images.shape[-1] == 1: |
| |
| pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images] |
| else: |
| pil_images = [Image.fromarray(image) for image in images] |
|
|
| return pil_images |
|
|
|
|
| def make_image_grid(images: List[PIL.Image.Image], rows: int, cols: int, resize: int = None) -> PIL.Image.Image: |
| """ |
| Prepares a single grid of images. Useful for visualization purposes. |
| """ |
| assert len(images) == rows * cols |
|
|
| if resize is not None: |
| images = [img.resize((resize, resize)) for img in images] |
|
|
| w, h = images[0].size |
| grid = Image.new("RGB", size=(cols * w, rows * h)) |
|
|
| for i, img in enumerate(images): |
| grid.paste(img, box=(i % cols * w, i // cols * h)) |
| return grid |
|
|