| import numpy as np |
| import cv2 |
| import albumentations as A |
| from torch.utils.data import Dataset |
| from .data_utils import * |
|
|
| class BaseDataset(Dataset): |
| def __init__(self): |
| self.data = [] |
|
|
| def __getitem__(self, idx): |
| item = self._get_sample(idx) |
| return item |
| |
| def _get_sample(self, idx): |
| |
| pass |
|
|
| def __len__(self): |
| |
| pass |
|
|
| def aug_data_mask(self, image, mask): |
| transform = A.Compose([ |
| A.RandomBrightnessContrast(p=0.5), |
| A.Rotate(limit=30, border_mode=cv2.BORDER_CONSTANT), |
| ]) |
|
|
| transformed = transform(image=image.astype(np.uint8), mask=mask) |
| transformed_image = transformed["image"] |
| transformed_mask = transformed["mask"] |
| return transformed_image, transformed_mask |
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
|
|
| def aug_patch(self, patch): |
| gray = cv2.cvtColor(patch, cv2.COLOR_RGB2GRAY) |
| mask = (gray < 250).astype(np.float32)[:, :, None] |
|
|
| transform = A.Compose([ |
| A.HorizontalFlip(p=0.2), |
| A.RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.1, p=0.3), |
| A.Rotate(limit=15, border_mode=cv2.BORDER_REPLICATE, p=0.5), |
| ]) |
|
|
| transformed = transform(image=patch.astype(np.uint8), mask=mask) |
| aug_img = transformed["image"] |
| aug_mask = transformed["mask"] |
| final_img = aug_img * aug_mask + 255 * (1 - aug_mask) |
|
|
| return final_img.astype(np.uint8) |
|
|
| def sample_timestep(self, max_step=1000): |
| if np.random.rand() < 0.3: |
| step = np.random.randint(0, max_step) |
| else: |
| step = np.random.randint(0, max_step // 2) |
| return np.array([step]) |
|
|
| def get_patch(self, ref_image, ref_mask): |
| ''' |
| extract compact patch and convert to 224x224 RGBA. |
| ref_mask: [0, 1] |
| ''' |
|
|
| |
| y1, y2, x1, x2 = get_bbox_from_mask(ref_mask) |
| |
| |
| ref_mask_3 = np.stack([ref_mask, ref_mask, ref_mask], -1) |
| masked_ref_image = ref_image * ref_mask_3 + np.ones_like(ref_image) * 255 * (1 - ref_mask_3) |
|
|
| |
| masked_ref_image = masked_ref_image[y1:y2, x1:x2, :] |
| ref_mask_crop = ref_mask[y1:y2, x1:x2] |
|
|
| |
| ratio = np.random.randint(11, 15) / 10 |
| masked_ref_image, ref_mask_crop = expand_image_mask(masked_ref_image, ref_mask_crop, ratio=ratio) |
|
|
| |
| |
|
|
| |
| masked_ref_image = pad_to_square(masked_ref_image, pad_value=255) |
| masked_ref_image = cv2.resize(masked_ref_image.astype(np.uint8), (224, 224)) |
|
|
| m_local = ref_mask_crop[:, :, None] * 255 |
| m_local = pad_to_square(m_local, pad_value=0) |
| m_local = cv2.resize(m_local.astype(np.uint8), (224, 224), interpolation=cv2.INTER_NEAREST) |
| |
| rgba_image = np.dstack((masked_ref_image.astype(np.uint8), m_local)) |
|
|
| return rgba_image |
|
|
| def _construct_collage(self, image, object_0, object_1, mask_0, mask_1): |
| background = image.copy() |
| image = pad_to_square(image, pad_value = 0, random = False).astype(np.uint8) |
| image = cv2.resize(image.astype(np.uint8), (512,512)).astype(np.float32) |
| image = image / 127.5 - 1.0 |
| item = {} |
| item.update({'jpg': image.copy()}) |
|
|
| ratio = np.random.randint(11, 15) / 10 |
| object_0 = expand_image(object_0, ratio=ratio) |
| object_0 = self.aug_patch(object_0) |
| object_0 = pad_to_square(object_0, pad_value = 255, random = False) |
| object_0 = cv2.resize(object_0.astype(np.uint8), (224,224) ).astype(np.uint8) |
| object_0 = object_0 / 255 |
| item.update({'ref0': object_0.copy()}) |
|
|
| ratio = np.random.randint(11, 15) / 10 |
| object_1 = expand_image(object_1, ratio=ratio) |
| object_1 = self.aug_patch(object_1) |
| object_1 = pad_to_square(object_1, pad_value = 255, random = False) |
| object_1 = cv2.resize(object_1.astype(np.uint8), (224,224) ).astype(np.uint8) |
| object_1 = object_1 / 255 |
| item.update({'ref1': object_1.copy()}) |
|
|
| background_mask0 = background.copy() * 0.0 |
| background_mask1 = background.copy() * 0.0 |
| background_mask = background.copy() * 0.0 |
|
|
| box_yyxx = get_bbox_from_mask(mask_0) |
| box_yyxx = expand_bbox(mask_0, box_yyxx, ratio=[1.1, 1.2]) |
| y1, y2, x1, x2 = box_yyxx |
| background[y1:y2, x1:x2,:] = 0 |
| background_mask0[y1:y2, x1:x2, :] = 1.0 |
| background_mask[y1:y2, x1:x2, :] = 1.0 |
|
|
| box_yyxx = get_bbox_from_mask(mask_1) |
| box_yyxx = expand_bbox(mask_1, box_yyxx, ratio=[1.1, 1.2]) |
| y1, y2, x1, x2 = box_yyxx |
| background[y1:y2, x1:x2,:] = 0 |
| background_mask1[y1:y2, x1:x2, :] = 1.0 |
| background_mask[y1:y2, x1:x2, :] = 1.0 |
|
|
| background = pad_to_square(background, pad_value = 0, random = False).astype(np.uint8) |
| background = cv2.resize(background.astype(np.uint8), (512,512)).astype(np.float32) |
| background_mask0 = pad_to_square(background_mask0, pad_value = 2, random = False).astype(np.uint8) |
| background_mask1 = pad_to_square(background_mask1, pad_value = 2, random = False).astype(np.uint8) |
| background_mask = pad_to_square(background_mask, pad_value = 2, random = False).astype(np.uint8) |
| background_mask0 = cv2.resize(background_mask0.astype(np.uint8), (512,512), interpolation = cv2.INTER_NEAREST).astype(np.float32) |
| background_mask1 = cv2.resize(background_mask1.astype(np.uint8), (512,512), interpolation = cv2.INTER_NEAREST).astype(np.float32) |
| background_mask = cv2.resize(background_mask.astype(np.uint8), (512,512), interpolation = cv2.INTER_NEAREST).astype(np.float32) |
| |
| background_mask0[background_mask0 == 2] = -1 |
| background_mask1[background_mask1 == 2] = -1 |
| background_mask[background_mask == 2] = -1 |
|
|
| background_mask0_ = background_mask0 |
| background_mask0_[background_mask0_ == -1] = 0 |
| background_mask0_ = background_mask0_[:, :, 0] |
|
|
| background_mask1_ = background_mask1 |
| background_mask1_[background_mask1_ == -1] = 0 |
| background_mask1_ = background_mask1_[:, :, 0] |
|
|
| background = background / 127.5 - 1.0 |
| background = np.concatenate([background, background_mask[:,:,:1]] , -1) |
| item.update({'hint': background.copy()}) |
|
|
| item.update({'mask0': background_mask0_.copy()}) |
| item.update({'mask1': background_mask1_.copy()}) |
|
|
| sampled_time_steps = self.sample_timestep() |
| item['time_steps'] = sampled_time_steps |
| item['object_num'] = 2 |
|
|
| return item |
|
|