| from argparse import Namespace |
| import os |
| import time |
| from tqdm import tqdm |
| from PIL import Image |
| import numpy as np |
| import torch |
| from torch.utils.data import DataLoader |
|
|
| import sys |
| sys.path.append(".") |
| sys.path.append("..") |
|
|
| from configs import data_configs |
| from datasets.inference_dataset import InferenceDataset |
| from datasets.augmentations import AgeTransformer |
| from utils.common import tensor2im, log_image |
| from options.test_options import TestOptions |
| from models.psp import pSp |
|
|
|
|
| def run(): |
| test_opts = TestOptions().parse() |
|
|
| out_path_results = os.path.join(test_opts.exp_dir, 'inference_side_by_side') |
| os.makedirs(out_path_results, exist_ok=True) |
|
|
| |
| ckpt = torch.load(test_opts.checkpoint_path, map_location='cpu') |
| opts = ckpt['opts'] |
| opts.update(vars(test_opts)) |
| opts = Namespace(**opts) |
|
|
| net = pSp(opts) |
| net.eval() |
| net.cuda() |
|
|
| age_transformers = [AgeTransformer(target_age=age) for age in opts.target_age.split(',')] |
|
|
| print(f'Loading dataset for {opts.dataset_type}') |
| dataset_args = data_configs.DATASETS[opts.dataset_type] |
| transforms_dict = dataset_args['transforms'](opts).get_transforms() |
| dataset = InferenceDataset(root=opts.data_path, |
| transform=transforms_dict['transform_inference'], |
| opts=opts, |
| return_path=True) |
| dataloader = DataLoader(dataset, |
| batch_size=opts.test_batch_size, |
| shuffle=False, |
| num_workers=int(opts.test_workers), |
| drop_last=False) |
|
|
| if opts.n_images is None: |
| opts.n_images = len(dataset) |
|
|
| global_time = [] |
| global_i = 0 |
| for input_batch, image_paths in tqdm(dataloader): |
| if global_i >= opts.n_images: |
| break |
| batch_results = {} |
| for idx, age_transformer in enumerate(age_transformers): |
| with torch.no_grad(): |
| input_age_batch = [age_transformer(img.cpu()).to('cuda') for img in input_batch] |
| input_age_batch = torch.stack(input_age_batch) |
| input_cuda = input_age_batch.cuda().float() |
| tic = time.time() |
| result_batch = run_on_batch(input_cuda, net, opts) |
| toc = time.time() |
| global_time.append(toc - tic) |
|
|
| resize_amount = (256, 256) if opts.resize_outputs else (1024, 1024) |
| for i in range(len(input_batch)): |
| result = tensor2im(result_batch[i]) |
| im_path = image_paths[i] |
| input_im = log_image(input_batch[i], opts) |
| if im_path not in batch_results.keys(): |
| batch_results[im_path] = np.array(input_im.resize(resize_amount)) |
| batch_results[im_path] = np.concatenate([batch_results[im_path], |
| np.array(result.resize(resize_amount))], |
| axis=1) |
|
|
| for im_path, res in batch_results.items(): |
| image_name = os.path.basename(im_path) |
| im_save_path = os.path.join(out_path_results, image_name) |
| Image.fromarray(np.array(res)).save(im_save_path) |
| global_i += 1 |
|
|
|
|
| def run_on_batch(inputs, net, opts): |
| result_batch = net(inputs, randomize_noise=False, resize=opts.resize_outputs) |
| return result_batch |
|
|
|
|
| if __name__ == '__main__': |
| run() |
|
|