""" UNIStainNet: Pixel-Space UNI-Guided Virtual Staining Network. Architecture: Generator: SPADE-UNet conditioned on UNI pathology features + stain/class embedding Discriminator: Multi-scale PatchGAN (512 + 256) Losses: LPIPS@128 + adversarial + DAB intensity + DAB contrast References: - Park et al., "Semantic Image Synthesis with SPADE" (CVPR 2019) - Chen et al., "A general-purpose self-supervised model for pathology" (Nature Medicine 2024) - Isola et al., "Image-to-Image Translation with pix2pix" (CVPR 2017) """ import copy from collections import OrderedDict import torch import torch.nn as nn import torch.nn.functional as F import lpips import pytorch_lightning as pl import torchvision import wandb from src.models.discriminator import ( PatchDiscriminator, MultiScaleDiscriminator, hinge_loss_d, hinge_loss_g, r1_gradient_penalty, feature_matching_loss, ) from src.models.generator import SPADEUNetGenerator from src.models.losses import VGGFeatureExtractor, gram_matrix, PatchNCELoss from src.utils.dab import DABExtractor # ====================================================================== # Training Module # ====================================================================== class UNIStainNetTrainer(pl.LightningModule): """PyTorch Lightning training module for UNIStainNet. Handles GAN training with manual optimization, CFG dropout, EMA, and all loss computations. """ def __init__( self, # Architecture num_classes=5, null_class=4, class_dim=64, uni_dim=1024, ndf=64, disc_n_layers=3, input_skip=False, # Optimizer gen_lr=1e-4, disc_lr=4e-4, warmup_steps=1000, # Loss weights lpips_weight=1.0, lpips_256_weight=0.5, lpips_512_weight=0.0, adversarial_weight=1.0, dab_intensity_weight=0.1, dab_contrast_weight=0.05, dab_sharpness_weight=0.0, gram_style_weight=0.0, edge_weight=0.0, he_edge_weight=0.0, bg_white_weight=0.0, bg_threshold=0.85, l1_lowres_weight=0.0, edge_encoder=False, edge_base_ch=32, uni_spatial_size=4, uncond_disc_weight=0.0, crop_disc_weight=0.0, crop_size=128, feat_match_weight=0.0, patchnce_weight=0.0, patchnce_layers=(2, 3, 4), patchnce_n_patches=256, patchnce_temperature=0.07, # Ablation disable_uni=False, disable_class=False, # GAN training r1_weight=10.0, r1_every=16, adversarial_start_step=2000, # CFG cfg_drop_class_prob=0.10, cfg_drop_uni_prob=0.10, cfg_drop_both_prob=0.05, # EMA ema_decay=0.999, # On-the-fly UNI extraction (for crop-based training) extract_uni_on_the_fly=False, uni_spatial_pool_size=32, # Resolution image_size=512, # 1024 architecture: extend UNI SPADE to 512 level uni_spade_at_512=False, # Per-label names for multi-stain logging label_names=None, ): super().__init__() self.save_hyperparameters() self.automatic_optimization = False self.null_class = null_class # On-the-fly UNI feature extraction (loaded lazily on first use) self._uni_model = None self._uni_extract_on_the_fly = extract_uni_on_the_fly # Generator self.generator = SPADEUNetGenerator( num_classes=num_classes, class_dim=class_dim, uni_dim=uni_dim, input_skip=input_skip, edge_encoder=edge_encoder, edge_base_ch=edge_base_ch, uni_spatial_size=uni_spatial_size, image_size=image_size, uni_spade_at_512=uni_spade_at_512, ) # Discriminator (global multi-scale) self.discriminator = MultiScaleDiscriminator( in_channels=6, ndf=ndf, n_layers=disc_n_layers, ) # Crop discriminator (local full-res detail) if crop_disc_weight > 0: self.crop_discriminator = PatchDiscriminator( in_channels=6, ndf=ndf, n_layers=disc_n_layers, ) else: self.crop_discriminator = None # Unconditional discriminator (HER2-only, alignment-free texture judge) # Also needed for feature matching loss (FM uses uncond disc features) if uncond_disc_weight > 0 or feat_match_weight > 0: self.uncond_discriminator = PatchDiscriminator( in_channels=3, ndf=ndf, n_layers=disc_n_layers, ) else: self.uncond_discriminator = None # PatchNCE loss (contrastive, alignment-free: H&E input vs generated) if patchnce_weight > 0: # Encoder channel dims: {1: 64, 2: 128, 3: 256, 4: 512} enc_channels = {1: 64, 2: 128, 3: 256, 4: 512} layer_channels = {l: enc_channels[l] for l in patchnce_layers} self.patchnce_loss = PatchNCELoss( layer_channels=layer_channels, num_patches=patchnce_n_patches, temperature=patchnce_temperature, ) else: self.patchnce_loss = None # EMA generator self.generator_ema = copy.deepcopy(self.generator) self.generator_ema.requires_grad_(False) # Losses self.lpips_fn = lpips.LPIPS(net='alex') self.lpips_fn.requires_grad_(False) self.lpips_fn.eval() self.dab_extractor = DABExtractor(device='cpu') # VGG feature extractor for Gram-matrix style loss if gram_style_weight > 0: self.vgg_extractor = VGGFeatureExtractor() self.vgg_extractor.requires_grad_(False) self.vgg_extractor.eval() else: self.vgg_extractor = None # Param counts n_gen = sum(p.numel() for p in self.generator.parameters()) n_disc = sum(p.numel() for p in self.discriminator.parameters()) n_crop = sum(p.numel() for p in self.crop_discriminator.parameters()) if self.crop_discriminator else 0 n_uncond = sum(p.numel() for p in self.uncond_discriminator.parameters()) if self.uncond_discriminator else 0 print(f"Generator: {n_gen:,} params") print(f"Discriminator: {n_disc:,} params (global) + {n_crop:,} (crop) + {n_uncond:,} (uncond)") def configure_optimizers(self): gen_params = list(self.generator.parameters()) if self.patchnce_loss is not None: gen_params += list(self.patchnce_loss.parameters()) opt_g = torch.optim.Adam( gen_params, lr=self.hparams.gen_lr, betas=(0.0, 0.999), ) # All discriminator params in one optimizer disc_params = list(self.discriminator.parameters()) if self.crop_discriminator is not None: disc_params += list(self.crop_discriminator.parameters()) if self.uncond_discriminator is not None: disc_params += list(self.uncond_discriminator.parameters()) opt_d = torch.optim.Adam( disc_params, lr=self.hparams.disc_lr, betas=(0.0, 0.999), ) return [opt_g, opt_d] def _get_lr_scale(self): """Linear warmup.""" if self.global_step < self.hparams.warmup_steps: return self.global_step / max(1, self.hparams.warmup_steps) return 1.0 @torch.no_grad() def _update_ema(self): """Update EMA generator weights.""" decay = self.hparams.ema_decay for p_ema, p in zip(self.generator_ema.parameters(), self.generator.parameters()): p_ema.data.mul_(decay).add_(p.data, alpha=1 - decay) def on_save_checkpoint(self, checkpoint): """Exclude frozen UNI model from checkpoint (it's reloaded on-the-fly).""" state_dict = checkpoint.get('state_dict', {}) keys_to_remove = [k for k in state_dict if k.startswith('_uni_model.')] for k in keys_to_remove: del state_dict[k] def on_load_checkpoint(self, checkpoint): """Filter out UNI model keys from old checkpoints that included them.""" state_dict = checkpoint.get('state_dict', {}) keys_to_remove = [k for k in state_dict if k.startswith('_uni_model.')] for k in keys_to_remove: del state_dict[k] def _load_uni_model(self): """Lazily load UNI ViT-L/16 for on-the-fly feature extraction.""" if self._uni_model is None: import timm self._uni_model = timm.create_model( "hf-hub:MahmoodLab/uni", pretrained=True, init_values=1e-5, dynamic_img_size=True, ) self._uni_model.eval() self._uni_model.requires_grad_(False) self._uni_model = self._uni_model.to(self.device) n_params = sum(p.numel() for p in self._uni_model.parameters()) print(f"UNI model loaded for on-the-fly extraction: {n_params:,} params") return self._uni_model @torch.no_grad() def _extract_uni_from_sub_crops(self, uni_sub_crops): """Extract UNI features from pre-prepared sub-crops on GPU. Args: uni_sub_crops: [B, 16, 3, 224, 224] — batch of 4x4 sub-crop grids, already normalized with ImageNet stats. Returns: uni_features: [B, S*S, 1024] where S = uni_spatial_pool_size (default 32) """ uni_model = self._load_uni_model() B = uni_sub_crops.shape[0] spatial_size = self.hparams.uni_spatial_pool_size num_crops = 4 # 4x4 grid patches_per_side = 14 # 224/16 # Batched UNI forward: [B, 16, 3, 224, 224] -> [B*16, 3, 224, 224] all_crops = uni_sub_crops.reshape(B * 16, 3, 224, 224).to(self.device) all_feats = uni_model.forward_features(all_crops) # [B*16, 197, 1024] patch_tokens = all_feats[:, 1:, :] # [B*16, 196, 1024] # Reshape back to per-sample grids: [B, 4, 4, 14, 14, 1024] patch_tokens = patch_tokens.reshape( B, num_crops, num_crops, patches_per_side, patches_per_side, 1024 ) # Interleave to spatial grid: [B, 56, 56, 1024] full_size = num_crops * patches_per_side # 56 full_grid = patch_tokens.permute(0, 1, 3, 2, 4, 5) full_grid = full_grid.reshape(B, full_size, full_size, 1024) # Pool to target spatial size (batched) if spatial_size < full_size: grid_bchw = full_grid.permute(0, 3, 1, 2) # [B, 1024, 56, 56] pooled = F.adaptive_avg_pool2d(grid_bchw, spatial_size) # [B, 1024, S, S] result = pooled.permute(0, 2, 3, 1) # [B, S, S, 1024] else: result = full_grid S = result.shape[1] return result.reshape(B, S * S, 1024) # [B, S*S, 1024] def _apply_cfg_dropout(self, labels, uni_features): """Apply classifier-free guidance dropout during training (vectorized).""" B = labels.shape[0] device = labels.device new_labels = labels.clone() new_uni = uni_features.clone() r = torch.rand(B, device=device) p_both = self.hparams.cfg_drop_both_prob p_class = p_both + self.hparams.cfg_drop_class_prob p_uni = p_class + self.hparams.cfg_drop_uni_prob drop_both = r < p_both drop_class = (r >= p_both) & (r < p_class) drop_uni = (r >= p_class) & (r < p_uni) new_labels[drop_both | drop_class] = self.null_class new_uni[drop_both | drop_uni] = 0.0 return new_labels, new_uni def compute_dab_intensity_loss(self, generated, target): """Top-10% percentile matching for DAB intensity.""" with torch.amp.autocast('cuda', enabled=False): gen = generated.float() tgt = target.float() dab_gen = self.dab_extractor.extract_dab_intensity(gen, normalize="none") dab_tgt = self.dab_extractor.extract_dab_intensity(tgt, normalize="none") def _batched_top10_mean(dab): """Compute mean of top-10% DAB intensity per sample (vectorized).""" B = dab.shape[0] flat = dab.reshape(B, -1) # [B, H*W] p99 = torch.quantile(flat, 0.99, dim=1, keepdim=True) flat = flat.clamp(max=p99) p90 = torch.quantile(flat, 0.9, dim=1, keepdim=True) mask = flat >= p90 # [B, H*W] # Use masked mean: sum(vals * mask) / sum(mask), fallback to flat mean masked_sum = (flat * mask).sum(dim=1) mask_count = mask.sum(dim=1).clamp(min=1) return masked_sum / mask_count # [B] gen_scores = _batched_top10_mean(dab_gen) tgt_scores = _batched_top10_mean(dab_tgt) return F.l1_loss(gen_scores, tgt_scores) def compute_dab_contrast_loss(self, generated, labels): """Class-ordering hinge loss: DAB(3+) > DAB(2+) > DAB(1+) > DAB(0).""" with torch.amp.autocast('cuda', enabled=False): gen = generated.float() # Only use non-null labels valid = labels < self.null_class if valid.sum() < 2: return torch.tensor(0.0, device=self.device, requires_grad=True) gen_valid = gen[valid] labels_valid = labels[valid] dab_gen = self.dab_extractor.extract_dab_intensity(gen_valid, normalize="none") B = dab_gen.shape[0] flat = dab_gen.reshape(B, -1) p99 = torch.quantile(flat, 0.99, dim=1, keepdim=True) flat = flat.clamp(max=p99) p90 = torch.quantile(flat, 0.9, dim=1, keepdim=True) mask = flat >= p90 masked_sum = (flat * mask).sum(dim=1) mask_count = mask.sum(dim=1).clamp(min=1) dab_scores = masked_sum / mask_count class_pairs = [ (3, 0, 0.20), (3, 1, 0.15), (2, 0, 0.08), (3, 2, 0.10), ] losses = [] for high_cls, low_cls, margin in class_pairs: high_mask = labels_valid == high_cls low_mask = labels_valid == low_cls if high_mask.sum() > 0 and low_mask.sum() > 0: high_score = dab_scores[high_mask].mean() low_score = dab_scores[low_mask].mean() losses.append(F.relu(margin - (high_score - low_score))) if losses: return torch.stack(losses).mean() return torch.tensor(0.0, device=self.device, requires_grad=True) def compute_edge_loss(self, generated, target): """Fourier spectral loss at 256x256 for boundary sharpness. Compares power spectrum magnitudes between generated and target. The Fourier magnitude is inherently translation-invariant — shifting an image doesn't change its frequency content — so this is robust to the ~30px misalignment in consecutive-cut BCI pairs. Focuses on high-frequency bands (outer 75% of spectrum) where blurriness manifests as reduced power. """ with torch.amp.autocast('cuda', enabled=False): gen = F.interpolate(generated.float(), size=256, mode='bilinear', align_corners=False) tgt = F.interpolate(target.float(), size=256, mode='bilinear', align_corners=False) # Grayscale gen_gray = gen.mean(dim=1, keepdim=True) tgt_gray = tgt.mean(dim=1, keepdim=True) # 2D FFT -> power spectrum (log-scale for stability) gen_fft = torch.fft.fft2(gen_gray) tgt_fft = torch.fft.fft2(tgt_gray) gen_mag = torch.log1p(gen_fft.abs()) tgt_mag = torch.log1p(tgt_fft.abs()) # High-frequency mask: keep outer 75% of spectrum H, W = gen_mag.shape[-2], gen_mag.shape[-1] cy, cx = H // 2, W // 2 y = torch.arange(H, device=gen.device).float() - cy x = torch.arange(W, device=gen.device).float() - cx dist = (y[:, None] ** 2 + x[None, :] ** 2).sqrt() max_dist = (cy ** 2 + cx ** 2) ** 0.5 hf_mask = (dist > 0.25 * max_dist).float() # L1 on high-frequency magnitudes return F.l1_loss(gen_mag * hf_mask, tgt_mag * hf_mask) def compute_dab_sharpness_loss(self, generated, target): """DAB spatial sharpness loss: penalizes diffuse brown, rewards membrane-localized DAB. Two components: 1. DAB gradient magnitude: mean Sobel gradient magnitude per image. 2. DAB local variance distribution: sorted-L1 (Wasserstein-1) on patch variance vectors. """ with torch.amp.autocast('cuda', enabled=False): gen = generated.float() tgt = target.float() dab_gen = self.dab_extractor.extract_dab_intensity(gen, normalize="none") dab_tgt = self.dab_extractor.extract_dab_intensity(tgt, normalize="none") # Ensure [B, 1, H, W] if dab_gen.dim() == 3: dab_gen = dab_gen.unsqueeze(1) if dab_tgt.dim() == 3: dab_tgt = dab_tgt.unsqueeze(1) B = dab_gen.shape[0] # --- Component 1: Gradient magnitude (batched) --- sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32, device=gen.device).view(1, 1, 3, 3) sobel_y = sobel_x.transpose(-1, -2) gx_gen = F.conv2d(dab_gen, sobel_x, padding=1) gy_gen = F.conv2d(dab_gen, sobel_y, padding=1) grad_gen = (gx_gen**2 + gy_gen**2 + 1e-8).sqrt() gx_tgt = F.conv2d(dab_tgt, sobel_x, padding=1) gy_tgt = F.conv2d(dab_tgt, sobel_y, padding=1) grad_tgt = (gx_tgt**2 + gy_tgt**2 + 1e-8).sqrt() # Match mean gradient magnitude per image grad_loss = F.l1_loss(grad_gen.mean(dim=[1, 2, 3]), grad_tgt.mean(dim=[1, 2, 3])) # --- Component 2: Local variance distribution (sorted-L1) --- ps = 16 # patch size var_losses = [] for i in range(B): g = dab_gen[i, 0] # [H, W] t = dab_tgt[i, 0] H, W = g.shape nH, nW = H // ps, W // ps g_patches = g[:nH*ps, :nW*ps].reshape(nH, ps, nW, ps).permute(0, 2, 1, 3).reshape(-1, ps*ps) t_patches = t[:nH*ps, :nW*ps].reshape(nH, ps, nW, ps).permute(0, 2, 1, 3).reshape(-1, ps*ps) g_var = g_patches.var(dim=1) t_var = t_patches.var(dim=1) g_sorted, _ = g_var.sort() t_sorted, _ = t_var.sort() var_losses.append(F.l1_loss(g_sorted, t_sorted.detach())) var_loss = torch.stack(var_losses).mean() return grad_loss + var_loss def compute_he_edge_loss(self, generated, he_input): """H&E edge structure preservation loss. Extracts Sobel edges from H&E input and generated output, then computes L1 loss between edge maps at multiple scales. """ with torch.amp.autocast('cuda', enabled=False): gen = generated.float() he = he_input.float() gen_gray = ((gen + 1) / 2).mean(dim=1, keepdim=True) he_gray = ((he + 1) / 2).mean(dim=1, keepdim=True) sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32, device=gen.device).view(1, 1, 3, 3) sobel_y = sobel_x.transpose(-1, -2) loss = 0.0 full_size = gen_gray.shape[-1] scales = [full_size, full_size // 2] for size in scales: if size < full_size: g = F.interpolate(gen_gray, size=size, mode='bilinear', align_corners=False) h = F.interpolate(he_gray, size=size, mode='bilinear', align_corners=False) else: g, h = gen_gray, he_gray gx_gen = F.conv2d(g, sobel_x, padding=1) gy_gen = F.conv2d(g, sobel_y, padding=1) edge_gen = (gx_gen**2 + gy_gen**2 + 1e-8).sqrt() gx_he = F.conv2d(h, sobel_x, padding=1) gy_he = F.conv2d(h, sobel_y, padding=1) edge_he = (gx_he**2 + gy_he**2 + 1e-8).sqrt() loss = loss + F.l1_loss(edge_gen, edge_he.detach()) return loss / 2.0 def compute_background_loss(self, generated, he_input): """Background white loss: push background regions toward white.""" with torch.amp.autocast('cuda', enabled=False): gen = generated.float() he = he_input.float() he_bright = ((he + 1) / 2).mean(dim=1, keepdim=True) threshold = self.hparams.bg_threshold mask = torch.sigmoid((he_bright - threshold) * 20.0) white_target = torch.ones_like(gen) diff = (gen - white_target).abs() weighted_diff = diff * mask mask_sum = mask.sum() * 3 if mask_sum > 0: return weighted_diff.sum() / mask_sum return torch.tensor(0.0, device=gen.device, requires_grad=True) def compute_gram_style_loss(self, generated, target): """Gram-matrix style loss: match texture statistics via VGG feature correlations.""" with torch.amp.autocast('cuda', enabled=False): gen = generated.float() tgt = target.float() gen_256 = F.interpolate(gen, size=256, mode='bilinear', align_corners=False) tgt_256 = F.interpolate(tgt, size=256, mode='bilinear', align_corners=False) gen_feats = self.vgg_extractor(gen_256) tgt_feats = self.vgg_extractor(tgt_256) loss = 0.0 for gf, tf in zip(gen_feats, tgt_feats): gram_gen = gram_matrix(gf) gram_tgt = gram_matrix(tf) loss = loss + F.l1_loss(gram_gen, gram_tgt.detach()) return loss / len(gen_feats) def training_step(self, batch, batch_idx): he, her2, uni_or_crops, labels, fnames = batch opt_g, opt_d = self.optimizers() # On-the-fly UNI extraction: dataset returns [B, 16, 3, 224, 224] sub-crops if self._uni_extract_on_the_fly: uni = self._extract_uni_from_sub_crops(uni_or_crops) else: uni = uni_or_crops # Apply CFG dropout labels_dropped, uni_dropped = self._apply_cfg_dropout(labels, uni) # Ablation: zero out UNI features if self.hparams.disable_uni: uni_dropped = torch.zeros_like(uni_dropped) # Ablation: force all labels to null class if self.hparams.disable_class: labels_dropped = torch.full_like(labels_dropped, self.hparams.null_class) # ---------------------------------------------------------------- # Generator step # ---------------------------------------------------------------- generated = self.generator(he, uni_dropped, labels_dropped) # LPIPS main: 4x downsample (128 for 512 input, 256 for 1024) lpips_main_size = self.hparams.image_size // 4 gen_lpips = F.interpolate(generated, size=lpips_main_size, mode='bilinear', align_corners=False) her2_lpips = F.interpolate(her2, size=lpips_main_size, mode='bilinear', align_corners=False) loss_lpips = self.lpips_fn(gen_lpips, her2_lpips).mean() loss_g = self.hparams.lpips_weight * loss_lpips # LPIPS fine: 2x downsample (256 for 512 input, 512 for 1024) if self.hparams.lpips_256_weight > 0: lpips_fine_size = self.hparams.image_size // 2 gen_fine = F.interpolate(generated, size=lpips_fine_size, mode='bilinear', align_corners=False) her2_fine = F.interpolate(her2, size=lpips_fine_size, mode='bilinear', align_corners=False) loss_lpips_256 = self.lpips_fn(gen_fine, her2_fine).mean() loss_g = loss_g + self.hparams.lpips_256_weight * loss_lpips_256 self.log('train/lpips_fine', loss_lpips_256, prog_bar=False) # LPIPS at full resolution (expensive) if self.hparams.lpips_512_weight > 0: loss_lpips_512 = self.lpips_fn(generated, her2).mean() loss_g = loss_g + self.hparams.lpips_512_weight * loss_lpips_512 self.log('train/lpips_fullres', loss_lpips_512, prog_bar=False) # Low-resolution L1 (color fidelity, misalignment-robust at 64x64) if self.hparams.l1_lowres_weight > 0: gen_64 = F.interpolate(generated, size=64, mode='bilinear', align_corners=False) her2_64 = F.interpolate(her2, size=64, mode='bilinear', align_corners=False) loss_l1_lowres = F.l1_loss(gen_64, her2_64) loss_g = loss_g + self.hparams.l1_lowres_weight * loss_l1_lowres self.log('train/l1_lowres', loss_l1_lowres, prog_bar=False) # DAB losses (use original labels, not dropped) if self.hparams.dab_intensity_weight > 0: loss_dab = self.compute_dab_intensity_loss(generated, her2) loss_g = loss_g + self.hparams.dab_intensity_weight * loss_dab self.log('train/dab_intensity', loss_dab, prog_bar=False) if self.hparams.dab_contrast_weight > 0: loss_dab_contrast = self.compute_dab_contrast_loss(generated, labels) loss_g = loss_g + self.hparams.dab_contrast_weight * loss_dab_contrast self.log('train/dab_contrast', loss_dab_contrast, prog_bar=False) # Edge loss (boundary sharpness) if self.hparams.edge_weight > 0: loss_edge = self.compute_edge_loss(generated, her2) loss_g = loss_g + self.hparams.edge_weight * loss_edge self.log('train/edge_loss', loss_edge, prog_bar=False) # DAB sharpness loss (membrane-localized vs diffuse brown) if self.hparams.dab_sharpness_weight > 0: loss_dab_sharp = self.compute_dab_sharpness_loss(generated, her2) loss_g = loss_g + self.hparams.dab_sharpness_weight * loss_dab_sharp self.log('train/dab_sharpness', loss_dab_sharp, prog_bar=False) # Gram-matrix style loss if self.hparams.gram_style_weight > 0 and self.vgg_extractor is not None: loss_gram = self.compute_gram_style_loss(generated, her2) loss_g = loss_g + self.hparams.gram_style_weight * loss_gram self.log('train/gram_style', loss_gram, prog_bar=False) # H&E edge structure preservation (pixel-aligned) if self.hparams.he_edge_weight > 0: loss_he_edge = self.compute_he_edge_loss(generated, he) loss_g = loss_g + self.hparams.he_edge_weight * loss_he_edge self.log('train/he_edge', loss_he_edge, prog_bar=False) # Background white loss if self.hparams.bg_white_weight > 0: loss_bg = self.compute_background_loss(generated, he) loss_g = loss_g + self.hparams.bg_white_weight * loss_bg self.log('train/bg_white', loss_bg, prog_bar=False) # PatchNCE loss (contrastive: H&E input vs generated, never sees GT) if self.hparams.patchnce_weight > 0 and self.patchnce_loss is not None: feats_he = self.generator.encode(he) feats_gen = self.generator.encode(generated) loss_nce = self.patchnce_loss(feats_he, feats_gen) loss_g = loss_g + self.hparams.patchnce_weight * loss_nce self.log('train/patchnce', loss_nce, prog_bar=False) # Adversarial losses (after warmup) loss_adv = torch.tensor(0.0, device=self.device) loss_feat_match = torch.tensor(0.0, device=self.device) loss_crop_adv = torch.tensor(0.0, device=self.device) loss_uncond_adv = torch.tensor(0.0, device=self.device) any_adv = (self.hparams.adversarial_weight > 0 or self.hparams.uncond_disc_weight > 0 or self.hparams.crop_disc_weight > 0 or self.hparams.feat_match_weight > 0) img_sz = self.hparams.image_size # Pre-compute disc-resolution tensors (512 for 1024 input, identity for 512) if img_sz == 1024: he_for_disc = F.interpolate(he, size=512, mode='bilinear', align_corners=False) her2_for_disc = F.interpolate(her2, size=512, mode='bilinear', align_corners=False) else: he_for_disc = he her2_for_disc = her2 if self.global_step >= self.hparams.adversarial_start_step and any_adv: if img_sz == 1024: gen_for_disc = F.interpolate(generated, size=512, mode='bilinear', align_corners=False) else: gen_for_disc = generated # Conditional discriminator (paired: generated+HE vs real_HER2+HE) if self.hparams.adversarial_weight > 0: fake_input = torch.cat([gen_for_disc, he_for_disc], dim=1) disc_outputs = self.discriminator(fake_input) loss_adv = sum(hinge_loss_g(out) for out in disc_outputs) / len(disc_outputs) loss_g = loss_g + self.hparams.adversarial_weight * loss_adv # Feature matching from unconditional disc if (self.hparams.feat_match_weight > 0 and self.uncond_discriminator is not None): _, fake_feats = self.uncond_discriminator(gen_for_disc, return_features=True) with torch.no_grad(): _, real_feats = self.uncond_discriminator(her2_for_disc, return_features=True) loss_feat_match = feature_matching_loss(fake_feats, real_feats) loss_g = loss_g + self.hparams.feat_match_weight * loss_feat_match # Crop discriminator: random crops at full resolution if self.crop_discriminator is not None and self.hparams.crop_disc_weight > 0: fake_input_crop = torch.cat([generated, he], dim=1) cs = self.hparams.crop_size top = torch.randint(0, img_sz - cs, (1,)).item() left = torch.randint(0, img_sz - cs, (1,)).item() fake_crop = fake_input_crop[:, :, top:top+cs, left:left+cs] loss_crop_adv = hinge_loss_g(self.crop_discriminator(fake_crop)) loss_g = loss_g + self.hparams.crop_disc_weight * loss_crop_adv # Unconditional discriminator: HER2-only adversarial if self.uncond_discriminator is not None and self.hparams.uncond_disc_weight > 0: loss_uncond_adv = hinge_loss_g(self.uncond_discriminator(gen_for_disc)) loss_g = loss_g + self.hparams.uncond_disc_weight * loss_uncond_adv # Generator backward + step lr_scale = self._get_lr_scale() for pg in opt_g.param_groups: pg['lr'] = self.hparams.gen_lr * lr_scale opt_g.zero_grad() self.manual_backward(loss_g) torch.nn.utils.clip_grad_norm_(self.generator.parameters(), 1.0) opt_g.step() # Update EMA self._update_ema() # ---------------------------------------------------------------- # Discriminator step # ---------------------------------------------------------------- loss_d = torch.tensor(0.0, device=self.device) loss_crop_d = torch.tensor(0.0, device=self.device) loss_uncond_d = torch.tensor(0.0, device=self.device) if self.global_step >= self.hparams.adversarial_start_step and any_adv: with torch.no_grad(): fake_detached = self.generator(he, uni_dropped, labels_dropped) # For 1024, downsample for disc if img_sz == 1024: fake_det_disc = F.interpolate(fake_detached, size=512, mode='bilinear', align_corners=False) else: fake_det_disc = fake_detached # Conditional discriminator if self.hparams.adversarial_weight > 0: real_input = torch.cat([her2_for_disc, he_for_disc], dim=1) fake_input = torch.cat([fake_det_disc, he_for_disc], dim=1) disc_real = self.discriminator(real_input) disc_fake = self.discriminator(fake_input) loss_d = sum( hinge_loss_d(dr, df) for dr, df in zip(disc_real, disc_fake) ) / len(disc_real) # Crop discriminator if self.crop_discriminator is not None and self.hparams.crop_disc_weight > 0: real_input_c = torch.cat([her2, he], dim=1) fake_input_c = torch.cat([fake_detached, he], dim=1) cs = self.hparams.crop_size top = torch.randint(0, img_sz - cs, (1,)).item() left = torch.randint(0, img_sz - cs, (1,)).item() real_crop = real_input_c[:, :, top:top+cs, left:left+cs] fake_crop = fake_input_c[:, :, top:top+cs, left:left+cs] loss_crop_d = hinge_loss_d( self.crop_discriminator(real_crop), self.crop_discriminator(fake_crop), ) loss_d = loss_d + self.hparams.crop_disc_weight * loss_crop_d # Unconditional discriminator if self.uncond_discriminator is not None and ( self.hparams.uncond_disc_weight > 0 or self.hparams.feat_match_weight > 0): uncond_real_out = self.uncond_discriminator(her2_for_disc) uncond_fake_out = self.uncond_discriminator(fake_det_disc) loss_uncond_d = hinge_loss_d(uncond_real_out, uncond_fake_out) loss_d = loss_d + max(self.hparams.uncond_disc_weight, 1.0) * loss_uncond_d # R1 gradient penalty loss_r1 = torch.tensor(0.0, device=self.device) if self.global_step % self.hparams.r1_every == 0: with torch.amp.autocast('cuda', enabled=False): if self.hparams.adversarial_weight > 0: real_input_r1 = torch.cat([her2_for_disc, he_for_disc], dim=1).float().detach().requires_grad_(True) for disc in [self.discriminator.disc_512]: d_real = disc(real_input_r1) grad_real = torch.autograd.grad( outputs=d_real.sum(), inputs=real_input_r1, create_graph=True, )[0] loss_r1 = loss_r1 + self.hparams.r1_weight * grad_real.pow(2).mean() if self.uncond_discriminator is not None and ( self.hparams.uncond_disc_weight > 0 or self.hparams.feat_match_weight > 0): her2_r1 = her2_for_disc.float().detach().requires_grad_(True) d_real_uncond = self.uncond_discriminator(her2_r1) grad_uncond = torch.autograd.grad( outputs=d_real_uncond.sum(), inputs=her2_r1, create_graph=True, )[0] loss_r1 = loss_r1 + self.hparams.r1_weight * grad_uncond.pow(2).mean() loss_d = loss_d + loss_r1 self.log('train/r1_penalty', loss_r1, prog_bar=False) opt_d.zero_grad() self.manual_backward(loss_d) if self.hparams.adversarial_weight > 0: torch.nn.utils.clip_grad_norm_(self.discriminator.parameters(), 1.0) if self.crop_discriminator is not None: torch.nn.utils.clip_grad_norm_(self.crop_discriminator.parameters(), 1.0) if self.uncond_discriminator is not None: torch.nn.utils.clip_grad_norm_(self.uncond_discriminator.parameters(), 1.0) opt_d.step() # Logging self.log('train/loss_g', loss_g, prog_bar=True) self.log('train/loss_d', loss_d, prog_bar=True) self.log('train/lpips', loss_lpips, prog_bar=True) self.log('train/adversarial', loss_adv, prog_bar=False) self.log('train/lr_scale', lr_scale, prog_bar=False) if self.crop_discriminator is not None: self.log('train/crop_adv_g', loss_crop_adv, prog_bar=False) self.log('train/crop_adv_d', loss_crop_d, prog_bar=False) if self.uncond_discriminator is not None: self.log('train/uncond_adv_g', loss_uncond_adv, prog_bar=False) self.log('train/uncond_adv_d', loss_uncond_d, prog_bar=False) if self.hparams.feat_match_weight > 0: self.log('train/feat_match', loss_feat_match, prog_bar=False) def on_validation_epoch_start(self): # Pick a random batch index for the second sample grid n_val_batches = max(1, len(self.trainer.val_dataloaders)) self._random_val_batch_idx = torch.randint(1, max(2, n_val_batches), (1,)).item() # Per-label sample collectors (for multi-stain visual grids) self._val_per_label_samples = {} def _log_sample_grid(self, he, her2_01, gen_01, key): """Log H&E | Real | Gen grid to wandb.""" n = min(4, len(he)) he_01 = ((he[:n].cpu() + 1) / 2).clamp(0, 1) grid_images = [] for i in range(n): grid_images.extend([ he_01[i], her2_01[i].cpu(), gen_01[i].cpu(), ]) grid = torchvision.utils.make_grid(grid_images, nrow=3, padding=2) if self.logger: self.logger.experiment.log({ key: [wandb.Image(grid, caption='H&E | Real | Gen')], 'global_step': self.global_step, }) def validation_step(self, batch, batch_idx): he, her2, uni_or_crops, labels, fnames = batch # On-the-fly UNI extraction if self._uni_extract_on_the_fly: uni = self._extract_uni_from_sub_crops(uni_or_crops) else: uni = uni_or_crops if self.hparams.disable_uni: uni = torch.zeros_like(uni) if self.hparams.disable_class: labels = torch.full_like(labels, self.hparams.null_class) # Use EMA generator with torch.no_grad(): generated = self.generator_ema(he, uni, labels) # LPIPS (4x downsample: 128 for 512, 256 for 1024) lpips_size = self.hparams.image_size // 4 gen_lpips = F.interpolate(generated, size=lpips_size, mode='bilinear', align_corners=False) her2_lpips = F.interpolate(her2, size=lpips_size, mode='bilinear', align_corners=False) lpips_val = self.lpips_fn(gen_lpips, her2_lpips).mean() # SSIM gen_01 = ((generated + 1) / 2).clamp(0, 1) her2_01 = ((her2 + 1) / 2).clamp(0, 1) from torchmetrics.functional.image import structural_similarity_index_measure ssim_val = structural_similarity_index_measure(gen_01, her2_01, data_range=1.0) # DAB MAE (canonical: mean of top-10%) dab_gen = self.dab_extractor.extract_dab_intensity(generated.float().cpu(), normalize="none") dab_real = self.dab_extractor.extract_dab_intensity(her2.float().cpu(), normalize="none") def p90_score(dab): flat = dab.flatten() p90 = torch.quantile(flat, 0.9) mask = flat >= p90 return flat[mask].mean().item() if mask.sum() > 0 else flat.mean().item() dab_mae = sum( abs(p90_score(dab_gen[i]) - p90_score(dab_real[i])) for i in range(len(dab_gen)) ) / len(dab_gen) self.log('val/lpips', lpips_val, prog_bar=True, sync_dist=True) self.log('val/ssim', ssim_val, prog_bar=True, sync_dist=True) self.log('val/dab_mae', dab_mae, prog_bar=True, sync_dist=True) # Collect per-label samples for visual grids (multi-stain only) if hasattr(self, '_val_per_label_samples'): for i in range(len(labels)): lbl = labels[i].item() if lbl == self.hparams.null_class: continue if lbl not in self._val_per_label_samples: self._val_per_label_samples[lbl] = {'he': [], 'real': [], 'gen': []} bucket = self._val_per_label_samples[lbl] if len(bucket['he']) < 4: bucket['he'].append(he[i].cpu()) bucket['real'].append(her2_01[i].cpu()) bucket['gen'].append(gen_01[i].cpu()) # Log sample grids: first batch (fixed) + one random batch if batch_idx == 0: self._log_sample_grid(he, her2_01, gen_01, 'val/samples_fixed') elif batch_idx == self._random_val_batch_idx: self._log_sample_grid(he, her2_01, gen_01, 'val/samples_random') def on_validation_epoch_end(self): """Log per-label sample grids if multiple labels are present.""" if not hasattr(self, '_val_per_label_samples') or len(self._val_per_label_samples) <= 1: return label_names = getattr(self.hparams, 'label_names', None) for lbl, bucket in sorted(self._val_per_label_samples.items()): if not bucket['he'] or not self.logger: continue name = label_names[lbl] if label_names and lbl < len(label_names) else str(lbl) self._log_sample_grid( torch.stack(bucket['he']), torch.stack(bucket['real']), torch.stack(bucket['gen']), f'val/samples_{name}', ) self._val_per_label_samples = {} @torch.no_grad() def generate(self, he_images, uni_features, labels, num_inference_steps=None, guidance_scale=1.0, seed=None): """Generate IHC images from H&E input. Args: he_images: [B, 3, H, H] where H=512 or H=1024 uni_features: [B, N, 1024] where N=16 (4x4 CLS) or N=1024 (32x32 patch) labels: [B] class/stain labels num_inference_steps: ignored (single forward pass) guidance_scale: CFG scale (1.0 = no guidance) seed: random seed (for reproducibility, though model is deterministic) """ if seed is not None: torch.manual_seed(seed) gen = self.generator_ema if hasattr(self, 'generator_ema') else self.generator if guidance_scale <= 1.0: return gen(he_images, uni_features, labels) # Classifier-free guidance null_labels = torch.full_like(labels, self.null_class) output_cond = gen(he_images, uni_features, labels) output_uncond = gen(he_images, uni_features, null_labels) output = output_uncond + guidance_scale * (output_cond - output_uncond) return output.clamp(-1, 1)