| import json |
| import cv2 |
| import numpy as np |
| import os |
| import random |
| from glob import glob |
|
|
| from torch.utils.data import Dataset |
| from torchvision import transforms |
| from PIL import Image, ImageDraw |
| import torch |
| import albumentations as A |
|
|
|
|
| class ZalandoDataset(Dataset): |
| def __init__(self, transform, root="/tmp/zalando/train/", width = 512, height = 512): |
| self.root = root |
| self.transform = transform |
| self.width = width |
| self.height = height |
| self.image_paths = sorted(glob(f'{self.root}image/*.jpg')) |
| self.ref_paths = sorted(glob(f'{self.root}cloth/*.jpg')) |
| self.parse_paths = sorted(glob(f"{self.root}image-parse-v3/*.png")) |
| self.prompts = ["", "a professional, detailed, high-quality image", "shirt"] |
| self.labels = { |
| 0: ['background', [0, 10]], |
| 1: ['hair', [1, 2]], |
| 2: ['face', [4, 13]], |
| 3: ['upper', [5, 6, 7]], |
| 4: ['bottom', [9, 12]], |
| 5: ['left_arm', [14]], |
| 6: ['right_arm', [15]], |
| 7: ['left_leg', [16]], |
| 8: ['right_leg', [17]], |
| 9: ['left_shoe', [18]], |
| 10: ['right_shoe', [19]], |
| 11: ['socks', [8]], |
| 12: ['noise', [3, 11]] |
| } |
| self.random_trans=A.Compose([ |
| A.HorizontalFlip(p=0.5), |
| A.Rotate(limit=20), |
| A.Blur(p=0.3), |
| |
| ]) |
| |
| |
| def img_segment(self,parse_img,wanted_label = 3): |
| im_parse_pil = transforms.Resize((512,512), interpolation=0)(parse_img) |
| parse = torch.from_numpy(np.array(im_parse_pil)[None]).long() |
| parse_map = torch.FloatTensor(20, 512, 512).zero_() |
| parse_map = parse_map.scatter_(0, parse, 1.0) |
| new_parse_map = torch.FloatTensor(13, 512, 512).zero_() |
| for i in range(len(self.labels)): |
| for label in self.labels[i][1]: |
| new_parse_map[i] += parse_map[label] |
|
|
| shirt_mask = new_parse_map[wanted_label].numpy() |
| return shirt_mask.astype(dtype="uint8") * 255 |
|
|
| def add_noise(self, image): |
| image = image.astype(np.uint8) |
| |
| gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) |
|
|
| |
| contours, _ = cv2.findContours(gray, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
|
|
| |
| if contours: |
| random_contour = contours[np.random.randint(len(contours))] |
|
|
| |
| canvas = np.zeros_like(gray) |
|
|
| |
| cv2.drawContours(canvas, [random_contour], 0, 255, thickness=10) |
|
|
| |
| kernel = np.ones((15,15), np.uint8) |
| canvas = cv2.dilate(canvas, kernel, iterations=1) |
|
|
| |
| boundary = cv2.absdiff(canvas, gray) |
|
|
| |
| points_on_boundary = [] |
| for i in range(len(random_contour)): |
| x, y = random_contour[i][0] |
| points_on_boundary.append((x, y)) |
| points_on_boundary = np.array(points_on_boundary) |
|
|
| |
| for point in points_on_boundary: |
| |
| thickness = 30 |
| |
| length = 0.1 |
| angle = np.random.randint(0,360) |
| endpoint = (int(point[0] + length * np.cos(angle * np.pi / 180)), |
| int(point[1] + length * np.sin(angle * np.pi / 180))) |
| cv2.line(boundary, tuple(point), endpoint, 255, thickness) |
|
|
| |
| image = cv2.bitwise_or(image, cv2.cvtColor(boundary, cv2.COLOR_GRAY2BGR)) |
| return image |
|
|
| def __len__(self): |
| return len(self.image_paths) |
|
|
| def __getitem__(self, idx): |
| source_filename = self.ref_paths[idx] |
| target_filename = self.image_paths[idx] |
| parse_filename = self.parse_paths[idx] |
| |
| prompt = random.choice(self.prompts) |
| |
| source = cv2.imread(source_filename) |
| source = cv2.resize(source, (224,224)) |
| if self.transform: |
| source = self.random_trans(image=source)["image"] |
| |
| |
| target = cv2.imread(target_filename) |
| target = cv2.resize(target, (self.width,self.height)) |
| |
| parse = Image.open(parse_filename).resize((self.width,self.height)) |
| mask = self.img_segment(parse,3) |
| |
| |
| mask = np.array(mask) |
| mask = cv2.cvtColor(mask, cv2.COLOR_RGB2BGR) |
| |
| mask = self.add_noise(mask) |
| mask_gray = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY) |
| mask_gray = np.expand_dims(mask_gray, axis=-1) |
| |
| |
| source = cv2.cvtColor(source, cv2.COLOR_BGR2RGB) |
| target = cv2.cvtColor(target, cv2.COLOR_BGR2RGB) |
| |
| |
|
|
| |
| mask = mask.astype(np.float32) / 255.0 |
| source = source.astype(np.float32) / 255.0 |
| target0 = target.astype(np.float32) / 255.0 |
| masked_image = target0 * (mask < 0.5) |
|
|
| |
| target_normalized = (target.astype(np.float32) / 127.5) - 1.0 |
| |
| |
| return dict(jpg=target_normalized, txt=prompt, hint=source, mask = mask_gray, masked_image = masked_image, path=source_filename) |
|
|