| from inference_avwm import model_forward_wrapper_v |
| import torch |
| |
| torch.backends.cuda.matmul.allow_tf32 = True |
| torch.backends.cudnn.allow_tf32 = True |
|
|
| import matplotlib |
| matplotlib.use('Agg') |
| from collections import OrderedDict |
| from copy import deepcopy |
| from time import time |
| import argparse |
| import logging |
| import os |
| import matplotlib.pyplot as plt |
| import yaml |
|
|
|
|
| import torch.distributed as dist |
| from torch.nn.parallel import DistributedDataParallel as DDP |
| from torch.utils.data import DataLoader, ConcatDataset |
| from torch.utils.data.distributed import DistributedSampler |
| from diffusers.models import AutoencoderKL |
|
|
| from distributed import init_distributed |
| from models import AVCDiT_models |
| from diffusion import create_diffusion |
| from datasets import TrainingDataset |
| from misc import transform |
|
|
| |
| |
| |
|
|
|
|
| def load_checkpoint_if_available(model, ema, opt, scaler, config, device, logger, args): |
| start_epoch = 0 |
| train_steps = 0 |
| latest_path = os.path.join(config['results_dir'], config['run_name'], "checkpoints", "latest.pth.tar") |
| if os.path.isfile(latest_path) or config.get('from_checkpoint', 0): |
| latest_path = latest_path if os.path.isfile(latest_path) else config.get('from_checkpoint', 0) |
| print("Loading model from ", latest_path) |
| checkpoint = torch.load(latest_path, map_location=f"cuda:{device}", weights_only=False) |
|
|
| ema_ckp = {k.replace('_orig_mod.', ''): v for k, v in checkpoint["ema"].items()} |
| remapped = {} |
| for k, v in ema_ckp.items(): |
| new_k = k |
| |
| if k.startswith("pos_embed"): |
| new_k = k.replace("pos_embed", "pos_embed_v", 1) |
| |
| if new_k.startswith("x_embedder."): |
| new_k = new_k.replace("x_embedder.", "x_embedder_v.", 1) |
| |
| if new_k.startswith("blocks.") and ".mlp." in new_k: |
| new_k = new_k.replace(".mlp.", ".mlp_v.", 1) |
| remapped[new_k] = v |
| ema_ckp = remapped |
| model.load_state_dict(ema_ckp, strict=True) |
| print("Model weights loaded.") |
| ema.load_state_dict(ema_ckp, strict=True) |
| print("EMA weights loaded.") |
|
|
| if args.restart_from_checkpoint: |
| logger.info("Restarting training: epoch and step counters set to 0.") |
| else: |
| if "opt" in checkpoint: |
| opt_ckp = {k.replace('_orig_mod.', ''): v for k, v in checkpoint["opt"].items()} |
| opt.load_state_dict(opt_ckp) |
| print("Optimizer state loaded.") |
| if "scaler" in checkpoint and scaler is not None: |
| scaler.load_state_dict(checkpoint["scaler"]) |
| print("GradScaler state loaded.") |
| if "epoch" in checkpoint: |
| start_epoch = checkpoint["epoch"] + 1 |
| if "train_steps" in checkpoint: |
| train_steps = checkpoint["train_steps"] |
| logger.info(f"Resuming from epoch {start_epoch}, step {train_steps}") |
|
|
| return start_epoch, train_steps |
|
|
|
|
| @torch.no_grad() |
| def update_ema(ema_model, model, decay=0.9999): |
| """ |
| Step the EMA model towards the current model. |
| """ |
| ema_params = OrderedDict(ema_model.named_parameters()) |
| model_params = OrderedDict(model.named_parameters()) |
|
|
| for name, param in model_params.items(): |
| name = name.replace('_orig_mod.', '') |
| ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay) |
|
|
|
|
| def requires_grad(model, flag=True): |
| """ |
| Set requires_grad flag for all parameters in a model. |
| """ |
| for p in model.parameters(): |
| p.requires_grad = flag |
|
|
|
|
| def cleanup(): |
| """ |
| End DDP training. |
| """ |
| dist.destroy_process_group() |
|
|
|
|
| def create_logger(logging_dir): |
| """ |
| Create a logger that writes to a log file and stdout. |
| """ |
| if dist.get_rank() == 0: |
| logging.basicConfig( |
| level=logging.INFO, |
| format='[\033[34m%(asctime)s\033[0m] %(message)s', |
| datefmt='%Y-%m-%d %H:%M:%S', |
| handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")] |
| ) |
| logger = logging.getLogger(__name__) |
| else: |
| logger = logging.getLogger(__name__) |
| logger.addHandler(logging.NullHandler()) |
| return logger |
|
|
| |
| |
| |
|
|
| def main(args): |
| """ |
| Trains a new AVCDiT model. |
| """ |
| assert torch.cuda.is_available(), "Training currently requires at least one GPU." |
|
|
| |
| _, rank, device, _ = init_distributed() |
| |
| seed = args.global_seed * dist.get_world_size() + rank |
| torch.manual_seed(seed) |
| print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.") |
| with open("config/eval_config.yaml", "r") as f: |
| default_config = yaml.safe_load(f) |
| config = default_config |
| |
| with open(args.config, "r") as f: |
| user_config = yaml.safe_load(f) |
| config.update(user_config) |
| |
| |
| os.makedirs(config['results_dir'], exist_ok=True) |
| experiment_dir = f"{config['results_dir']}/{config['run_name']}" |
| checkpoint_dir = f"{experiment_dir}/checkpoints" |
| if rank == 0: |
| os.makedirs(checkpoint_dir, exist_ok=True) |
| logger = create_logger(experiment_dir) |
| logger.info(f"Experiment directory created at {experiment_dir}") |
| else: |
| logger = create_logger(None) |
|
|
| |
| tokenizer = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-ema").to(device) |
| latent_size = config['image_size'] // 8 |
|
|
| assert config['image_size'] % 8 == 0, "Image size must be divisible by 8 (for the VAE encoder)." |
| num_cond = config['context_size'] |
| model = AVCDiT_models[config['model']](context_size=num_cond, input_size=latent_size, in_channels=4, mode="v").to(device) |
| |
| ema = deepcopy(model).to(device) |
| requires_grad(ema, False) |
| |
| |
| lr = float(config.get('lr', 1e-4)) |
| opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0) |
|
|
|
|
| bfloat_enable = bool(hasattr(args, 'bfloat16') and args.bfloat16) |
| if bfloat_enable: |
| scaler = torch.amp.GradScaler() |
|
|
| |
| |
| |
| start_epoch, train_steps = load_checkpoint_if_available( |
| model, ema, opt, scaler if bfloat_enable else None, config, device, logger, args |
| ) |
| |
| |
| if args.torch_compile: |
| model = torch.compile(model) |
| model = DDP(model, device_ids=[device]) |
| diffusion = create_diffusion(timestep_respacing="") |
| |
| logger.info(f"AVCDiT Parameters: {sum(p.numel() for p in model.parameters()):,}") |
|
|
| train_dataset = [] |
| test_dataset = [] |
|
|
| for dataset_name in config["datasets"]: |
| data_config = config["datasets"][dataset_name] |
|
|
| for data_split_type in ["train", "test"]: |
| if data_split_type in data_config: |
| goals_per_obs = int(data_config["goals_per_obs"]) |
| if data_split_type == 'test': |
| goals_per_obs = 4 |
| |
| if "distance" in data_config: |
| min_dist_cat=data_config["distance"]["min_dist_cat"] |
| max_dist_cat=data_config["distance"]["max_dist_cat"] |
| else: |
| min_dist_cat=config["distance"]["min_dist_cat"] |
| max_dist_cat=config["distance"]["max_dist_cat"] |
|
|
| if "len_traj_pred" in data_config: |
| len_traj_pred=data_config["len_traj_pred"] |
| else: |
| len_traj_pred=config["len_traj_pred"] |
|
|
| dataset = TrainingDataset( |
| data_folder=data_config["data_folder"], |
| data_split_folder=data_config[data_split_type], |
| dataset_name=dataset_name, |
| image_size=config["image_size"], |
| min_dist_cat=min_dist_cat, |
| max_dist_cat=max_dist_cat, |
| len_traj_pred=len_traj_pred, |
| context_size=config["context_size"], |
| normalize=config["normalize"], |
| goals_per_obs=goals_per_obs, |
| transform=transform, |
| predefined_index=None, |
| traj_stride=1, |
| evaluate=(data_split_type=="test") |
| ) |
| if data_split_type == "train": |
| train_dataset.append(dataset) |
| else: |
| test_dataset.append(dataset) |
| print(f"Dataset: {dataset_name} ({data_split_type}), size: {len(dataset)}") |
|
|
| |
| print(f"Combining {len(train_dataset)} datasets.") |
| train_dataset = ConcatDataset(train_dataset) |
| test_dataset = ConcatDataset(test_dataset) |
|
|
| sampler = DistributedSampler( |
| train_dataset, |
| num_replicas=dist.get_world_size(), |
| rank=rank, |
| shuffle=True, |
| seed=args.global_seed |
| ) |
| loader = DataLoader( |
| train_dataset, |
| batch_size=config['batch_size'], |
| shuffle=False, |
| sampler=sampler, |
| num_workers=config['num_workers'], |
| pin_memory=True, |
| drop_last=True, |
| persistent_workers=True |
| ) |
| logger.info(f"Dataset contains {len(train_dataset):,} images") |
|
|
| |
| model.train() |
| ema.eval() |
|
|
| |
| log_steps = 0 |
| running_loss = 0 |
| start_time = time() |
|
|
| logger.info(f"Training for {args.epochs} epochs...") |
| for epoch in range(start_epoch, args.epochs): |
| sampler.set_epoch(epoch) |
| steps_per_epoch = len(loader) |
| if rank == 0: |
| logger.info(f"Epoch {epoch} contains {steps_per_epoch} steps.") |
| logger.info(f"Beginning epoch {epoch}...") |
|
|
| for x, _, y, _, rel_t in loader: |
| x = x.to(device, non_blocking=True) |
| y = y.to(device, non_blocking=True) |
| rel_t = rel_t.to(device, non_blocking=True) |
| |
| with torch.amp.autocast('cuda', enabled=bfloat_enable, dtype=torch.bfloat16): |
| with torch.no_grad(): |
| |
| B, T = x.shape[:2] |
| x = x.flatten(0,1) |
| x = tokenizer.encode(x).latent_dist.sample().mul_(0.18215) |
| x = x.unflatten(0, (B, T)) |
| |
| num_goals = T - num_cond |
| x_start = x[:, num_cond:].flatten(0, 1) |
| x_cond = x[:, :num_cond].unsqueeze(1).expand(B, num_goals, num_cond, x.shape[2], x.shape[3], x.shape[4]).flatten(0, 1) |
| y = y.flatten(0, 1) |
| rel_t = rel_t.flatten(0, 1) |
| |
| t = torch.randint(0, diffusion.num_timesteps, (x_start.shape[0],), device=device) |
| model_kwargs = dict(y=y, x_cond=x_cond, rel_t=rel_t) |
| loss_dict = diffusion.training_losses(model, x_start, t, model_kwargs) |
| loss = loss_dict["loss"].mean() |
|
|
| if not bfloat_enable: |
| opt.zero_grad() |
| loss.backward() |
| opt.step() |
| else: |
| scaler.scale(loss).backward() |
| if config.get('grad_clip_val', 0) > 0: |
| scaler.unscale_(opt) |
| torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config['grad_clip_val']) |
| scaler.step(opt) |
| scaler.update() |
| |
| update_ema(ema, model.module) |
|
|
| |
| running_loss += loss.detach().item() |
| log_steps += 1 |
| train_steps += 1 |
| if train_steps % args.log_every == 0: |
| |
| torch.cuda.synchronize() |
| end_time = time() |
| steps_per_sec = log_steps / (end_time - start_time) |
| samples_per_sec = dist.get_world_size()*x_cond.shape[0]*steps_per_sec |
| |
| avg_loss = torch.tensor(running_loss / log_steps, device=device) |
| dist.all_reduce(avg_loss, op=dist.ReduceOp.SUM) |
| avg_loss = avg_loss.item() / dist.get_world_size() |
| total_steps = len(loader) * args.epochs |
| progress_pct = train_steps / total_steps * 100 |
|
|
| remaining_steps = total_steps - train_steps |
| eta_seconds = remaining_steps / steps_per_sec if steps_per_sec > 0 else 0 |
| eta_hours = eta_seconds / 3600 |
|
|
| logger.info(f"(step={train_steps:07d}) Train Loss: {avg_loss:.4f}, Train Steps/Sec: {steps_per_sec:.2f}, Samples/Sec: {samples_per_sec:.2f}") |
| logger.info(f"Progress: {progress_pct:.2f}% | ETA: {eta_hours:.1f}h") |
| |
| running_loss = 0 |
| log_steps = 0 |
| start_time = time() |
|
|
| |
| if train_steps % args.ckpt_every == 0 and train_steps > 0: |
| if rank == 0: |
| checkpoint = { |
| "model": model.module.state_dict(), |
| "ema": ema.state_dict(), |
| "opt": opt.state_dict(), |
| "args": args, |
| "epoch": epoch, |
| "train_steps": train_steps |
| } |
| if bfloat_enable: |
| checkpoint.update({"scaler": scaler.state_dict()}) |
| checkpoint_path = f"{checkpoint_dir}/latest.pth.tar" |
| torch.save(checkpoint, checkpoint_path) |
| if train_steps % (10*args.ckpt_every) == 0 and train_steps > 0: |
| checkpoint_path = f"{checkpoint_dir}/{train_steps:07d}.pth.tar" |
| torch.save(checkpoint, checkpoint_path) |
| logger.info(f"Saved checkpoint to {checkpoint_path}") |
| |
| if train_steps % args.eval_every == 0 and train_steps > 0: |
| eval_start_time = time() |
| |
| save_dir = os.path.join(experiment_dir, str(train_steps)) |
| sim_score_val = evaluate(ema, tokenizer, diffusion, test_dataset, rank, config["batch_size"], config["num_workers"], latent_size, device, save_dir, args.global_seed, bfloat_enable, num_cond) |
| dist.barrier() |
| eval_end_time = time() |
| eval_time = eval_end_time - eval_start_time |
| |
| logger.info(f"(step={train_steps:07d}) Val Perceptual Loss: {sim_score_val:.4f}, Eval Time: {eval_time:.2f}") |
|
|
| model.eval() |
| logger.info("Done!") |
| cleanup() |
|
|
|
|
| @torch.no_grad |
| def evaluate(model, vae, diffusion, test_dataloaders, rank, batch_size, num_workers, latent_size, device, save_dir, seed, bfloat_enable, num_cond): |
| sampler = DistributedSampler( |
| test_dataloaders, |
| num_replicas=dist.get_world_size(), |
| rank=rank, |
| shuffle=True, |
| seed=seed |
| ) |
| loader = DataLoader( |
| test_dataloaders, |
| batch_size=batch_size, |
| shuffle=False, |
| sampler=sampler, |
| num_workers=num_workers, |
| pin_memory=True, |
| drop_last=True |
| ) |
| from dreamsim import dreamsim |
| eval_model, _ = dreamsim(pretrained=True) |
| score = torch.tensor(0.).to(device) |
| n_samples = torch.tensor(0).to(device) |
|
|
| |
| for x, _, y, _, rel_t, _ in loader: |
| x = x.to(device) |
| y = y.to(device) |
| rel_t = rel_t.to(device).flatten(0, 1) |
| with torch.amp.autocast('cuda', enabled=True, dtype=torch.bfloat16): |
| B, T = x.shape[:2] |
| num_goals = T - num_cond |
| samples = model_forward_wrapper_v((model, diffusion, vae), x, y, num_timesteps=None, latent_size=latent_size, device=device, num_cond=num_cond, num_goals=num_goals, rel_t=rel_t) |
| x_start_pixels = x[:, num_cond:].flatten(0, 1) |
| x_cond_pixels = x[:, :num_cond].unsqueeze(1).expand(B, num_goals, num_cond, x.shape[2], x.shape[3], x.shape[4]).flatten(0, 1) |
| samples = samples * 0.5 + 0.5 |
| x_start_pixels = x_start_pixels * 0.5 + 0.5 |
| x_cond_pixels = x_cond_pixels * 0.5 + 0.5 |
| res = eval_model(x_start_pixels, samples) |
| score += res.sum() |
| n_samples += len(res) |
| break |
| |
| if rank == 0: |
| os.makedirs(save_dir, exist_ok=True) |
| for i in range(min(samples.shape[0], 10)): |
| _, ax = plt.subplots(1,3,dpi=256) |
| ax[0].imshow((x_cond_pixels[i, -1].permute(1,2,0).cpu().numpy()*255).astype('uint8')) |
| ax[1].imshow((x_start_pixels[i].permute(1,2,0).cpu().numpy()*255).astype('uint8')) |
| ax[2].imshow((samples[i].permute(1,2,0).cpu().float().numpy()*255).astype('uint8')) |
| plt.savefig(f'{save_dir}/{i}.png') |
| plt.close() |
|
|
| dist.all_reduce(score) |
| dist.all_reduce(n_samples) |
| sim_score = score/n_samples |
| return sim_score |
|
|
|
|
| def get_args_parser(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--config", type=str, required=True) |
| parser.add_argument("--epochs", type=int, default=300) |
| |
| parser.add_argument("--global-seed", type=int, default=0) |
| parser.add_argument("--log-every", type=int, default=100) |
| parser.add_argument("--ckpt-every", type=int, default=2000) |
| parser.add_argument("--eval-every", type=int, default=5000) |
| parser.add_argument("--bfloat16", type=int, default=1) |
| parser.add_argument("--torch-compile", type=int, default=1) |
| parser.add_argument("--restart-from-checkpoint", type=int, default=0, |
| help="If 1, only load model weights and reset epoch/step to zero (cold start)") |
| return parser |
|
|
| if __name__ == "__main__": |
| args = get_args_parser().parse_args() |
| main(args) |
|
|