| | |
| | import os |
| | import math |
| | import re |
| | import torch |
| | import numpy as np |
| | import random |
| | import gc |
| | from datetime import datetime |
| | from pathlib import Path |
| |
|
| | import torchvision.transforms as transforms |
| | import torch.nn.functional as F |
| | from torch.utils.data import DataLoader, Dataset |
| | from torch.optim.lr_scheduler import LambdaLR |
| | from diffusers import AutoencoderKL, AsymmetricAutoencoderKL |
| | |
| | from diffusers import AutoencoderKLQwenImage |
| | from diffusers import AutoencoderKLWan |
| |
|
| | from accelerate import Accelerator |
| | from PIL import Image, UnidentifiedImageError |
| | from tqdm import tqdm |
| | import bitsandbytes as bnb |
| | import wandb |
| | import lpips |
| | from collections import deque |
| |
|
| | |
| | ds_path = "/workspace/d23" |
| | project = "vae3" |
| | batch_size = 5 |
| | base_learning_rate = 5e-5 |
| | min_learning_rate = 1e-5 |
| | num_epochs = 50 |
| | sample_interval_share = 2 |
| | use_wandb = True |
| | save_model = True |
| | use_decay = True |
| | optimizer_type = "adam8bit" |
| | dtype = torch.float32 |
| |
|
| | model_resolution = 256 |
| | high_resolution = 512 |
| | limit = 0 |
| | save_barrier = 1.3 |
| | warmup_percent = 0.001 |
| | percentile_clipping = 99 |
| | beta2 = 0.997 |
| | eps = 1e-8 |
| | clip_grad_norm = 1.0 |
| | mixed_precision = "no" |
| | gradient_accumulation_steps = 2 |
| | generated_folder = "samples" |
| | save_as = "vae3" |
| | num_workers = 0 |
| | device = None |
| |
|
| | |
| | |
| | train_decoder_only = True |
| | full_training = False |
| | kl_ratio = 0.00 |
| |
|
| | |
| | loss_ratios = { |
| | "lpips": 0.75, |
| | "edge": 0.05, |
| | "mse": 0.10, |
| | "mae": 0.10, |
| | "kl": 0.00, |
| | } |
| | median_coeff_steps = 1000 |
| |
|
| | resize_long_side = 1280 |
| |
|
| | |
| | vae_kind = "kl" |
| |
|
| | Path(generated_folder).mkdir(parents=True, exist_ok=True) |
| |
|
| | accelerator = Accelerator( |
| | mixed_precision=mixed_precision, |
| | gradient_accumulation_steps=gradient_accumulation_steps |
| | ) |
| | device = accelerator.device |
| |
|
| | |
| | seed = int(datetime.now().strftime("%Y%m%d")) |
| | torch.manual_seed(seed); np.random.seed(seed); random.seed(seed) |
| | torch.backends.cudnn.benchmark = False |
| |
|
| | |
| | if use_wandb and accelerator.is_main_process: |
| | wandb.init(project=project, config={ |
| | "batch_size": batch_size, |
| | "base_learning_rate": base_learning_rate, |
| | "num_epochs": num_epochs, |
| | "optimizer_type": optimizer_type, |
| | "model_resolution": model_resolution, |
| | "high_resolution": high_resolution, |
| | "gradient_accumulation_steps": gradient_accumulation_steps, |
| | "train_decoder_only": train_decoder_only, |
| | "full_training": full_training, |
| | "kl_ratio": kl_ratio, |
| | "vae_kind": vae_kind, |
| | }) |
| |
|
| | |
| | def get_core_model(model): |
| | m = model |
| | |
| | if hasattr(m, "_orig_mod"): |
| | m = m._orig_mod |
| | return m |
| |
|
| | def is_video_vae(model) -> bool: |
| | |
| | if vae_kind in ("wan", "qwen"): |
| | return True |
| | |
| | try: |
| | core = get_core_model(model) |
| | enc = getattr(core, "encoder", None) |
| | conv_in = getattr(enc, "conv_in", None) |
| | w = getattr(conv_in, "weight", None) |
| | if isinstance(w, torch.nn.Parameter): |
| | return w.ndim == 5 |
| | except Exception: |
| | pass |
| | return False |
| |
|
| | |
| | if vae_kind == "qwen": |
| | vae = AutoencoderKLQwenImage.from_pretrained("Qwen/Qwen-Image", subfolder="vae") |
| | else: |
| | if vae_kind == "wan": |
| | vae = AutoencoderKLWan.from_pretrained(project) |
| | else: |
| | |
| | if model_resolution==high_resolution: |
| | vae = AutoencoderKL.from_pretrained(project) |
| | else: |
| | vae = AsymmetricAutoencoderKL.from_pretrained(project) |
| |
|
| | vae = vae.to(dtype) |
| |
|
| | |
| | if hasattr(torch, "compile"): |
| | try: |
| | vae = torch.compile(vae) |
| | except Exception as e: |
| | print(f"[WARN] torch.compile failed: {e}") |
| |
|
| | |
| | core = get_core_model(vae) |
| |
|
| | for p in core.parameters(): |
| | p.requires_grad = False |
| |
|
| | unfrozen_param_names = [] |
| |
|
| | if full_training and not train_decoder_only: |
| | for name, p in core.named_parameters(): |
| | p.requires_grad = True |
| | unfrozen_param_names.append(name) |
| | loss_ratios["kl"] = float(kl_ratio) |
| | trainable_module = core |
| | else: |
| | |
| | if hasattr(core, "decoder"): |
| | if hasattr(core.decoder, "up_blocks") and len(core.decoder.up_blocks) > 0: |
| | |
| | for name, p in core.decoder.up_blocks[0].named_parameters(): |
| | p.requires_grad = True |
| | unfrozen_param_names.append(f"{name}") |
| | else: |
| | print("[WARN] Decoder has no up_blocks — fallback to full decoder") |
| | for name, p in core.decoder.named_parameters(): |
| | p.requires_grad = True |
| | unfrozen_param_names.append(f"decoder.{name}") |
| | if hasattr(core, "post_quant_conv"): |
| | for name, p in core.post_quant_conv.named_parameters(): |
| | p.requires_grad = True |
| | unfrozen_param_names.append(f"post_quant_conv.{name}") |
| | trainable_module = core.decoder if hasattr(core, "decoder") else core |
| |
|
| |
|
| | print(f"[INFO] Разморожено параметров: {len(unfrozen_param_names)}. Первые 200 имён:") |
| | for nm in unfrozen_param_names[:200]: |
| | print(" ", nm) |
| |
|
| | |
| | class PngFolderDataset(Dataset): |
| | def __init__(self, root_dir, min_exts=('.png',), resolution=1024, limit=0): |
| | self.root_dir = root_dir |
| | self.resolution = resolution |
| | self.paths = [] |
| | for root, _, files in os.walk(root_dir): |
| | for fname in files: |
| | if fname.lower().endswith(tuple(ext.lower() for ext in min_exts)): |
| | self.paths.append(os.path.join(root, fname)) |
| | if limit: |
| | self.paths = self.paths[:limit] |
| | valid = [] |
| | for p in self.paths: |
| | try: |
| | with Image.open(p) as im: |
| | im.verify() |
| | valid.append(p) |
| | except (OSError, UnidentifiedImageError): |
| | continue |
| | self.paths = valid |
| | if len(self.paths) == 0: |
| | raise RuntimeError(f"No valid PNG images found under {root_dir}") |
| | random.shuffle(self.paths) |
| |
|
| | def __len__(self): |
| | return len(self.paths) |
| |
|
| | def __getitem__(self, idx): |
| | p = self.paths[idx % len(self.paths)] |
| | with Image.open(p) as img: |
| | img = img.convert("RGB") |
| | if not resize_long_side or resize_long_side <= 0: |
| | return img |
| | w, h = img.size |
| | long = max(w, h) |
| | if long <= resize_long_side: |
| | return img |
| | scale = resize_long_side / float(long) |
| | new_w = int(round(w * scale)) |
| | new_h = int(round(h * scale)) |
| | return img.resize((new_w, new_h), Image.BICUBIC) |
| |
|
| | def random_crop(img, sz): |
| | w, h = img.size |
| | if w < sz or h < sz: |
| | img = img.resize((max(sz, w), max(sz, h)), Image.BICUBIC) |
| | x = random.randint(0, max(1, img.width - sz)) |
| | y = random.randint(0, max(1, img.height - sz)) |
| | return img.crop((x, y, x + sz, y + sz)) |
| |
|
| | tfm = transforms.Compose([ |
| | transforms.ToTensor(), |
| | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) |
| | ]) |
| |
|
| | dataset = PngFolderDataset(ds_path, min_exts=('.png',), resolution=high_resolution, limit=limit) |
| | if len(dataset) < batch_size: |
| | raise RuntimeError(f"Not enough valid images ({len(dataset)}) to form a batch of size {batch_size}") |
| |
|
| | def collate_fn(batch): |
| | imgs = [] |
| | for img in batch: |
| | img = random_crop(img, high_resolution) |
| | imgs.append(tfm(img)) |
| | return torch.stack(imgs) |
| |
|
| | dataloader = DataLoader( |
| | dataset, |
| | batch_size=batch_size, |
| | shuffle=True, |
| | collate_fn=collate_fn, |
| | num_workers=num_workers, |
| | pin_memory=True, |
| | drop_last=True |
| | ) |
| |
|
| | |
| | def get_param_groups(module, weight_decay=0.001): |
| | no_decay = ["bias", "LayerNorm.weight", "layer_norm.weight", "ln_1.weight", "ln_f.weight"] |
| | decay_params, no_decay_params = [], [] |
| | for n, p in vae.named_parameters(): |
| | if not p.requires_grad: |
| | continue |
| | if any(nd in n for nd in no_decay): |
| | no_decay_params.append(p) |
| | else: |
| | decay_params.append(p) |
| | return [ |
| | {"params": decay_params, "weight_decay": weight_decay}, |
| | {"params": no_decay_params, "weight_decay": 0.0}, |
| | ] |
| |
|
| | def get_param_groups(module, weight_decay=0.001): |
| | no_decay_tokens = ("bias", "norm", "rms", "layernorm") |
| | decay_params, no_decay_params = [], [] |
| | for n, p in module.named_parameters(): |
| | if not p.requires_grad: |
| | continue |
| | n_l = n.lower() |
| | if any(t in n_l for t in no_decay_tokens): |
| | no_decay_params.append(p) |
| | else: |
| | decay_params.append(p) |
| | return [ |
| | {"params": decay_params, "weight_decay": weight_decay}, |
| | {"params": no_decay_params, "weight_decay": 0.0}, |
| | ] |
| |
|
| | def create_optimizer(name, param_groups): |
| | if name == "adam8bit": |
| | return bnb.optim.AdamW8bit(param_groups, lr=base_learning_rate, betas=(0.9, beta2), eps=eps) |
| | raise ValueError(name) |
| |
|
| | param_groups = get_param_groups(get_core_model(vae), weight_decay=0.001) |
| | optimizer = create_optimizer(optimizer_type, param_groups) |
| |
|
| | |
| | batches_per_epoch = len(dataloader) |
| | steps_per_epoch = int(math.ceil(batches_per_epoch / float(gradient_accumulation_steps))) |
| | total_steps = steps_per_epoch * num_epochs |
| |
|
| | def lr_lambda(step): |
| | if not use_decay: |
| | return 1.0 |
| | x = float(step) / float(max(1, total_steps)) |
| | warmup = float(warmup_percent) |
| | min_ratio = float(min_learning_rate) / float(base_learning_rate) |
| | if x < warmup: |
| | return min_ratio + (1.0 - min_ratio) * (x / warmup) |
| | decay_ratio = (x - warmup) / (1.0 - warmup) |
| | return min_ratio + 0.5 * (1.0 - min_ratio) * (1.0 + math.cos(math.pi * decay_ratio)) |
| |
|
| | scheduler = LambdaLR(optimizer, lr_lambda) |
| |
|
| | |
| | dataloader, vae, optimizer, scheduler = accelerator.prepare(dataloader, vae, optimizer, scheduler) |
| | trainable_params = [p for p in vae.parameters() if p.requires_grad] |
| |
|
| | |
| | _lpips_net = None |
| | def _get_lpips(): |
| | global _lpips_net |
| | if _lpips_net is None: |
| | _lpips_net = lpips.LPIPS(net='vgg', verbose=False).eval().to(accelerator.device).eval() |
| | return _lpips_net |
| |
|
| | _sobel_kx = torch.tensor([[[[-1,0,1],[-2,0,2],[-1,0,1]]]], dtype=torch.float32) |
| | _sobel_ky = torch.tensor([[[[-1,-2,-1],[0,0,0],[1,2,1]]]], dtype=torch.float32) |
| | def sobel_edges(x: torch.Tensor) -> torch.Tensor: |
| | C = x.shape[1] |
| | kx = _sobel_kx.to(x.device, x.dtype).repeat(C, 1, 1, 1) |
| | ky = _sobel_ky.to(x.device, x.dtype).repeat(C, 1, 1, 1) |
| | gx = F.conv2d(x, kx, padding=1, groups=C) |
| | gy = F.conv2d(x, ky, padding=1, groups=C) |
| | return torch.sqrt(gx * gx + gy * gy + 1e-12) |
| |
|
| | class MedianLossNormalizer: |
| | def __init__(self, desired_ratios: dict, window_steps: int): |
| | s = sum(desired_ratios.values()) |
| | self.ratios = {k: (v / s) if s > 0 else 0.0 for k, v in desired_ratios.items()} |
| | self.buffers = {k: deque(maxlen=window_steps) for k in self.ratios.keys()} |
| | self.window = window_steps |
| |
|
| | def update_and_total(self, abs_losses: dict): |
| | for k, v in abs_losses.items(): |
| | if k in self.buffers: |
| | self.buffers[k].append(float(v.detach().abs().cpu())) |
| | meds = {k: (np.median(self.buffers[k]) if len(self.buffers[k]) > 0 else 1.0) for k in self.buffers} |
| | coeffs = {k: (self.ratios[k] / max(meds[k], 1e-12)) for k in self.ratios} |
| | total = sum(coeffs[k] * abs_losses[k] for k in abs_losses if k in coeffs) |
| | return total, coeffs, meds |
| |
|
| | if full_training and not train_decoder_only: |
| | loss_ratios["kl"] = float(kl_ratio) |
| | normalizer = MedianLossNormalizer(loss_ratios, median_coeff_steps) |
| |
|
| | |
| | @torch.no_grad() |
| | def get_fixed_samples(n=3): |
| | idx = random.sample(range(len(dataset)), min(n, len(dataset))) |
| | pil_imgs = [dataset[i] for i in idx] |
| | tensors = [] |
| | for img in pil_imgs: |
| | img = random_crop(img, high_resolution) |
| | tensors.append(tfm(img)) |
| | return torch.stack(tensors).to(accelerator.device, dtype) |
| |
|
| | fixed_samples = get_fixed_samples() |
| |
|
| | @torch.no_grad() |
| | def _to_pil_uint8(img_tensor: torch.Tensor) -> Image.Image: |
| | arr = ((img_tensor.float().clamp(-1, 1) + 1.0) * 127.5).clamp(0, 255).byte().cpu().numpy().transpose(1, 2, 0) |
| | return Image.fromarray(arr) |
| |
|
| |
|
| | @torch.no_grad() |
| | def generate_and_save_samples(step=None): |
| | try: |
| | temp_vae = accelerator.unwrap_model(vae).eval() |
| | lpips_net = _get_lpips() |
| | with torch.no_grad(): |
| | orig_high = fixed_samples |
| | orig_low = F.interpolate( |
| | orig_high, |
| | size=(model_resolution, model_resolution), |
| | mode="bilinear", |
| | align_corners=False |
| | ) |
| | model_dtype = next(temp_vae.parameters()).dtype |
| | orig_low = orig_low.to(dtype=model_dtype) |
| |
|
| | |
| | if is_video_vae(temp_vae): |
| | x_in = orig_low.unsqueeze(2) |
| | enc = temp_vae.encode(x_in) |
| | latents_mean = enc.latent_dist.mean |
| | dec = temp_vae.decode(latents_mean).sample |
| | rec = dec.squeeze(2) |
| | else: |
| | enc = temp_vae.encode(orig_low) |
| | latents_mean = enc.latent_dist.mean |
| | rec = temp_vae.decode(latents_mean).sample |
| |
|
| | |
| | if rec.shape[-2:] != orig_high.shape[-2:]: |
| | rec = F.interpolate(rec, size=orig_high.shape[-2:], mode="bilinear", align_corners=False) |
| |
|
| | |
| | for i in range(rec.shape[0]): |
| | real_img = _to_pil_uint8(orig_high[i]) |
| | dec_img = _to_pil_uint8(rec[i]) |
| | real_img.save(f"{generated_folder}/sample_real_{i}.jpg", quality=95) |
| | dec_img.save(f"{generated_folder}/sample_decoded_{i}.jpg", quality=95) |
| |
|
| | |
| | lpips_scores = [] |
| | for i in range(rec.shape[0]): |
| | orig_full = orig_high[i:i+1].to(torch.float32) |
| | rec_full = rec[i:i+1].to(torch.float32) |
| | if rec_full.shape[-2:] != orig_full.shape[-2:]: |
| | rec_full = F.interpolate(rec_full, size=orig_full.shape[-2:], mode="bilinear", align_corners=False) |
| | lpips_val = lpips_net(orig_full, rec_full).item() |
| | lpips_scores.append(lpips_val) |
| | avg_lpips = float(np.mean(lpips_scores)) |
| |
|
| | |
| | if use_wandb and accelerator.is_main_process: |
| | log_data = {"lpips_mean": avg_lpips} |
| | for i in range(rec.shape[0]): |
| | log_data[f"sample/real_{i}"] = wandb.Image(f"{generated_folder}/sample_real_{i}.jpg", caption=f"real_{i}") |
| | log_data[f"sample/decoded_{i}"] = wandb.Image(f"{generated_folder}/sample_decoded_{i}.jpg", caption=f"decoded_{i}") |
| | wandb.log(log_data, step=step) |
| |
|
| | finally: |
| | gc.collect() |
| | torch.cuda.empty_cache() |
| |
|
| |
|
| |
|
| | def generate_and_save_samples2(step=None): |
| | try: |
| | temp_vae = accelerator.unwrap_model(vae).eval() |
| | lpips_net = _get_lpips() |
| | with torch.no_grad(): |
| | orig_high = fixed_samples |
| | orig_low = F.interpolate(orig_high, size=(model_resolution, model_resolution), mode="bilinear", align_corners=False) |
| | model_dtype = next(temp_vae.parameters()).dtype |
| | orig_low = orig_low.to(dtype=model_dtype) |
| |
|
| | |
| | if is_video_vae(temp_vae): |
| | x_in = orig_low.unsqueeze(2) |
| | enc = temp_vae.encode(x_in) |
| | latents_mean = enc.latent_dist.mean |
| | dec = temp_vae.decode(latents_mean).sample |
| | rec = dec.squeeze(2) |
| | else: |
| | enc = temp_vae.encode(orig_low) |
| | latents_mean = enc.latent_dist.mean |
| | rec = temp_vae.decode(latents_mean).sample |
| |
|
| | if rec.shape[-2:] != orig_high.shape[-2:]: |
| | rec = F.interpolate(rec, size=orig_high.shape[-2:], mode="bilinear", align_corners=False) |
| |
|
| | first_real = _to_pil_uint8(orig_high[0]) |
| | first_dec = _to_pil_uint8(rec[0]) |
| | first_real.save(f"{generated_folder}/sample_real.jpg", quality=95) |
| | first_dec.save(f"{generated_folder}/sample_decoded.jpg", quality=95) |
| |
|
| | for i in range(rec.shape[0]): |
| | _to_pil_uint8(rec[i]).save(f"{generated_folder}/sample_{i}.jpg", quality=95) |
| |
|
| | lpips_scores = [] |
| | for i in range(rec.shape[0]): |
| | orig_full = orig_high[i:i+1].to(torch.float32) |
| | rec_full = rec[i:i+1].to(torch.float32) |
| | if rec_full.shape[-2:] != orig_full.shape[-2:]: |
| | rec_full = F.interpolate(rec_full, size=orig_full.shape[-2:], mode="bilinear", align_corners=False) |
| | lpips_val = lpips_net(orig_full, rec_full).item() |
| | lpips_scores.append(lpips_val) |
| | avg_lpips = float(np.mean(lpips_scores)) |
| |
|
| | if use_wandb and accelerator.is_main_process: |
| | wandb.log({"lpips_mean": avg_lpips}, step=step) |
| | wandb.log({ |
| | "sample/real": wandb.Image(first_real, caption="real"), |
| | "sample/decoded": wandb.Image(first_dec, caption="decoded"), |
| | }, step=step) |
| | finally: |
| | gc.collect() |
| | torch.cuda.empty_cache() |
| |
|
| | if accelerator.is_main_process and save_model: |
| | print("Генерация сэмплов до старта обучения...") |
| | generate_and_save_samples(0) |
| |
|
| | accelerator.wait_for_everyone() |
| |
|
| | |
| | progress = tqdm(total=total_steps, disable=not accelerator.is_local_main_process) |
| | global_step = 0 |
| | min_loss = float("inf") |
| | sample_interval = max(1, total_steps // max(1, sample_interval_share * num_epochs)) |
| |
|
| | for epoch in range(num_epochs): |
| | vae.train() |
| | batch_losses, batch_grads = [], [] |
| | track_losses = {k: [] for k in loss_ratios.keys()} |
| |
|
| | for imgs in dataloader: |
| | with accelerator.accumulate(vae): |
| | imgs = imgs.to(accelerator.device) |
| |
|
| | if high_resolution != model_resolution: |
| | imgs_low = F.interpolate(imgs, size=(model_resolution, model_resolution), mode="bilinear", align_corners=False) |
| | else: |
| | imgs_low = imgs |
| |
|
| | model_dtype = next(vae.parameters()).dtype |
| | imgs_low_model = imgs_low.to(dtype=model_dtype) if imgs_low.dtype != model_dtype else imgs_low |
| |
|
| | |
| | if is_video_vae(vae): |
| | x_in = imgs_low_model.unsqueeze(2) |
| | enc = vae.encode(x_in) |
| | latents = enc.latent_dist.mean if train_decoder_only else enc.latent_dist.sample() |
| | dec = vae.decode(latents).sample |
| | rec = dec.squeeze(2) |
| | else: |
| | enc = vae.encode(imgs_low_model) |
| | latents = enc.latent_dist.mean if train_decoder_only else enc.latent_dist.sample() |
| | rec = vae.decode(latents).sample |
| |
|
| | if rec.shape[-2:] != imgs.shape[-2:]: |
| | rec = F.interpolate(rec, size=imgs.shape[-2:], mode="bilinear", align_corners=False) |
| |
|
| | rec_f32 = rec.to(torch.float32) |
| | imgs_f32 = imgs.to(torch.float32) |
| |
|
| | abs_losses = { |
| | "mae": F.l1_loss(rec_f32, imgs_f32), |
| | "mse": F.mse_loss(rec_f32, imgs_f32), |
| | "lpips": _get_lpips()(rec_f32, imgs_f32).mean(), |
| | "edge": F.l1_loss(sobel_edges(rec_f32), sobel_edges(imgs_f32)), |
| | } |
| |
|
| | if full_training and not train_decoder_only: |
| | mean = enc.latent_dist.mean |
| | logvar = enc.latent_dist.logvar |
| | kl = -0.5 * torch.mean(1 + logvar - mean.pow(2) - logvar.exp()) |
| | abs_losses["kl"] = kl |
| | else: |
| | abs_losses["kl"] = torch.tensor(0.0, device=accelerator.device, dtype=torch.float32) |
| |
|
| | total_loss, coeffs, meds = normalizer.update_and_total(abs_losses) |
| |
|
| | if torch.isnan(total_loss) or torch.isinf(total_loss): |
| | raise RuntimeError("NaN/Inf loss") |
| |
|
| | accelerator.backward(total_loss) |
| |
|
| | grad_norm = torch.tensor(0.0, device=accelerator.device) |
| | if accelerator.sync_gradients: |
| | grad_norm = accelerator.clip_grad_norm_(trainable_params, clip_grad_norm) |
| | optimizer.step() |
| | scheduler.step() |
| | optimizer.zero_grad(set_to_none=True) |
| | global_step += 1 |
| | progress.update(1) |
| |
|
| | if accelerator.is_main_process: |
| | try: |
| | current_lr = optimizer.param_groups[0]["lr"] |
| | except Exception: |
| | current_lr = scheduler.get_last_lr()[0] |
| |
|
| | batch_losses.append(total_loss.detach().item()) |
| | batch_grads.append(float(grad_norm.detach().cpu().item()) if isinstance(grad_norm, torch.Tensor) else float(grad_norm)) |
| | for k, v in abs_losses.items(): |
| | track_losses[k].append(float(v.detach().item())) |
| |
|
| | if use_wandb and accelerator.sync_gradients: |
| | log_dict = { |
| | "total_loss": float(total_loss.detach().item()), |
| | "learning_rate": current_lr, |
| | "epoch": epoch, |
| | "grad_norm": batch_grads[-1], |
| | } |
| | for k, v in abs_losses.items(): |
| | log_dict[f"loss_{k}"] = float(v.detach().item()) |
| | for k in coeffs: |
| | log_dict[f"coeff_{k}"] = float(coeffs[k]) |
| | log_dict[f"median_{k}"] = float(meds[k]) |
| | wandb.log(log_dict, step=global_step) |
| |
|
| | if global_step > 0 and global_step % sample_interval == 0: |
| | if accelerator.is_main_process: |
| | generate_and_save_samples(global_step) |
| | accelerator.wait_for_everyone() |
| |
|
| | n_micro = sample_interval * gradient_accumulation_steps |
| | avg_loss = float(np.mean(batch_losses[-n_micro:])) if len(batch_losses) >= n_micro else float(np.mean(batch_losses)) if batch_losses else float("nan") |
| | avg_grad = float(np.mean(batch_grads[-n_micro:])) if len(batch_grads) >= 1 else float(np.mean(batch_grads)) if batch_grads else 0.0 |
| |
|
| | if accelerator.is_main_process: |
| | print(f"Epoch {epoch} step {global_step} loss: {avg_loss:.6f}, grad_norm: {avg_grad:.6f}, lr: {current_lr:.9f}") |
| | if save_model and avg_loss < min_loss * save_barrier: |
| | min_loss = avg_loss |
| | accelerator.unwrap_model(vae).save_pretrained(save_as) |
| | if use_wandb: |
| | wandb.log({"interm_loss": avg_loss, "interm_grad": avg_grad}, step=global_step) |
| |
|
| | if accelerator.is_main_process: |
| | epoch_avg = float(np.mean(batch_losses)) if batch_losses else float("nan") |
| | print(f"Epoch {epoch} done, avg loss {epoch_avg:.6f}") |
| | if use_wandb: |
| | wandb.log({"epoch_loss": epoch_avg, "epoch": epoch + 1}, step=global_step) |
| |
|
| | |
| | if accelerator.is_main_process: |
| | print("Training finished – saving final model") |
| | if save_model: |
| | accelerator.unwrap_model(vae).save_pretrained(save_as) |
| |
|
| | accelerator.free_memory() |
| | if torch.distributed.is_initialized(): |
| | torch.distributed.destroy_process_group() |
| | print("Готово!") |