| import torch |
| from omegaconf import OmegaConf |
| from sgm.util import instantiate_from_config |
| from sgm.modules.diffusionmodules.sampling import * |
|
|
|
|
| def init_model(cfgs): |
|
|
| model_cfg = OmegaConf.load(cfgs.model_cfg_path) |
| ckpt = cfgs.load_ckpt_path |
|
|
| model = instantiate_from_config(model_cfg.model) |
| model.init_from_ckpt(ckpt) |
|
|
| if cfgs.type == "train": |
| model.train() |
| else: |
| model.to(torch.device("cuda", index=cfgs.gpu)) |
| model.eval() |
| model.freeze() |
|
|
| return model |
|
|
| def init_sampling(cfgs): |
|
|
| discretization_config = { |
| "target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization", |
| } |
|
|
| guider_config = { |
| "target": "sgm.modules.diffusionmodules.guiders.VanillaCFG", |
| "params": {"scale": cfgs.scale[0]}, |
| } |
|
|
| sampler = EulerEDMSampler( |
| num_steps=cfgs.steps, |
| discretization_config=discretization_config, |
| guider_config=guider_config, |
| s_churn=0.0, |
| s_tmin=0.0, |
| s_tmax=999.0, |
| s_noise=1.0, |
| verbose=True, |
| device=torch.device("cuda", index=cfgs.gpu) |
| ) |
|
|
| return sampler |
|
|
| def deep_copy(batch): |
|
|
| c_batch = {} |
| for key in batch: |
| if isinstance(batch[key], torch.Tensor): |
| c_batch[key] = torch.clone(batch[key]) |
| elif isinstance(batch[key], (tuple, list)): |
| c_batch[key] = batch[key].copy() |
| else: |
| c_batch[key] = batch[key] |
| |
| return c_batch |
|
|
| def prepare_batch(cfgs, batch): |
|
|
| for key in batch: |
| if isinstance(batch[key], torch.Tensor): |
| batch[key] = batch[key].to(torch.device("cuda", index=cfgs.gpu)) |
|
|
| batch_uc = batch |
|
|
| return batch, batch_uc |