import copy import functools import os import blobfile as bf import torch as th import torch.distributed as dist from torch.nn.parallel.distributed import DistributedDataParallel as DDP from torch.optim import AdamW import cv2 from . import dist_util, logger from .fp16_util import MixedPrecisionTrainer from .nn import update_ema from .resample import LossAwareSampler, UniformSampler import numpy as np import skimage from skimage.metrics import peak_signal_noise_ratio as psnr import math # For ImageNet experiments, this was a good default value. # We found that the lg_loss_scale quickly climbed to # 20-21 within the first ~1K steps of training. INITIAL_LOG_LOSS_SCALE = 20.0 import core.metrics as Metrics # from core.wandb_logger import WandbLogger import wandb class TrainLoop: def __init__( self, *, model, diffusion, data, val_dat, batch_size, microbatch, lr, ema_rate, log_interval, save_interval, resume_checkpoint, args, use_fp16=False, fp16_scale_growth=1e-3, schedule_sampler=None, weight_decay=0.0, lr_anneal_steps=0, ): self.model = model self.diffusion = diffusion self.data = data self.val_data=val_dat self.batch_size = batch_size self.microbatch = microbatch if microbatch > 0 else batch_size self.lr = lr self.ema_rate = ( [ema_rate] if isinstance(ema_rate, float) else [float(x) for x in ema_rate.split(",")] ) self.log_interval = log_interval self.save_interval = save_interval self.resume_checkpoint = resume_checkpoint self.args = args self.use_fp16 = use_fp16 self.fp16_scale_growth = fp16_scale_growth self.schedule_sampler = schedule_sampler or UniformSampler(diffusion) self.weight_decay = weight_decay self.lr_anneal_steps = lr_anneal_steps self.step = 0 self.resume_step = 0 self.global_batch = self.batch_size * dist.get_world_size() self.sync_cuda = th.cuda.is_available() self._load_and_sync_parameters() self.mp_trainer = MixedPrecisionTrainer( model=self.model, use_fp16=self.use_fp16, fp16_scale_growth=fp16_scale_growth, ) self.opt = AdamW( self.mp_trainer.master_params, lr=self.lr, weight_decay=self.weight_decay ) if self.resume_step: self._load_optimizer_state() # Model was resumed, either due to a restart or a checkpoint # being specified at the command line. self.ema_params = [ self._load_ema_parameters(rate) for rate in self.ema_rate ] else: self.ema_params = [ copy.deepcopy(self.mp_trainer.master_params) for _ in range(len(self.ema_rate)) ] if th.cuda.is_available(): print('cuda available') self.use_ddp = True self.ddp_model = DDP( self.model, device_ids=[dist_util.dev()], output_device=dist_util.dev(), broadcast_buffers=False, bucket_cap_mb=128, find_unused_parameters=True, ) else: if dist.get_world_size() > 1: logger.warn( "Distributed training requires CUDA. " "Gradients will not be synchronized properly!" ) self.use_ddp = False self.ddp_model = self.model def _load_and_sync_parameters(self): resume_checkpoint = find_resume_checkpoint() or self.resume_checkpoint if resume_checkpoint: self.resume_step = parse_resume_step_from_filename(resume_checkpoint) if dist.get_rank() == 0: logger.log(f"loading model from checkpoint: {resume_checkpoint}...") dict_load = dist_util.load_state_dict(resume_checkpoint, map_location=dist_util.dev()) self.model.load_state_dict(dict_load, strict=False) dist_util.sync_params(self.model.parameters()) def _load_ema_parameters(self, rate): ema_params = copy.deepcopy(self.mp_trainer.master_params) main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint ema_checkpoint = find_ema_checkpoint(main_checkpoint, self.resume_step, rate) if ema_checkpoint: if dist.get_rank() == 0: logger.log(f"loading EMA from checkpoint: {ema_checkpoint}...") state_dict = dist_util.load_state_dict( ema_checkpoint, map_location=dist_util.dev() ) ema_params = self.mp_trainer.state_dict_to_master_params(state_dict) dist_util.sync_params(ema_params) return ema_params def _load_optimizer_state(self): main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint opt_checkpoint = bf.join( bf.dirname(main_checkpoint), f"opt{self.resume_step:06}.pt" ) if bf.exists(opt_checkpoint): logger.log(f"loading optimizer state from checkpoint: {opt_checkpoint}") state_dict = dist_util.load_state_dict( opt_checkpoint, map_location=dist_util.dev() ) self.opt.load_state_dict(state_dict) def run_loop(self): val_idx=0 best_psnr = 0 # wandb.init(project = 'diffusion_small', config=self.args) while ( not self.lr_anneal_steps or self.step + self.resume_step < self.lr_anneal_steps ): # wandb_logger = WandbLogger() batch, cond = next(self.data) self.run_step(batch, cond) if (self.step+1) % self.save_interval == 0: number=0 all_images=[] number=0 print('validation') with th.no_grad(): val_idx=val_idx+1 psnr_val = 0 for batch_id1, data_var in enumerate(self.val_data): clean_batch, model_kwargs1 = data_var model_kwargs={} for k, v in model_kwargs1.items(): if('Index' in k): img_name=v else: model_kwargs[k]= v.to(dist_util.dev()) sample = self.diffusion.p_sample_loop( self.model, (clean_batch.shape[0], 3, 256,256), clip_denoised=True, model_kwargs=model_kwargs, ) sample = ((sample + 1) * 127.5) sample = sample.clamp(0, 255).to(th.uint8) sample = sample.permute(0, 2, 3, 1) sample = sample.contiguous().cpu().numpy() number=number+1 clean_image = ((model_kwargs['HR']+1)* 127.5).clamp(0, 255).to(th.uint8) clean_image= clean_image.permute(0, 2, 3, 1) clean_image= clean_image.contiguous().cpu().numpy() clean_image = clean_image[0][:,:,::-1] sample = sample[0][:,:,::-1] clean_image = cv2.cvtColor(clean_image, cv2.COLOR_BGR2GRAY) sample = cv2.cvtColor(sample, cv2.COLOR_BGR2GRAY) psnr_im = psnr(clean_image,sample) # print(img_name[0]) # print(psnr_im) psnr_val = psnr_val + psnr_im psnr_val = psnr_val/number print('psnr =') print(psnr_val) # wandb.log({"psnr": psnr_val}) if best_psnr < psnr_val: best_psnr = psnr_val self.save_val() # Run for a finite amount of time in integration tests. # if os.environ.get("DIFFUSION_TRAINING_TEST", "") and self.step > 0: # return self.step += 1 # Save the last checkpoint if it wasn't already saved. # if (self.step - 1) % self.save_interval != 0: # self.save() def run_step(self, batch, cond): self.forward_backward(batch, cond) took_step = self.mp_trainer.optimize(self.opt) if took_step: self._update_ema() self._anneal_lr() self.log_step() def forward_backward(self, batch, cond): self.mp_trainer.zero_grad() num_im = 0 loss_wandb = 0 for i in range(0, batch.shape[0], self.microbatch): num_im = num_im + 1 micro = batch[i : i + self.microbatch].to(dist_util.dev()) micro_cond = { k: v[i : i + self.microbatch].to(dist_util.dev()) for k, v in cond.items() } last_batch = (i + self.microbatch) >= batch.shape[0] t, weights = self.schedule_sampler.sample(micro.shape[0], dist_util.dev()) # print(t.shape) # print(t) compute_losses = functools.partial( self.diffusion.training_losses, self.ddp_model, micro, t, model_kwargs=micro_cond, ) if last_batch or not self.use_ddp: losses = compute_losses() else: with self.ddp_model.no_sync(): losses = compute_losses() if isinstance(self.schedule_sampler, LossAwareSampler): self.schedule_sampler.update_with_local_losses( t, losses["loss"].detach() ) loss = (losses["loss"] * weights).mean() loss_wandb = th.log10(loss) + loss_wandb log_loss_dict( self.diffusion, t, {k: v * weights for k, v in losses.items()} ) self.mp_trainer.backward(loss) loss_wandb_f = loss_wandb/num_im # wandb.log({"loss": loss_wandb_f}) def _update_ema(self): for rate, params in zip(self.ema_rate, self.ema_params): update_ema(params, self.mp_trainer.master_params, rate=rate) def _anneal_lr(self): if not self.lr_anneal_steps: return frac_done = (self.step + self.resume_step) / self.lr_anneal_steps lr = self.lr * (1 - frac_done) for param_group in self.opt.param_groups: param_group["lr"] = lr def log_step(self): logger.logkv("step", self.step + self.resume_step) logger.logkv("samples", (self.step + self.resume_step + 1) * self.global_batch) def save(self): def save_checkpoint(rate, params): state_dict = self.mp_trainer.master_params_to_state_dict(params) if dist.get_rank() == 0: logger.log(f"saving model {rate}...") if not rate: filename = f"model{(self.step+self.resume_step):06d}.pt" else: filename = f"ema_{rate}_{(self.step+self.resume_step):06d}.pt" with bf.BlobFile(bf.join("./weights", filename), "wb") as f: th.save(state_dict, f) save_checkpoint(0, self.mp_trainer.master_params) for rate, params in zip(self.ema_rate, self.ema_params): save_checkpoint(rate, params) if dist.get_rank() == 0: with bf.BlobFile( bf.join(get_blob_logdir(), f"opt{(self.step+self.resume_step):06d}.pt"), "wb", ) as f: th.save(self.opt.state_dict(), f) dist.barrier() def save_val(self): def save_checkpoint_val(rate, params): state_dict = self.mp_trainer.master_params_to_state_dict(params) if dist.get_rank() == 0: logger.log(f"saving model {rate}...") if not rate: filename = f"model{(self.step+self.resume_step):06d}.pt" else: filename = f"ema_{rate}_{(self.step+self.resume_step):06d}.pt" with bf.BlobFile(bf.join("./weights", filename), "wb") as f: th.save(state_dict, f) save_checkpoint_val(0, self.mp_trainer.master_params) for rate, params in zip(self.ema_rate, self.ema_params): save_checkpoint_val(rate, params) if dist.get_rank() == 0: with bf.BlobFile( bf.join(get_blob_logdir(), f"opt{(self.step+self.resume_step):06d}.pt"), "wb", ) as f: th.save(self.opt.state_dict(), f) dist.barrier() def parse_resume_step_from_filename(filename): """ Parse filenames of the form path/to/modelNNNNNN.pt, where NNNNNN is the checkpoint's number of steps. """ split = filename.split("model") if len(split) < 2: return 0 split1 = split[-1].split(".")[0] try: return int(split1) except ValueError: return 0 def get_blob_logdir(): # You can change this to be a separate path to save checkpoints to # a blobstore or some external drive. return logger.get_dir() def find_resume_checkpoint(): # On your infrastructure, you may want to override this to automatically # discover the latest checkpoint on your blob storage, etc. return None def find_ema_checkpoint(main_checkpoint, step, rate): if main_checkpoint is None: return None filename = f"ema_{rate}_{(step):06d}.pt" path = bf.join(bf.dirname(main_checkpoint), filename) if bf.exists(path): return path return None def log_loss_dict(diffusion, ts, losses): for key, values in losses.items(): logger.logkv_mean(key, values.mean().item()) # Log the quantiles (four quartiles, in particular). for sub_t, sub_loss in zip(ts.cpu().numpy(), values.detach().cpu().numpy()): quartile = int(4 * sub_t / diffusion.num_timesteps) logger.logkv_mean(f"{key}_q{quartile}", sub_loss)