Spaces:
Running
Running
| """ | |
| UNIStainNet: Pixel-Space UNI-Guided Virtual Staining Network. | |
| Architecture: | |
| Generator: SPADE-UNet conditioned on UNI pathology features + stain/class embedding | |
| Discriminator: Multi-scale PatchGAN (512 + 256) | |
| Losses: LPIPS@128 + adversarial + DAB intensity + DAB contrast | |
| References: | |
| - Park et al., "Semantic Image Synthesis with SPADE" (CVPR 2019) | |
| - Chen et al., "A general-purpose self-supervised model for pathology" (Nature Medicine 2024) | |
| - Isola et al., "Image-to-Image Translation with pix2pix" (CVPR 2017) | |
| """ | |
| import copy | |
| from collections import OrderedDict | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import lpips | |
| import pytorch_lightning as pl | |
| import torchvision | |
| import wandb | |
| from src.models.discriminator import ( | |
| PatchDiscriminator, MultiScaleDiscriminator, | |
| hinge_loss_d, hinge_loss_g, r1_gradient_penalty, feature_matching_loss, | |
| ) | |
| from src.models.generator import SPADEUNetGenerator | |
| from src.models.losses import VGGFeatureExtractor, gram_matrix, PatchNCELoss | |
| from src.utils.dab import DABExtractor | |
| # ====================================================================== | |
| # Training Module | |
| # ====================================================================== | |
| class UNIStainNetTrainer(pl.LightningModule): | |
| """PyTorch Lightning training module for UNIStainNet. | |
| Handles GAN training with manual optimization, CFG dropout, EMA, and | |
| all loss computations. | |
| """ | |
| def __init__( | |
| self, | |
| # Architecture | |
| num_classes=5, | |
| null_class=4, | |
| class_dim=64, | |
| uni_dim=1024, | |
| ndf=64, | |
| disc_n_layers=3, | |
| input_skip=False, | |
| # Optimizer | |
| gen_lr=1e-4, | |
| disc_lr=4e-4, | |
| warmup_steps=1000, | |
| # Loss weights | |
| lpips_weight=1.0, | |
| lpips_256_weight=0.5, | |
| lpips_512_weight=0.0, | |
| adversarial_weight=1.0, | |
| dab_intensity_weight=0.1, | |
| dab_contrast_weight=0.05, | |
| dab_sharpness_weight=0.0, | |
| gram_style_weight=0.0, | |
| edge_weight=0.0, | |
| he_edge_weight=0.0, | |
| bg_white_weight=0.0, | |
| bg_threshold=0.85, | |
| l1_lowres_weight=0.0, | |
| edge_encoder=False, | |
| edge_base_ch=32, | |
| uni_spatial_size=4, | |
| uncond_disc_weight=0.0, | |
| crop_disc_weight=0.0, | |
| crop_size=128, | |
| feat_match_weight=0.0, | |
| patchnce_weight=0.0, | |
| patchnce_layers=(2, 3, 4), | |
| patchnce_n_patches=256, | |
| patchnce_temperature=0.07, | |
| # Ablation | |
| disable_uni=False, | |
| disable_class=False, | |
| # GAN training | |
| r1_weight=10.0, | |
| r1_every=16, | |
| adversarial_start_step=2000, | |
| # CFG | |
| cfg_drop_class_prob=0.10, | |
| cfg_drop_uni_prob=0.10, | |
| cfg_drop_both_prob=0.05, | |
| # EMA | |
| ema_decay=0.999, | |
| # On-the-fly UNI extraction (for crop-based training) | |
| extract_uni_on_the_fly=False, | |
| uni_spatial_pool_size=32, | |
| # Resolution | |
| image_size=512, | |
| # 1024 architecture: extend UNI SPADE to 512 level | |
| uni_spade_at_512=False, | |
| # Per-label names for multi-stain logging | |
| label_names=None, | |
| ): | |
| super().__init__() | |
| self.save_hyperparameters() | |
| self.automatic_optimization = False | |
| self.null_class = null_class | |
| # On-the-fly UNI feature extraction (loaded lazily on first use) | |
| self._uni_model = None | |
| self._uni_extract_on_the_fly = extract_uni_on_the_fly | |
| # Generator | |
| self.generator = SPADEUNetGenerator( | |
| num_classes=num_classes, | |
| class_dim=class_dim, | |
| uni_dim=uni_dim, | |
| input_skip=input_skip, | |
| edge_encoder=edge_encoder, | |
| edge_base_ch=edge_base_ch, | |
| uni_spatial_size=uni_spatial_size, | |
| image_size=image_size, | |
| uni_spade_at_512=uni_spade_at_512, | |
| ) | |
| # Discriminator (global multi-scale) | |
| self.discriminator = MultiScaleDiscriminator( | |
| in_channels=6, ndf=ndf, n_layers=disc_n_layers, | |
| ) | |
| # Crop discriminator (local full-res detail) | |
| if crop_disc_weight > 0: | |
| self.crop_discriminator = PatchDiscriminator( | |
| in_channels=6, ndf=ndf, n_layers=disc_n_layers, | |
| ) | |
| else: | |
| self.crop_discriminator = None | |
| # Unconditional discriminator (HER2-only, alignment-free texture judge) | |
| # Also needed for feature matching loss (FM uses uncond disc features) | |
| if uncond_disc_weight > 0 or feat_match_weight > 0: | |
| self.uncond_discriminator = PatchDiscriminator( | |
| in_channels=3, ndf=ndf, n_layers=disc_n_layers, | |
| ) | |
| else: | |
| self.uncond_discriminator = None | |
| # PatchNCE loss (contrastive, alignment-free: H&E input vs generated) | |
| if patchnce_weight > 0: | |
| # Encoder channel dims: {1: 64, 2: 128, 3: 256, 4: 512} | |
| enc_channels = {1: 64, 2: 128, 3: 256, 4: 512} | |
| layer_channels = {l: enc_channels[l] for l in patchnce_layers} | |
| self.patchnce_loss = PatchNCELoss( | |
| layer_channels=layer_channels, | |
| num_patches=patchnce_n_patches, | |
| temperature=patchnce_temperature, | |
| ) | |
| else: | |
| self.patchnce_loss = None | |
| # EMA generator | |
| self.generator_ema = copy.deepcopy(self.generator) | |
| self.generator_ema.requires_grad_(False) | |
| # Losses | |
| self.lpips_fn = lpips.LPIPS(net='alex') | |
| self.lpips_fn.requires_grad_(False) | |
| self.lpips_fn.eval() | |
| self.dab_extractor = DABExtractor(device='cpu') | |
| # VGG feature extractor for Gram-matrix style loss | |
| if gram_style_weight > 0: | |
| self.vgg_extractor = VGGFeatureExtractor() | |
| self.vgg_extractor.requires_grad_(False) | |
| self.vgg_extractor.eval() | |
| else: | |
| self.vgg_extractor = None | |
| # Param counts | |
| n_gen = sum(p.numel() for p in self.generator.parameters()) | |
| n_disc = sum(p.numel() for p in self.discriminator.parameters()) | |
| n_crop = sum(p.numel() for p in self.crop_discriminator.parameters()) if self.crop_discriminator else 0 | |
| n_uncond = sum(p.numel() for p in self.uncond_discriminator.parameters()) if self.uncond_discriminator else 0 | |
| print(f"Generator: {n_gen:,} params") | |
| print(f"Discriminator: {n_disc:,} params (global) + {n_crop:,} (crop) + {n_uncond:,} (uncond)") | |
| def configure_optimizers(self): | |
| gen_params = list(self.generator.parameters()) | |
| if self.patchnce_loss is not None: | |
| gen_params += list(self.patchnce_loss.parameters()) | |
| opt_g = torch.optim.Adam( | |
| gen_params, | |
| lr=self.hparams.gen_lr, | |
| betas=(0.0, 0.999), | |
| ) | |
| # All discriminator params in one optimizer | |
| disc_params = list(self.discriminator.parameters()) | |
| if self.crop_discriminator is not None: | |
| disc_params += list(self.crop_discriminator.parameters()) | |
| if self.uncond_discriminator is not None: | |
| disc_params += list(self.uncond_discriminator.parameters()) | |
| opt_d = torch.optim.Adam( | |
| disc_params, | |
| lr=self.hparams.disc_lr, | |
| betas=(0.0, 0.999), | |
| ) | |
| return [opt_g, opt_d] | |
| def _get_lr_scale(self): | |
| """Linear warmup.""" | |
| if self.global_step < self.hparams.warmup_steps: | |
| return self.global_step / max(1, self.hparams.warmup_steps) | |
| return 1.0 | |
| def _update_ema(self): | |
| """Update EMA generator weights.""" | |
| decay = self.hparams.ema_decay | |
| for p_ema, p in zip(self.generator_ema.parameters(), self.generator.parameters()): | |
| p_ema.data.mul_(decay).add_(p.data, alpha=1 - decay) | |
| def on_save_checkpoint(self, checkpoint): | |
| """Exclude frozen UNI model from checkpoint (it's reloaded on-the-fly).""" | |
| state_dict = checkpoint.get('state_dict', {}) | |
| keys_to_remove = [k for k in state_dict if k.startswith('_uni_model.')] | |
| for k in keys_to_remove: | |
| del state_dict[k] | |
| def on_load_checkpoint(self, checkpoint): | |
| """Filter out UNI model keys from old checkpoints that included them.""" | |
| state_dict = checkpoint.get('state_dict', {}) | |
| keys_to_remove = [k for k in state_dict if k.startswith('_uni_model.')] | |
| for k in keys_to_remove: | |
| del state_dict[k] | |
| def _load_uni_model(self): | |
| """Lazily load UNI ViT-L/16 for on-the-fly feature extraction.""" | |
| if self._uni_model is None: | |
| import timm | |
| self._uni_model = timm.create_model( | |
| "hf-hub:MahmoodLab/uni", | |
| pretrained=True, | |
| init_values=1e-5, | |
| dynamic_img_size=True, | |
| ) | |
| self._uni_model.eval() | |
| self._uni_model.requires_grad_(False) | |
| self._uni_model = self._uni_model.to(self.device) | |
| n_params = sum(p.numel() for p in self._uni_model.parameters()) | |
| print(f"UNI model loaded for on-the-fly extraction: {n_params:,} params") | |
| return self._uni_model | |
| def _extract_uni_from_sub_crops(self, uni_sub_crops): | |
| """Extract UNI features from pre-prepared sub-crops on GPU. | |
| Args: | |
| uni_sub_crops: [B, 16, 3, 224, 224] — batch of 4x4 sub-crop grids, | |
| already normalized with ImageNet stats. | |
| Returns: | |
| uni_features: [B, S*S, 1024] where S = uni_spatial_pool_size (default 32) | |
| """ | |
| uni_model = self._load_uni_model() | |
| B = uni_sub_crops.shape[0] | |
| spatial_size = self.hparams.uni_spatial_pool_size | |
| num_crops = 4 # 4x4 grid | |
| patches_per_side = 14 # 224/16 | |
| # Batched UNI forward: [B, 16, 3, 224, 224] -> [B*16, 3, 224, 224] | |
| all_crops = uni_sub_crops.reshape(B * 16, 3, 224, 224).to(self.device) | |
| all_feats = uni_model.forward_features(all_crops) # [B*16, 197, 1024] | |
| patch_tokens = all_feats[:, 1:, :] # [B*16, 196, 1024] | |
| # Reshape back to per-sample grids: [B, 4, 4, 14, 14, 1024] | |
| patch_tokens = patch_tokens.reshape( | |
| B, num_crops, num_crops, | |
| patches_per_side, patches_per_side, 1024 | |
| ) | |
| # Interleave to spatial grid: [B, 56, 56, 1024] | |
| full_size = num_crops * patches_per_side # 56 | |
| full_grid = patch_tokens.permute(0, 1, 3, 2, 4, 5) | |
| full_grid = full_grid.reshape(B, full_size, full_size, 1024) | |
| # Pool to target spatial size (batched) | |
| if spatial_size < full_size: | |
| grid_bchw = full_grid.permute(0, 3, 1, 2) # [B, 1024, 56, 56] | |
| pooled = F.adaptive_avg_pool2d(grid_bchw, spatial_size) # [B, 1024, S, S] | |
| result = pooled.permute(0, 2, 3, 1) # [B, S, S, 1024] | |
| else: | |
| result = full_grid | |
| S = result.shape[1] | |
| return result.reshape(B, S * S, 1024) # [B, S*S, 1024] | |
| def _apply_cfg_dropout(self, labels, uni_features): | |
| """Apply classifier-free guidance dropout during training (vectorized).""" | |
| B = labels.shape[0] | |
| device = labels.device | |
| new_labels = labels.clone() | |
| new_uni = uni_features.clone() | |
| r = torch.rand(B, device=device) | |
| p_both = self.hparams.cfg_drop_both_prob | |
| p_class = p_both + self.hparams.cfg_drop_class_prob | |
| p_uni = p_class + self.hparams.cfg_drop_uni_prob | |
| drop_both = r < p_both | |
| drop_class = (r >= p_both) & (r < p_class) | |
| drop_uni = (r >= p_class) & (r < p_uni) | |
| new_labels[drop_both | drop_class] = self.null_class | |
| new_uni[drop_both | drop_uni] = 0.0 | |
| return new_labels, new_uni | |
| def compute_dab_intensity_loss(self, generated, target): | |
| """Top-10% percentile matching for DAB intensity.""" | |
| with torch.amp.autocast('cuda', enabled=False): | |
| gen = generated.float() | |
| tgt = target.float() | |
| dab_gen = self.dab_extractor.extract_dab_intensity(gen, normalize="none") | |
| dab_tgt = self.dab_extractor.extract_dab_intensity(tgt, normalize="none") | |
| def _batched_top10_mean(dab): | |
| """Compute mean of top-10% DAB intensity per sample (vectorized).""" | |
| B = dab.shape[0] | |
| flat = dab.reshape(B, -1) # [B, H*W] | |
| p99 = torch.quantile(flat, 0.99, dim=1, keepdim=True) | |
| flat = flat.clamp(max=p99) | |
| p90 = torch.quantile(flat, 0.9, dim=1, keepdim=True) | |
| mask = flat >= p90 # [B, H*W] | |
| # Use masked mean: sum(vals * mask) / sum(mask), fallback to flat mean | |
| masked_sum = (flat * mask).sum(dim=1) | |
| mask_count = mask.sum(dim=1).clamp(min=1) | |
| return masked_sum / mask_count # [B] | |
| gen_scores = _batched_top10_mean(dab_gen) | |
| tgt_scores = _batched_top10_mean(dab_tgt) | |
| return F.l1_loss(gen_scores, tgt_scores) | |
| def compute_dab_contrast_loss(self, generated, labels): | |
| """Class-ordering hinge loss: DAB(3+) > DAB(2+) > DAB(1+) > DAB(0).""" | |
| with torch.amp.autocast('cuda', enabled=False): | |
| gen = generated.float() | |
| # Only use non-null labels | |
| valid = labels < self.null_class | |
| if valid.sum() < 2: | |
| return torch.tensor(0.0, device=self.device, requires_grad=True) | |
| gen_valid = gen[valid] | |
| labels_valid = labels[valid] | |
| dab_gen = self.dab_extractor.extract_dab_intensity(gen_valid, normalize="none") | |
| B = dab_gen.shape[0] | |
| flat = dab_gen.reshape(B, -1) | |
| p99 = torch.quantile(flat, 0.99, dim=1, keepdim=True) | |
| flat = flat.clamp(max=p99) | |
| p90 = torch.quantile(flat, 0.9, dim=1, keepdim=True) | |
| mask = flat >= p90 | |
| masked_sum = (flat * mask).sum(dim=1) | |
| mask_count = mask.sum(dim=1).clamp(min=1) | |
| dab_scores = masked_sum / mask_count | |
| class_pairs = [ | |
| (3, 0, 0.20), (3, 1, 0.15), | |
| (2, 0, 0.08), (3, 2, 0.10), | |
| ] | |
| losses = [] | |
| for high_cls, low_cls, margin in class_pairs: | |
| high_mask = labels_valid == high_cls | |
| low_mask = labels_valid == low_cls | |
| if high_mask.sum() > 0 and low_mask.sum() > 0: | |
| high_score = dab_scores[high_mask].mean() | |
| low_score = dab_scores[low_mask].mean() | |
| losses.append(F.relu(margin - (high_score - low_score))) | |
| if losses: | |
| return torch.stack(losses).mean() | |
| return torch.tensor(0.0, device=self.device, requires_grad=True) | |
| def compute_edge_loss(self, generated, target): | |
| """Fourier spectral loss at 256x256 for boundary sharpness. | |
| Compares power spectrum magnitudes between generated and target. | |
| The Fourier magnitude is inherently translation-invariant — shifting | |
| an image doesn't change its frequency content — so this is robust to | |
| the ~30px misalignment in consecutive-cut BCI pairs. | |
| Focuses on high-frequency bands (outer 75% of spectrum) where | |
| blurriness manifests as reduced power. | |
| """ | |
| with torch.amp.autocast('cuda', enabled=False): | |
| gen = F.interpolate(generated.float(), size=256, mode='bilinear', align_corners=False) | |
| tgt = F.interpolate(target.float(), size=256, mode='bilinear', align_corners=False) | |
| # Grayscale | |
| gen_gray = gen.mean(dim=1, keepdim=True) | |
| tgt_gray = tgt.mean(dim=1, keepdim=True) | |
| # 2D FFT -> power spectrum (log-scale for stability) | |
| gen_fft = torch.fft.fft2(gen_gray) | |
| tgt_fft = torch.fft.fft2(tgt_gray) | |
| gen_mag = torch.log1p(gen_fft.abs()) | |
| tgt_mag = torch.log1p(tgt_fft.abs()) | |
| # High-frequency mask: keep outer 75% of spectrum | |
| H, W = gen_mag.shape[-2], gen_mag.shape[-1] | |
| cy, cx = H // 2, W // 2 | |
| y = torch.arange(H, device=gen.device).float() - cy | |
| x = torch.arange(W, device=gen.device).float() - cx | |
| dist = (y[:, None] ** 2 + x[None, :] ** 2).sqrt() | |
| max_dist = (cy ** 2 + cx ** 2) ** 0.5 | |
| hf_mask = (dist > 0.25 * max_dist).float() | |
| # L1 on high-frequency magnitudes | |
| return F.l1_loss(gen_mag * hf_mask, tgt_mag * hf_mask) | |
| def compute_dab_sharpness_loss(self, generated, target): | |
| """DAB spatial sharpness loss: penalizes diffuse brown, rewards membrane-localized DAB. | |
| Two components: | |
| 1. DAB gradient magnitude: mean Sobel gradient magnitude per image. | |
| 2. DAB local variance distribution: sorted-L1 (Wasserstein-1) on | |
| patch variance vectors. | |
| """ | |
| with torch.amp.autocast('cuda', enabled=False): | |
| gen = generated.float() | |
| tgt = target.float() | |
| dab_gen = self.dab_extractor.extract_dab_intensity(gen, normalize="none") | |
| dab_tgt = self.dab_extractor.extract_dab_intensity(tgt, normalize="none") | |
| # Ensure [B, 1, H, W] | |
| if dab_gen.dim() == 3: | |
| dab_gen = dab_gen.unsqueeze(1) | |
| if dab_tgt.dim() == 3: | |
| dab_tgt = dab_tgt.unsqueeze(1) | |
| B = dab_gen.shape[0] | |
| # --- Component 1: Gradient magnitude (batched) --- | |
| sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], | |
| dtype=torch.float32, device=gen.device).view(1, 1, 3, 3) | |
| sobel_y = sobel_x.transpose(-1, -2) | |
| gx_gen = F.conv2d(dab_gen, sobel_x, padding=1) | |
| gy_gen = F.conv2d(dab_gen, sobel_y, padding=1) | |
| grad_gen = (gx_gen**2 + gy_gen**2 + 1e-8).sqrt() | |
| gx_tgt = F.conv2d(dab_tgt, sobel_x, padding=1) | |
| gy_tgt = F.conv2d(dab_tgt, sobel_y, padding=1) | |
| grad_tgt = (gx_tgt**2 + gy_tgt**2 + 1e-8).sqrt() | |
| # Match mean gradient magnitude per image | |
| grad_loss = F.l1_loss(grad_gen.mean(dim=[1, 2, 3]), grad_tgt.mean(dim=[1, 2, 3])) | |
| # --- Component 2: Local variance distribution (sorted-L1) --- | |
| ps = 16 # patch size | |
| var_losses = [] | |
| for i in range(B): | |
| g = dab_gen[i, 0] # [H, W] | |
| t = dab_tgt[i, 0] | |
| H, W = g.shape | |
| nH, nW = H // ps, W // ps | |
| g_patches = g[:nH*ps, :nW*ps].reshape(nH, ps, nW, ps).permute(0, 2, 1, 3).reshape(-1, ps*ps) | |
| t_patches = t[:nH*ps, :nW*ps].reshape(nH, ps, nW, ps).permute(0, 2, 1, 3).reshape(-1, ps*ps) | |
| g_var = g_patches.var(dim=1) | |
| t_var = t_patches.var(dim=1) | |
| g_sorted, _ = g_var.sort() | |
| t_sorted, _ = t_var.sort() | |
| var_losses.append(F.l1_loss(g_sorted, t_sorted.detach())) | |
| var_loss = torch.stack(var_losses).mean() | |
| return grad_loss + var_loss | |
| def compute_he_edge_loss(self, generated, he_input): | |
| """H&E edge structure preservation loss. | |
| Extracts Sobel edges from H&E input and generated output, then | |
| computes L1 loss between edge maps at multiple scales. | |
| """ | |
| with torch.amp.autocast('cuda', enabled=False): | |
| gen = generated.float() | |
| he = he_input.float() | |
| gen_gray = ((gen + 1) / 2).mean(dim=1, keepdim=True) | |
| he_gray = ((he + 1) / 2).mean(dim=1, keepdim=True) | |
| sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], | |
| dtype=torch.float32, device=gen.device).view(1, 1, 3, 3) | |
| sobel_y = sobel_x.transpose(-1, -2) | |
| loss = 0.0 | |
| full_size = gen_gray.shape[-1] | |
| scales = [full_size, full_size // 2] | |
| for size in scales: | |
| if size < full_size: | |
| g = F.interpolate(gen_gray, size=size, mode='bilinear', align_corners=False) | |
| h = F.interpolate(he_gray, size=size, mode='bilinear', align_corners=False) | |
| else: | |
| g, h = gen_gray, he_gray | |
| gx_gen = F.conv2d(g, sobel_x, padding=1) | |
| gy_gen = F.conv2d(g, sobel_y, padding=1) | |
| edge_gen = (gx_gen**2 + gy_gen**2 + 1e-8).sqrt() | |
| gx_he = F.conv2d(h, sobel_x, padding=1) | |
| gy_he = F.conv2d(h, sobel_y, padding=1) | |
| edge_he = (gx_he**2 + gy_he**2 + 1e-8).sqrt() | |
| loss = loss + F.l1_loss(edge_gen, edge_he.detach()) | |
| return loss / 2.0 | |
| def compute_background_loss(self, generated, he_input): | |
| """Background white loss: push background regions toward white.""" | |
| with torch.amp.autocast('cuda', enabled=False): | |
| gen = generated.float() | |
| he = he_input.float() | |
| he_bright = ((he + 1) / 2).mean(dim=1, keepdim=True) | |
| threshold = self.hparams.bg_threshold | |
| mask = torch.sigmoid((he_bright - threshold) * 20.0) | |
| white_target = torch.ones_like(gen) | |
| diff = (gen - white_target).abs() | |
| weighted_diff = diff * mask | |
| mask_sum = mask.sum() * 3 | |
| if mask_sum > 0: | |
| return weighted_diff.sum() / mask_sum | |
| return torch.tensor(0.0, device=gen.device, requires_grad=True) | |
| def compute_gram_style_loss(self, generated, target): | |
| """Gram-matrix style loss: match texture statistics via VGG feature correlations.""" | |
| with torch.amp.autocast('cuda', enabled=False): | |
| gen = generated.float() | |
| tgt = target.float() | |
| gen_256 = F.interpolate(gen, size=256, mode='bilinear', align_corners=False) | |
| tgt_256 = F.interpolate(tgt, size=256, mode='bilinear', align_corners=False) | |
| gen_feats = self.vgg_extractor(gen_256) | |
| tgt_feats = self.vgg_extractor(tgt_256) | |
| loss = 0.0 | |
| for gf, tf in zip(gen_feats, tgt_feats): | |
| gram_gen = gram_matrix(gf) | |
| gram_tgt = gram_matrix(tf) | |
| loss = loss + F.l1_loss(gram_gen, gram_tgt.detach()) | |
| return loss / len(gen_feats) | |
| def training_step(self, batch, batch_idx): | |
| he, her2, uni_or_crops, labels, fnames = batch | |
| opt_g, opt_d = self.optimizers() | |
| # On-the-fly UNI extraction: dataset returns [B, 16, 3, 224, 224] sub-crops | |
| if self._uni_extract_on_the_fly: | |
| uni = self._extract_uni_from_sub_crops(uni_or_crops) | |
| else: | |
| uni = uni_or_crops | |
| # Apply CFG dropout | |
| labels_dropped, uni_dropped = self._apply_cfg_dropout(labels, uni) | |
| # Ablation: zero out UNI features | |
| if self.hparams.disable_uni: | |
| uni_dropped = torch.zeros_like(uni_dropped) | |
| # Ablation: force all labels to null class | |
| if self.hparams.disable_class: | |
| labels_dropped = torch.full_like(labels_dropped, self.hparams.null_class) | |
| # ---------------------------------------------------------------- | |
| # Generator step | |
| # ---------------------------------------------------------------- | |
| generated = self.generator(he, uni_dropped, labels_dropped) | |
| # LPIPS main: 4x downsample (128 for 512 input, 256 for 1024) | |
| lpips_main_size = self.hparams.image_size // 4 | |
| gen_lpips = F.interpolate(generated, size=lpips_main_size, mode='bilinear', align_corners=False) | |
| her2_lpips = F.interpolate(her2, size=lpips_main_size, mode='bilinear', align_corners=False) | |
| loss_lpips = self.lpips_fn(gen_lpips, her2_lpips).mean() | |
| loss_g = self.hparams.lpips_weight * loss_lpips | |
| # LPIPS fine: 2x downsample (256 for 512 input, 512 for 1024) | |
| if self.hparams.lpips_256_weight > 0: | |
| lpips_fine_size = self.hparams.image_size // 2 | |
| gen_fine = F.interpolate(generated, size=lpips_fine_size, mode='bilinear', align_corners=False) | |
| her2_fine = F.interpolate(her2, size=lpips_fine_size, mode='bilinear', align_corners=False) | |
| loss_lpips_256 = self.lpips_fn(gen_fine, her2_fine).mean() | |
| loss_g = loss_g + self.hparams.lpips_256_weight * loss_lpips_256 | |
| self.log('train/lpips_fine', loss_lpips_256, prog_bar=False) | |
| # LPIPS at full resolution (expensive) | |
| if self.hparams.lpips_512_weight > 0: | |
| loss_lpips_512 = self.lpips_fn(generated, her2).mean() | |
| loss_g = loss_g + self.hparams.lpips_512_weight * loss_lpips_512 | |
| self.log('train/lpips_fullres', loss_lpips_512, prog_bar=False) | |
| # Low-resolution L1 (color fidelity, misalignment-robust at 64x64) | |
| if self.hparams.l1_lowres_weight > 0: | |
| gen_64 = F.interpolate(generated, size=64, mode='bilinear', align_corners=False) | |
| her2_64 = F.interpolate(her2, size=64, mode='bilinear', align_corners=False) | |
| loss_l1_lowres = F.l1_loss(gen_64, her2_64) | |
| loss_g = loss_g + self.hparams.l1_lowres_weight * loss_l1_lowres | |
| self.log('train/l1_lowres', loss_l1_lowres, prog_bar=False) | |
| # DAB losses (use original labels, not dropped) | |
| if self.hparams.dab_intensity_weight > 0: | |
| loss_dab = self.compute_dab_intensity_loss(generated, her2) | |
| loss_g = loss_g + self.hparams.dab_intensity_weight * loss_dab | |
| self.log('train/dab_intensity', loss_dab, prog_bar=False) | |
| if self.hparams.dab_contrast_weight > 0: | |
| loss_dab_contrast = self.compute_dab_contrast_loss(generated, labels) | |
| loss_g = loss_g + self.hparams.dab_contrast_weight * loss_dab_contrast | |
| self.log('train/dab_contrast', loss_dab_contrast, prog_bar=False) | |
| # Edge loss (boundary sharpness) | |
| if self.hparams.edge_weight > 0: | |
| loss_edge = self.compute_edge_loss(generated, her2) | |
| loss_g = loss_g + self.hparams.edge_weight * loss_edge | |
| self.log('train/edge_loss', loss_edge, prog_bar=False) | |
| # DAB sharpness loss (membrane-localized vs diffuse brown) | |
| if self.hparams.dab_sharpness_weight > 0: | |
| loss_dab_sharp = self.compute_dab_sharpness_loss(generated, her2) | |
| loss_g = loss_g + self.hparams.dab_sharpness_weight * loss_dab_sharp | |
| self.log('train/dab_sharpness', loss_dab_sharp, prog_bar=False) | |
| # Gram-matrix style loss | |
| if self.hparams.gram_style_weight > 0 and self.vgg_extractor is not None: | |
| loss_gram = self.compute_gram_style_loss(generated, her2) | |
| loss_g = loss_g + self.hparams.gram_style_weight * loss_gram | |
| self.log('train/gram_style', loss_gram, prog_bar=False) | |
| # H&E edge structure preservation (pixel-aligned) | |
| if self.hparams.he_edge_weight > 0: | |
| loss_he_edge = self.compute_he_edge_loss(generated, he) | |
| loss_g = loss_g + self.hparams.he_edge_weight * loss_he_edge | |
| self.log('train/he_edge', loss_he_edge, prog_bar=False) | |
| # Background white loss | |
| if self.hparams.bg_white_weight > 0: | |
| loss_bg = self.compute_background_loss(generated, he) | |
| loss_g = loss_g + self.hparams.bg_white_weight * loss_bg | |
| self.log('train/bg_white', loss_bg, prog_bar=False) | |
| # PatchNCE loss (contrastive: H&E input vs generated, never sees GT) | |
| if self.hparams.patchnce_weight > 0 and self.patchnce_loss is not None: | |
| feats_he = self.generator.encode(he) | |
| feats_gen = self.generator.encode(generated) | |
| loss_nce = self.patchnce_loss(feats_he, feats_gen) | |
| loss_g = loss_g + self.hparams.patchnce_weight * loss_nce | |
| self.log('train/patchnce', loss_nce, prog_bar=False) | |
| # Adversarial losses (after warmup) | |
| loss_adv = torch.tensor(0.0, device=self.device) | |
| loss_feat_match = torch.tensor(0.0, device=self.device) | |
| loss_crop_adv = torch.tensor(0.0, device=self.device) | |
| loss_uncond_adv = torch.tensor(0.0, device=self.device) | |
| any_adv = (self.hparams.adversarial_weight > 0 or | |
| self.hparams.uncond_disc_weight > 0 or | |
| self.hparams.crop_disc_weight > 0 or | |
| self.hparams.feat_match_weight > 0) | |
| img_sz = self.hparams.image_size | |
| # Pre-compute disc-resolution tensors (512 for 1024 input, identity for 512) | |
| if img_sz == 1024: | |
| he_for_disc = F.interpolate(he, size=512, mode='bilinear', align_corners=False) | |
| her2_for_disc = F.interpolate(her2, size=512, mode='bilinear', align_corners=False) | |
| else: | |
| he_for_disc = he | |
| her2_for_disc = her2 | |
| if self.global_step >= self.hparams.adversarial_start_step and any_adv: | |
| if img_sz == 1024: | |
| gen_for_disc = F.interpolate(generated, size=512, mode='bilinear', align_corners=False) | |
| else: | |
| gen_for_disc = generated | |
| # Conditional discriminator (paired: generated+HE vs real_HER2+HE) | |
| if self.hparams.adversarial_weight > 0: | |
| fake_input = torch.cat([gen_for_disc, he_for_disc], dim=1) | |
| disc_outputs = self.discriminator(fake_input) | |
| loss_adv = sum(hinge_loss_g(out) for out in disc_outputs) / len(disc_outputs) | |
| loss_g = loss_g + self.hparams.adversarial_weight * loss_adv | |
| # Feature matching from unconditional disc | |
| if (self.hparams.feat_match_weight > 0 and | |
| self.uncond_discriminator is not None): | |
| _, fake_feats = self.uncond_discriminator(gen_for_disc, return_features=True) | |
| with torch.no_grad(): | |
| _, real_feats = self.uncond_discriminator(her2_for_disc, return_features=True) | |
| loss_feat_match = feature_matching_loss(fake_feats, real_feats) | |
| loss_g = loss_g + self.hparams.feat_match_weight * loss_feat_match | |
| # Crop discriminator: random crops at full resolution | |
| if self.crop_discriminator is not None and self.hparams.crop_disc_weight > 0: | |
| fake_input_crop = torch.cat([generated, he], dim=1) | |
| cs = self.hparams.crop_size | |
| top = torch.randint(0, img_sz - cs, (1,)).item() | |
| left = torch.randint(0, img_sz - cs, (1,)).item() | |
| fake_crop = fake_input_crop[:, :, top:top+cs, left:left+cs] | |
| loss_crop_adv = hinge_loss_g(self.crop_discriminator(fake_crop)) | |
| loss_g = loss_g + self.hparams.crop_disc_weight * loss_crop_adv | |
| # Unconditional discriminator: HER2-only adversarial | |
| if self.uncond_discriminator is not None and self.hparams.uncond_disc_weight > 0: | |
| loss_uncond_adv = hinge_loss_g(self.uncond_discriminator(gen_for_disc)) | |
| loss_g = loss_g + self.hparams.uncond_disc_weight * loss_uncond_adv | |
| # Generator backward + step | |
| lr_scale = self._get_lr_scale() | |
| for pg in opt_g.param_groups: | |
| pg['lr'] = self.hparams.gen_lr * lr_scale | |
| opt_g.zero_grad() | |
| self.manual_backward(loss_g) | |
| torch.nn.utils.clip_grad_norm_(self.generator.parameters(), 1.0) | |
| opt_g.step() | |
| # Update EMA | |
| self._update_ema() | |
| # ---------------------------------------------------------------- | |
| # Discriminator step | |
| # ---------------------------------------------------------------- | |
| loss_d = torch.tensor(0.0, device=self.device) | |
| loss_crop_d = torch.tensor(0.0, device=self.device) | |
| loss_uncond_d = torch.tensor(0.0, device=self.device) | |
| if self.global_step >= self.hparams.adversarial_start_step and any_adv: | |
| with torch.no_grad(): | |
| fake_detached = self.generator(he, uni_dropped, labels_dropped) | |
| # For 1024, downsample for disc | |
| if img_sz == 1024: | |
| fake_det_disc = F.interpolate(fake_detached, size=512, mode='bilinear', align_corners=False) | |
| else: | |
| fake_det_disc = fake_detached | |
| # Conditional discriminator | |
| if self.hparams.adversarial_weight > 0: | |
| real_input = torch.cat([her2_for_disc, he_for_disc], dim=1) | |
| fake_input = torch.cat([fake_det_disc, he_for_disc], dim=1) | |
| disc_real = self.discriminator(real_input) | |
| disc_fake = self.discriminator(fake_input) | |
| loss_d = sum( | |
| hinge_loss_d(dr, df) | |
| for dr, df in zip(disc_real, disc_fake) | |
| ) / len(disc_real) | |
| # Crop discriminator | |
| if self.crop_discriminator is not None and self.hparams.crop_disc_weight > 0: | |
| real_input_c = torch.cat([her2, he], dim=1) | |
| fake_input_c = torch.cat([fake_detached, he], dim=1) | |
| cs = self.hparams.crop_size | |
| top = torch.randint(0, img_sz - cs, (1,)).item() | |
| left = torch.randint(0, img_sz - cs, (1,)).item() | |
| real_crop = real_input_c[:, :, top:top+cs, left:left+cs] | |
| fake_crop = fake_input_c[:, :, top:top+cs, left:left+cs] | |
| loss_crop_d = hinge_loss_d( | |
| self.crop_discriminator(real_crop), | |
| self.crop_discriminator(fake_crop), | |
| ) | |
| loss_d = loss_d + self.hparams.crop_disc_weight * loss_crop_d | |
| # Unconditional discriminator | |
| if self.uncond_discriminator is not None and ( | |
| self.hparams.uncond_disc_weight > 0 or self.hparams.feat_match_weight > 0): | |
| uncond_real_out = self.uncond_discriminator(her2_for_disc) | |
| uncond_fake_out = self.uncond_discriminator(fake_det_disc) | |
| loss_uncond_d = hinge_loss_d(uncond_real_out, uncond_fake_out) | |
| loss_d = loss_d + max(self.hparams.uncond_disc_weight, 1.0) * loss_uncond_d | |
| # R1 gradient penalty | |
| loss_r1 = torch.tensor(0.0, device=self.device) | |
| if self.global_step % self.hparams.r1_every == 0: | |
| with torch.amp.autocast('cuda', enabled=False): | |
| if self.hparams.adversarial_weight > 0: | |
| real_input_r1 = torch.cat([her2_for_disc, he_for_disc], dim=1).float().detach().requires_grad_(True) | |
| for disc in [self.discriminator.disc_512]: | |
| d_real = disc(real_input_r1) | |
| grad_real = torch.autograd.grad( | |
| outputs=d_real.sum(), inputs=real_input_r1, | |
| create_graph=True, | |
| )[0] | |
| loss_r1 = loss_r1 + self.hparams.r1_weight * grad_real.pow(2).mean() | |
| if self.uncond_discriminator is not None and ( | |
| self.hparams.uncond_disc_weight > 0 or self.hparams.feat_match_weight > 0): | |
| her2_r1 = her2_for_disc.float().detach().requires_grad_(True) | |
| d_real_uncond = self.uncond_discriminator(her2_r1) | |
| grad_uncond = torch.autograd.grad( | |
| outputs=d_real_uncond.sum(), inputs=her2_r1, | |
| create_graph=True, | |
| )[0] | |
| loss_r1 = loss_r1 + self.hparams.r1_weight * grad_uncond.pow(2).mean() | |
| loss_d = loss_d + loss_r1 | |
| self.log('train/r1_penalty', loss_r1, prog_bar=False) | |
| opt_d.zero_grad() | |
| self.manual_backward(loss_d) | |
| if self.hparams.adversarial_weight > 0: | |
| torch.nn.utils.clip_grad_norm_(self.discriminator.parameters(), 1.0) | |
| if self.crop_discriminator is not None: | |
| torch.nn.utils.clip_grad_norm_(self.crop_discriminator.parameters(), 1.0) | |
| if self.uncond_discriminator is not None: | |
| torch.nn.utils.clip_grad_norm_(self.uncond_discriminator.parameters(), 1.0) | |
| opt_d.step() | |
| # Logging | |
| self.log('train/loss_g', loss_g, prog_bar=True) | |
| self.log('train/loss_d', loss_d, prog_bar=True) | |
| self.log('train/lpips', loss_lpips, prog_bar=True) | |
| self.log('train/adversarial', loss_adv, prog_bar=False) | |
| self.log('train/lr_scale', lr_scale, prog_bar=False) | |
| if self.crop_discriminator is not None: | |
| self.log('train/crop_adv_g', loss_crop_adv, prog_bar=False) | |
| self.log('train/crop_adv_d', loss_crop_d, prog_bar=False) | |
| if self.uncond_discriminator is not None: | |
| self.log('train/uncond_adv_g', loss_uncond_adv, prog_bar=False) | |
| self.log('train/uncond_adv_d', loss_uncond_d, prog_bar=False) | |
| if self.hparams.feat_match_weight > 0: | |
| self.log('train/feat_match', loss_feat_match, prog_bar=False) | |
| def on_validation_epoch_start(self): | |
| # Pick a random batch index for the second sample grid | |
| n_val_batches = max(1, len(self.trainer.val_dataloaders)) | |
| self._random_val_batch_idx = torch.randint(1, max(2, n_val_batches), (1,)).item() | |
| # Per-label sample collectors (for multi-stain visual grids) | |
| self._val_per_label_samples = {} | |
| def _log_sample_grid(self, he, her2_01, gen_01, key): | |
| """Log H&E | Real | Gen grid to wandb.""" | |
| n = min(4, len(he)) | |
| he_01 = ((he[:n].cpu() + 1) / 2).clamp(0, 1) | |
| grid_images = [] | |
| for i in range(n): | |
| grid_images.extend([ | |
| he_01[i], | |
| her2_01[i].cpu(), | |
| gen_01[i].cpu(), | |
| ]) | |
| grid = torchvision.utils.make_grid(grid_images, nrow=3, padding=2) | |
| if self.logger: | |
| self.logger.experiment.log({ | |
| key: [wandb.Image(grid, caption='H&E | Real | Gen')], | |
| 'global_step': self.global_step, | |
| }) | |
| def validation_step(self, batch, batch_idx): | |
| he, her2, uni_or_crops, labels, fnames = batch | |
| # On-the-fly UNI extraction | |
| if self._uni_extract_on_the_fly: | |
| uni = self._extract_uni_from_sub_crops(uni_or_crops) | |
| else: | |
| uni = uni_or_crops | |
| if self.hparams.disable_uni: | |
| uni = torch.zeros_like(uni) | |
| if self.hparams.disable_class: | |
| labels = torch.full_like(labels, self.hparams.null_class) | |
| # Use EMA generator | |
| with torch.no_grad(): | |
| generated = self.generator_ema(he, uni, labels) | |
| # LPIPS (4x downsample: 128 for 512, 256 for 1024) | |
| lpips_size = self.hparams.image_size // 4 | |
| gen_lpips = F.interpolate(generated, size=lpips_size, mode='bilinear', align_corners=False) | |
| her2_lpips = F.interpolate(her2, size=lpips_size, mode='bilinear', align_corners=False) | |
| lpips_val = self.lpips_fn(gen_lpips, her2_lpips).mean() | |
| # SSIM | |
| gen_01 = ((generated + 1) / 2).clamp(0, 1) | |
| her2_01 = ((her2 + 1) / 2).clamp(0, 1) | |
| from torchmetrics.functional.image import structural_similarity_index_measure | |
| ssim_val = structural_similarity_index_measure(gen_01, her2_01, data_range=1.0) | |
| # DAB MAE (canonical: mean of top-10%) | |
| dab_gen = self.dab_extractor.extract_dab_intensity(generated.float().cpu(), normalize="none") | |
| dab_real = self.dab_extractor.extract_dab_intensity(her2.float().cpu(), normalize="none") | |
| def p90_score(dab): | |
| flat = dab.flatten() | |
| p90 = torch.quantile(flat, 0.9) | |
| mask = flat >= p90 | |
| return flat[mask].mean().item() if mask.sum() > 0 else flat.mean().item() | |
| dab_mae = sum( | |
| abs(p90_score(dab_gen[i]) - p90_score(dab_real[i])) | |
| for i in range(len(dab_gen)) | |
| ) / len(dab_gen) | |
| self.log('val/lpips', lpips_val, prog_bar=True, sync_dist=True) | |
| self.log('val/ssim', ssim_val, prog_bar=True, sync_dist=True) | |
| self.log('val/dab_mae', dab_mae, prog_bar=True, sync_dist=True) | |
| # Collect per-label samples for visual grids (multi-stain only) | |
| if hasattr(self, '_val_per_label_samples'): | |
| for i in range(len(labels)): | |
| lbl = labels[i].item() | |
| if lbl == self.hparams.null_class: | |
| continue | |
| if lbl not in self._val_per_label_samples: | |
| self._val_per_label_samples[lbl] = {'he': [], 'real': [], 'gen': []} | |
| bucket = self._val_per_label_samples[lbl] | |
| if len(bucket['he']) < 4: | |
| bucket['he'].append(he[i].cpu()) | |
| bucket['real'].append(her2_01[i].cpu()) | |
| bucket['gen'].append(gen_01[i].cpu()) | |
| # Log sample grids: first batch (fixed) + one random batch | |
| if batch_idx == 0: | |
| self._log_sample_grid(he, her2_01, gen_01, 'val/samples_fixed') | |
| elif batch_idx == self._random_val_batch_idx: | |
| self._log_sample_grid(he, her2_01, gen_01, 'val/samples_random') | |
| def on_validation_epoch_end(self): | |
| """Log per-label sample grids if multiple labels are present.""" | |
| if not hasattr(self, '_val_per_label_samples') or len(self._val_per_label_samples) <= 1: | |
| return | |
| label_names = getattr(self.hparams, 'label_names', None) | |
| for lbl, bucket in sorted(self._val_per_label_samples.items()): | |
| if not bucket['he'] or not self.logger: | |
| continue | |
| name = label_names[lbl] if label_names and lbl < len(label_names) else str(lbl) | |
| self._log_sample_grid( | |
| torch.stack(bucket['he']), | |
| torch.stack(bucket['real']), | |
| torch.stack(bucket['gen']), | |
| f'val/samples_{name}', | |
| ) | |
| self._val_per_label_samples = {} | |
| def generate(self, he_images, uni_features, labels, | |
| num_inference_steps=None, guidance_scale=1.0, seed=None): | |
| """Generate IHC images from H&E input. | |
| Args: | |
| he_images: [B, 3, H, H] where H=512 or H=1024 | |
| uni_features: [B, N, 1024] where N=16 (4x4 CLS) or N=1024 (32x32 patch) | |
| labels: [B] class/stain labels | |
| num_inference_steps: ignored (single forward pass) | |
| guidance_scale: CFG scale (1.0 = no guidance) | |
| seed: random seed (for reproducibility, though model is deterministic) | |
| """ | |
| if seed is not None: | |
| torch.manual_seed(seed) | |
| gen = self.generator_ema if hasattr(self, 'generator_ema') else self.generator | |
| if guidance_scale <= 1.0: | |
| return gen(he_images, uni_features, labels) | |
| # Classifier-free guidance | |
| null_labels = torch.full_like(labels, self.null_class) | |
| output_cond = gen(he_images, uni_features, labels) | |
| output_uncond = gen(he_images, uni_features, null_labels) | |
| output = output_uncond + guidance_scale * (output_cond - output_uncond) | |
| return output.clamp(-1, 1) | |