| import os |
|
|
| import torch |
|
|
| import numpy as np |
| from PIL import Image, ImageOps |
| from .utils import BIGMAX |
| from .logger import logger |
|
|
|
|
| class LoadImagesFromDirectory: |
| @classmethod |
| def INPUT_TYPES(s): |
| return { |
| "required": { |
| "directory": ("STRING", {"default": ""}), |
| }, |
| "optional": { |
| "image_load_cap": ("INT", {"default": 0, "min": 0, "max": BIGMAX, "step": 1}), |
| "start_index": ("INT", {"default": 0, "min": 0, "max": BIGMAX, "step": 1}), |
| } |
| } |
| |
| RETURN_TYPES = ("IMAGE", "MASK", "INT") |
| FUNCTION = "load_images" |
|
|
| CATEGORY = "" |
|
|
| def load_images(self, directory: str, image_load_cap: int = 0, start_index: int = 0): |
| if not os.path.isdir(directory): |
| raise FileNotFoundError(f"Directory '{directory} cannot be found.'") |
| dir_files = os.listdir(directory) |
| if len(dir_files) == 0: |
| raise FileNotFoundError(f"No files in directory '{directory}'.") |
|
|
| dir_files = sorted(dir_files) |
| dir_files = [os.path.join(directory, x) for x in dir_files] |
| |
| dir_files = dir_files[start_index:] |
|
|
| images = [] |
| masks = [] |
|
|
| limit_images = False |
| if image_load_cap > 0: |
| limit_images = True |
| image_count = 0 |
|
|
| for image_path in dir_files: |
| if os.path.isdir(image_path): |
| continue |
| if limit_images and image_count >= image_load_cap: |
| break |
| i = Image.open(image_path) |
| i = ImageOps.exif_transpose(i) |
| image = i.convert("RGB") |
| image = np.array(image).astype(np.float32) / 255.0 |
| image = torch.from_numpy(image)[None,] |
| if 'A' in i.getbands(): |
| mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0 |
| mask = 1. - torch.from_numpy(mask) |
| else: |
| mask = torch.zeros((64,64), dtype=torch.float32, device="cpu") |
| images.append(image) |
| masks.append(mask) |
| image_count += 1 |
| |
| if len(images) == 0: |
| raise FileNotFoundError(f"No images could be loaded from directory '{directory}'.") |
|
|
| return (torch.cat(images, dim=0), torch.stack(masks, dim=0), image_count) |
|
|