王方懿康
Initial commit
ab2369a
Raw
History Blame Contribute Delete
4.53 kB
import sys
import time
import torch
import os
import json
import argparse
sys.path.append(os.getcwd())
from diffusers import LDMPipeline, DDIMScheduler, PNDMScheduler, UniPCMultistepScheduler, DPMSolverMultistepScheduler
from scheduler.scheduling_dpmsolver_multistep_lm import DPMSolverMultistepLMScheduler
from scheduler.scheduling_ddim_lm import DDIMLMScheduler
def main():
parser = argparse.ArgumentParser(description="sampling script for CelebA-HQ.")
parser.add_argument('--test_num', type=int, default=1)
parser.add_argument('--start_index', type=int, default=0)
parser.add_argument('--batch_size', type=int, default=4)
parser.add_argument('--num_inference_steps', type=int, default=20)
parser.add_argument('--sampler_type', type = str,default='lag', choices=[ 'pndm', 'ddim_lm', 'ddim', 'dpm++', 'dpm','dpm_lm', 'unipc'])
parser.add_argument('--save_dir', type=str, default='/xxx/xxx')
parser.add_argument('--model_id', type=str,
default='/xxx/xxx/ddpm_ema_cifar10')
parser.add_argument('--lamb', type=float, default=1.0)
parser.add_argument('--kappa', type=float, default=0.0)
parser.add_argument('--dtype', type=str, default='fp32')
parser.add_argument('--device', type=str, default='cuda')
args = parser.parse_args()
dtype = None
if args.dtype in ['fp32']:
dtype = torch.float32
elif args.dtype in ['fp64']:
dtype = torch.float64
elif args.dtype in ['fp16']:
dtype = torch.float16
elif args.dtype in ['bf16']:
dtype = torch.bfloat16
start_index = args.start_index
batch_size = args.batch_size
sampler_type = args.sampler_type
test_num = args.test_num
num_inference_steps = args.num_inference_steps
device = args.device
lamb = args.lamb
kappa = args.kappa
model_id = args.model_id
save_dir = args.save_dir
if not os.path.exists(save_dir):
os.makedirs(save_dir, exist_ok=True)
with torch.no_grad():
# load pipeline
pipe = LDMPipeline.from_pretrained(model_id, torch_dtype=dtype)
pipe.unet.to(device)
pipe.vqvae.to(device)
# load scheduler
if sampler_type in ['pndm']:
pipe.scheduler = PNDMScheduler.from_config(pipe.scheduler.config)
elif sampler_type in ['dpm++']:
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe.scheduler.config.solver_order = 3
pipe.scheduler.config.algorithm_type = "dpmsolver++"
elif sampler_type in ['dpm_lm']:
pipe.scheduler = DPMSolverMultistepLMScheduler.from_config(pipe.scheduler.config)
pipe.scheduler.config.solver_order = 3
pipe.scheduler.config.algorithm_type = "dpmsolver"
pipe.scheduler.lamb = lamb
pipe.scheduler.lm = True
pipe.scheduler.kappa = kappa
elif sampler_type in ['dpm']:
pipe.scheduler = DPMSolverMultistepLMScheduler.from_config(pipe.scheduler.config)
pipe.scheduler.config.solver_order = 3
pipe.scheduler.config.algorithm_type = "dpmsolver"
pipe.scheduler.lm = False
elif sampler_type in ['ddim']:
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
elif sampler_type in ['ddim_lm']:
pipe.scheduler = DDIMLMScheduler.from_config(pipe.scheduler.config)
pipe.scheduler.lamb = lamb
pipe.scheduler.lm = True
pipe.scheduler.kappa = kappa
elif sampler_type in ['unipc']:
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
for seed in range(start_index, start_index + test_num):
print('prepare to sample')
start_time = time.time()
torch.manual_seed(seed)
# sampling process
images = pipe(batch_size=batch_size, num_inference_steps=num_inference_steps).images
# store the generated images
for i, image in enumerate(images):
image.save(
os.path.join(save_dir, f"cifar10_{sampler_type}_inference{num_inference_steps}_seed{seed}_{i}.png"))
print(f"{sampler_type} batch##{seed},done")
# output the sampling time-costs
end_time = time.time()
time_difference = end_time - start_time
print(f"The code took {time_difference} seconds to run.")
if __name__ == '__main__':
main()