| |
|
| |
|
| |
|
| | import json |
| | import torchvision.transforms as transforms |
| | from torch.utils.data.dataset import Dataset |
| | |
| | from PIL import Image |
| | import os |
| | import torch |
| | import torchvision.transforms.functional as F |
| | def tokenize_captions( caption, tokenizer): |
| | captions = [caption] |
| | inputs = tokenizer( |
| | captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" |
| | ) |
| | |
| | |
| | return inputs.input_ids |
| |
|
| |
|
| |
|
| |
|
| | class SquarePad: |
| | def __call__(self, image ): |
| | w, h = image.size |
| | max_wh = max(w, h) |
| | hp = int((max_wh - w) / 2) |
| | vp = int((max_wh - h) / 2) |
| | padding = (hp, vp, hp, vp) |
| | return F.pad(image, padding, (255,255,255), 'constant') |
| |
|
| | class NormalSegDataset(Dataset): |
| | def __init__(self,args, path,tokenizer,cfg_prob ): |
| | |
| |
|
| | self.image_transforms = transforms.Compose( |
| | [ |
| | |
| | |
| | |
| | |
| | transforms.RandomResizedCrop(args.resolution, scale=(0.9, 1.0), interpolation=2, ), |
| | transforms.ToTensor(), |
| | ] |
| | ) |
| |
|
| | self.additional_image_transforms = transforms.Compose( |
| | [transforms.Normalize([0.5], [0.5]),] |
| | ) |
| |
|
| |
|
| | meta_path = os.path.join(path, 'meta_train_seg.json') |
| |
|
| | with open(meta_path, 'r') as f: |
| | self.meta = json.load(f) |
| |
|
| | |
| |
|
| | self.keys = self.meta['keys'] |
| | self.meta = self.meta['data'] |
| | |
| |
|
| | self.tokenizer = tokenizer |
| |
|
| | self.cfg_prob = cfg_prob |
| | |
| | def __len__(self): |
| | return len(self.keys) |
| | |
| | def __getitem__(self, index): |
| |
|
| | meta_data = self.meta[self.keys[index]] |
| |
|
| | rgb_path = meta_data['rgb'] |
| | normal_path = meta_data['normal'] |
| | seg_path = meta_data['seg'] |
| | text_prompt = meta_data['caption'][0] |
| |
|
| | rand = torch.rand(1).item() |
| | if rand < self.cfg_prob: |
| | text_prompt = "" |
| | |
| | image = Image.open(rgb_path).convert("RGB") |
| | state = torch.get_rng_state() |
| | image = self.image_transforms(image) |
| |
|
| | rand = torch.rand(1).item() |
| | if rand < self.cfg_prob: |
| | |
| | |
| | normal_image = Image.new('RGB', (image.shape[1], image.shape[2]), (255, 255, 255)) |
| | |
| | seg_image = Image.new('L', (image.shape[1], image.shape[2]), (0)) |
| | else: |
| | normal_image = Image.open(normal_path).convert("RGB") |
| | seg_image = Image.open(seg_path).convert("L") |
| | torch.set_rng_state(state) |
| | normal_image = self.image_transforms(normal_image) |
| |
|
| | torch.set_rng_state(state) |
| | seg_image = self.image_transforms(seg_image) |
| |
|
| |
|
| | conditioning_image = torch.cat([normal_image, seg_image], dim=0) |
| |
|
| | image = self.additional_image_transforms(image) |
| | |
| | prompt = text_prompt |
| | |
| |
|
| |
|
| | |
| | prompt = tokenize_captions(prompt, self.tokenizer) |
| |
|
| | return image, conditioning_image, prompt, text_prompt |
| |
|
| |
|
| |
|