Spaces:
Running
Running
File size: 43,305 Bytes
4db9215 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 | """
UNIStainNet: Pixel-Space UNI-Guided Virtual Staining Network.
Architecture:
Generator: SPADE-UNet conditioned on UNI pathology features + stain/class embedding
Discriminator: Multi-scale PatchGAN (512 + 256)
Losses: LPIPS@128 + adversarial + DAB intensity + DAB contrast
References:
- Park et al., "Semantic Image Synthesis with SPADE" (CVPR 2019)
- Chen et al., "A general-purpose self-supervised model for pathology" (Nature Medicine 2024)
- Isola et al., "Image-to-Image Translation with pix2pix" (CVPR 2017)
"""
import copy
from collections import OrderedDict
import torch
import torch.nn as nn
import torch.nn.functional as F
import lpips
import pytorch_lightning as pl
import torchvision
import wandb
from src.models.discriminator import (
PatchDiscriminator, MultiScaleDiscriminator,
hinge_loss_d, hinge_loss_g, r1_gradient_penalty, feature_matching_loss,
)
from src.models.generator import SPADEUNetGenerator
from src.models.losses import VGGFeatureExtractor, gram_matrix, PatchNCELoss
from src.utils.dab import DABExtractor
# ======================================================================
# Training Module
# ======================================================================
class UNIStainNetTrainer(pl.LightningModule):
"""PyTorch Lightning training module for UNIStainNet.
Handles GAN training with manual optimization, CFG dropout, EMA, and
all loss computations.
"""
def __init__(
self,
# Architecture
num_classes=5,
null_class=4,
class_dim=64,
uni_dim=1024,
ndf=64,
disc_n_layers=3,
input_skip=False,
# Optimizer
gen_lr=1e-4,
disc_lr=4e-4,
warmup_steps=1000,
# Loss weights
lpips_weight=1.0,
lpips_256_weight=0.5,
lpips_512_weight=0.0,
adversarial_weight=1.0,
dab_intensity_weight=0.1,
dab_contrast_weight=0.05,
dab_sharpness_weight=0.0,
gram_style_weight=0.0,
edge_weight=0.0,
he_edge_weight=0.0,
bg_white_weight=0.0,
bg_threshold=0.85,
l1_lowres_weight=0.0,
edge_encoder=False,
edge_base_ch=32,
uni_spatial_size=4,
uncond_disc_weight=0.0,
crop_disc_weight=0.0,
crop_size=128,
feat_match_weight=0.0,
patchnce_weight=0.0,
patchnce_layers=(2, 3, 4),
patchnce_n_patches=256,
patchnce_temperature=0.07,
# Ablation
disable_uni=False,
disable_class=False,
# GAN training
r1_weight=10.0,
r1_every=16,
adversarial_start_step=2000,
# CFG
cfg_drop_class_prob=0.10,
cfg_drop_uni_prob=0.10,
cfg_drop_both_prob=0.05,
# EMA
ema_decay=0.999,
# On-the-fly UNI extraction (for crop-based training)
extract_uni_on_the_fly=False,
uni_spatial_pool_size=32,
# Resolution
image_size=512,
# 1024 architecture: extend UNI SPADE to 512 level
uni_spade_at_512=False,
# Per-label names for multi-stain logging
label_names=None,
):
super().__init__()
self.save_hyperparameters()
self.automatic_optimization = False
self.null_class = null_class
# On-the-fly UNI feature extraction (loaded lazily on first use)
self._uni_model = None
self._uni_extract_on_the_fly = extract_uni_on_the_fly
# Generator
self.generator = SPADEUNetGenerator(
num_classes=num_classes,
class_dim=class_dim,
uni_dim=uni_dim,
input_skip=input_skip,
edge_encoder=edge_encoder,
edge_base_ch=edge_base_ch,
uni_spatial_size=uni_spatial_size,
image_size=image_size,
uni_spade_at_512=uni_spade_at_512,
)
# Discriminator (global multi-scale)
self.discriminator = MultiScaleDiscriminator(
in_channels=6, ndf=ndf, n_layers=disc_n_layers,
)
# Crop discriminator (local full-res detail)
if crop_disc_weight > 0:
self.crop_discriminator = PatchDiscriminator(
in_channels=6, ndf=ndf, n_layers=disc_n_layers,
)
else:
self.crop_discriminator = None
# Unconditional discriminator (HER2-only, alignment-free texture judge)
# Also needed for feature matching loss (FM uses uncond disc features)
if uncond_disc_weight > 0 or feat_match_weight > 0:
self.uncond_discriminator = PatchDiscriminator(
in_channels=3, ndf=ndf, n_layers=disc_n_layers,
)
else:
self.uncond_discriminator = None
# PatchNCE loss (contrastive, alignment-free: H&E input vs generated)
if patchnce_weight > 0:
# Encoder channel dims: {1: 64, 2: 128, 3: 256, 4: 512}
enc_channels = {1: 64, 2: 128, 3: 256, 4: 512}
layer_channels = {l: enc_channels[l] for l in patchnce_layers}
self.patchnce_loss = PatchNCELoss(
layer_channels=layer_channels,
num_patches=patchnce_n_patches,
temperature=patchnce_temperature,
)
else:
self.patchnce_loss = None
# EMA generator
self.generator_ema = copy.deepcopy(self.generator)
self.generator_ema.requires_grad_(False)
# Losses
self.lpips_fn = lpips.LPIPS(net='alex')
self.lpips_fn.requires_grad_(False)
self.lpips_fn.eval()
self.dab_extractor = DABExtractor(device='cpu')
# VGG feature extractor for Gram-matrix style loss
if gram_style_weight > 0:
self.vgg_extractor = VGGFeatureExtractor()
self.vgg_extractor.requires_grad_(False)
self.vgg_extractor.eval()
else:
self.vgg_extractor = None
# Param counts
n_gen = sum(p.numel() for p in self.generator.parameters())
n_disc = sum(p.numel() for p in self.discriminator.parameters())
n_crop = sum(p.numel() for p in self.crop_discriminator.parameters()) if self.crop_discriminator else 0
n_uncond = sum(p.numel() for p in self.uncond_discriminator.parameters()) if self.uncond_discriminator else 0
print(f"Generator: {n_gen:,} params")
print(f"Discriminator: {n_disc:,} params (global) + {n_crop:,} (crop) + {n_uncond:,} (uncond)")
def configure_optimizers(self):
gen_params = list(self.generator.parameters())
if self.patchnce_loss is not None:
gen_params += list(self.patchnce_loss.parameters())
opt_g = torch.optim.Adam(
gen_params,
lr=self.hparams.gen_lr,
betas=(0.0, 0.999),
)
# All discriminator params in one optimizer
disc_params = list(self.discriminator.parameters())
if self.crop_discriminator is not None:
disc_params += list(self.crop_discriminator.parameters())
if self.uncond_discriminator is not None:
disc_params += list(self.uncond_discriminator.parameters())
opt_d = torch.optim.Adam(
disc_params,
lr=self.hparams.disc_lr,
betas=(0.0, 0.999),
)
return [opt_g, opt_d]
def _get_lr_scale(self):
"""Linear warmup."""
if self.global_step < self.hparams.warmup_steps:
return self.global_step / max(1, self.hparams.warmup_steps)
return 1.0
@torch.no_grad()
def _update_ema(self):
"""Update EMA generator weights."""
decay = self.hparams.ema_decay
for p_ema, p in zip(self.generator_ema.parameters(), self.generator.parameters()):
p_ema.data.mul_(decay).add_(p.data, alpha=1 - decay)
def on_save_checkpoint(self, checkpoint):
"""Exclude frozen UNI model from checkpoint (it's reloaded on-the-fly)."""
state_dict = checkpoint.get('state_dict', {})
keys_to_remove = [k for k in state_dict if k.startswith('_uni_model.')]
for k in keys_to_remove:
del state_dict[k]
def on_load_checkpoint(self, checkpoint):
"""Filter out UNI model keys from old checkpoints that included them."""
state_dict = checkpoint.get('state_dict', {})
keys_to_remove = [k for k in state_dict if k.startswith('_uni_model.')]
for k in keys_to_remove:
del state_dict[k]
def _load_uni_model(self):
"""Lazily load UNI ViT-L/16 for on-the-fly feature extraction."""
if self._uni_model is None:
import timm
self._uni_model = timm.create_model(
"hf-hub:MahmoodLab/uni",
pretrained=True,
init_values=1e-5,
dynamic_img_size=True,
)
self._uni_model.eval()
self._uni_model.requires_grad_(False)
self._uni_model = self._uni_model.to(self.device)
n_params = sum(p.numel() for p in self._uni_model.parameters())
print(f"UNI model loaded for on-the-fly extraction: {n_params:,} params")
return self._uni_model
@torch.no_grad()
def _extract_uni_from_sub_crops(self, uni_sub_crops):
"""Extract UNI features from pre-prepared sub-crops on GPU.
Args:
uni_sub_crops: [B, 16, 3, 224, 224] — batch of 4x4 sub-crop grids,
already normalized with ImageNet stats.
Returns:
uni_features: [B, S*S, 1024] where S = uni_spatial_pool_size (default 32)
"""
uni_model = self._load_uni_model()
B = uni_sub_crops.shape[0]
spatial_size = self.hparams.uni_spatial_pool_size
num_crops = 4 # 4x4 grid
patches_per_side = 14 # 224/16
# Batched UNI forward: [B, 16, 3, 224, 224] -> [B*16, 3, 224, 224]
all_crops = uni_sub_crops.reshape(B * 16, 3, 224, 224).to(self.device)
all_feats = uni_model.forward_features(all_crops) # [B*16, 197, 1024]
patch_tokens = all_feats[:, 1:, :] # [B*16, 196, 1024]
# Reshape back to per-sample grids: [B, 4, 4, 14, 14, 1024]
patch_tokens = patch_tokens.reshape(
B, num_crops, num_crops,
patches_per_side, patches_per_side, 1024
)
# Interleave to spatial grid: [B, 56, 56, 1024]
full_size = num_crops * patches_per_side # 56
full_grid = patch_tokens.permute(0, 1, 3, 2, 4, 5)
full_grid = full_grid.reshape(B, full_size, full_size, 1024)
# Pool to target spatial size (batched)
if spatial_size < full_size:
grid_bchw = full_grid.permute(0, 3, 1, 2) # [B, 1024, 56, 56]
pooled = F.adaptive_avg_pool2d(grid_bchw, spatial_size) # [B, 1024, S, S]
result = pooled.permute(0, 2, 3, 1) # [B, S, S, 1024]
else:
result = full_grid
S = result.shape[1]
return result.reshape(B, S * S, 1024) # [B, S*S, 1024]
def _apply_cfg_dropout(self, labels, uni_features):
"""Apply classifier-free guidance dropout during training (vectorized)."""
B = labels.shape[0]
device = labels.device
new_labels = labels.clone()
new_uni = uni_features.clone()
r = torch.rand(B, device=device)
p_both = self.hparams.cfg_drop_both_prob
p_class = p_both + self.hparams.cfg_drop_class_prob
p_uni = p_class + self.hparams.cfg_drop_uni_prob
drop_both = r < p_both
drop_class = (r >= p_both) & (r < p_class)
drop_uni = (r >= p_class) & (r < p_uni)
new_labels[drop_both | drop_class] = self.null_class
new_uni[drop_both | drop_uni] = 0.0
return new_labels, new_uni
def compute_dab_intensity_loss(self, generated, target):
"""Top-10% percentile matching for DAB intensity."""
with torch.amp.autocast('cuda', enabled=False):
gen = generated.float()
tgt = target.float()
dab_gen = self.dab_extractor.extract_dab_intensity(gen, normalize="none")
dab_tgt = self.dab_extractor.extract_dab_intensity(tgt, normalize="none")
def _batched_top10_mean(dab):
"""Compute mean of top-10% DAB intensity per sample (vectorized)."""
B = dab.shape[0]
flat = dab.reshape(B, -1) # [B, H*W]
p99 = torch.quantile(flat, 0.99, dim=1, keepdim=True)
flat = flat.clamp(max=p99)
p90 = torch.quantile(flat, 0.9, dim=1, keepdim=True)
mask = flat >= p90 # [B, H*W]
# Use masked mean: sum(vals * mask) / sum(mask), fallback to flat mean
masked_sum = (flat * mask).sum(dim=1)
mask_count = mask.sum(dim=1).clamp(min=1)
return masked_sum / mask_count # [B]
gen_scores = _batched_top10_mean(dab_gen)
tgt_scores = _batched_top10_mean(dab_tgt)
return F.l1_loss(gen_scores, tgt_scores)
def compute_dab_contrast_loss(self, generated, labels):
"""Class-ordering hinge loss: DAB(3+) > DAB(2+) > DAB(1+) > DAB(0)."""
with torch.amp.autocast('cuda', enabled=False):
gen = generated.float()
# Only use non-null labels
valid = labels < self.null_class
if valid.sum() < 2:
return torch.tensor(0.0, device=self.device, requires_grad=True)
gen_valid = gen[valid]
labels_valid = labels[valid]
dab_gen = self.dab_extractor.extract_dab_intensity(gen_valid, normalize="none")
B = dab_gen.shape[0]
flat = dab_gen.reshape(B, -1)
p99 = torch.quantile(flat, 0.99, dim=1, keepdim=True)
flat = flat.clamp(max=p99)
p90 = torch.quantile(flat, 0.9, dim=1, keepdim=True)
mask = flat >= p90
masked_sum = (flat * mask).sum(dim=1)
mask_count = mask.sum(dim=1).clamp(min=1)
dab_scores = masked_sum / mask_count
class_pairs = [
(3, 0, 0.20), (3, 1, 0.15),
(2, 0, 0.08), (3, 2, 0.10),
]
losses = []
for high_cls, low_cls, margin in class_pairs:
high_mask = labels_valid == high_cls
low_mask = labels_valid == low_cls
if high_mask.sum() > 0 and low_mask.sum() > 0:
high_score = dab_scores[high_mask].mean()
low_score = dab_scores[low_mask].mean()
losses.append(F.relu(margin - (high_score - low_score)))
if losses:
return torch.stack(losses).mean()
return torch.tensor(0.0, device=self.device, requires_grad=True)
def compute_edge_loss(self, generated, target):
"""Fourier spectral loss at 256x256 for boundary sharpness.
Compares power spectrum magnitudes between generated and target.
The Fourier magnitude is inherently translation-invariant — shifting
an image doesn't change its frequency content — so this is robust to
the ~30px misalignment in consecutive-cut BCI pairs.
Focuses on high-frequency bands (outer 75% of spectrum) where
blurriness manifests as reduced power.
"""
with torch.amp.autocast('cuda', enabled=False):
gen = F.interpolate(generated.float(), size=256, mode='bilinear', align_corners=False)
tgt = F.interpolate(target.float(), size=256, mode='bilinear', align_corners=False)
# Grayscale
gen_gray = gen.mean(dim=1, keepdim=True)
tgt_gray = tgt.mean(dim=1, keepdim=True)
# 2D FFT -> power spectrum (log-scale for stability)
gen_fft = torch.fft.fft2(gen_gray)
tgt_fft = torch.fft.fft2(tgt_gray)
gen_mag = torch.log1p(gen_fft.abs())
tgt_mag = torch.log1p(tgt_fft.abs())
# High-frequency mask: keep outer 75% of spectrum
H, W = gen_mag.shape[-2], gen_mag.shape[-1]
cy, cx = H // 2, W // 2
y = torch.arange(H, device=gen.device).float() - cy
x = torch.arange(W, device=gen.device).float() - cx
dist = (y[:, None] ** 2 + x[None, :] ** 2).sqrt()
max_dist = (cy ** 2 + cx ** 2) ** 0.5
hf_mask = (dist > 0.25 * max_dist).float()
# L1 on high-frequency magnitudes
return F.l1_loss(gen_mag * hf_mask, tgt_mag * hf_mask)
def compute_dab_sharpness_loss(self, generated, target):
"""DAB spatial sharpness loss: penalizes diffuse brown, rewards membrane-localized DAB.
Two components:
1. DAB gradient magnitude: mean Sobel gradient magnitude per image.
2. DAB local variance distribution: sorted-L1 (Wasserstein-1) on
patch variance vectors.
"""
with torch.amp.autocast('cuda', enabled=False):
gen = generated.float()
tgt = target.float()
dab_gen = self.dab_extractor.extract_dab_intensity(gen, normalize="none")
dab_tgt = self.dab_extractor.extract_dab_intensity(tgt, normalize="none")
# Ensure [B, 1, H, W]
if dab_gen.dim() == 3:
dab_gen = dab_gen.unsqueeze(1)
if dab_tgt.dim() == 3:
dab_tgt = dab_tgt.unsqueeze(1)
B = dab_gen.shape[0]
# --- Component 1: Gradient magnitude (batched) ---
sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]],
dtype=torch.float32, device=gen.device).view(1, 1, 3, 3)
sobel_y = sobel_x.transpose(-1, -2)
gx_gen = F.conv2d(dab_gen, sobel_x, padding=1)
gy_gen = F.conv2d(dab_gen, sobel_y, padding=1)
grad_gen = (gx_gen**2 + gy_gen**2 + 1e-8).sqrt()
gx_tgt = F.conv2d(dab_tgt, sobel_x, padding=1)
gy_tgt = F.conv2d(dab_tgt, sobel_y, padding=1)
grad_tgt = (gx_tgt**2 + gy_tgt**2 + 1e-8).sqrt()
# Match mean gradient magnitude per image
grad_loss = F.l1_loss(grad_gen.mean(dim=[1, 2, 3]), grad_tgt.mean(dim=[1, 2, 3]))
# --- Component 2: Local variance distribution (sorted-L1) ---
ps = 16 # patch size
var_losses = []
for i in range(B):
g = dab_gen[i, 0] # [H, W]
t = dab_tgt[i, 0]
H, W = g.shape
nH, nW = H // ps, W // ps
g_patches = g[:nH*ps, :nW*ps].reshape(nH, ps, nW, ps).permute(0, 2, 1, 3).reshape(-1, ps*ps)
t_patches = t[:nH*ps, :nW*ps].reshape(nH, ps, nW, ps).permute(0, 2, 1, 3).reshape(-1, ps*ps)
g_var = g_patches.var(dim=1)
t_var = t_patches.var(dim=1)
g_sorted, _ = g_var.sort()
t_sorted, _ = t_var.sort()
var_losses.append(F.l1_loss(g_sorted, t_sorted.detach()))
var_loss = torch.stack(var_losses).mean()
return grad_loss + var_loss
def compute_he_edge_loss(self, generated, he_input):
"""H&E edge structure preservation loss.
Extracts Sobel edges from H&E input and generated output, then
computes L1 loss between edge maps at multiple scales.
"""
with torch.amp.autocast('cuda', enabled=False):
gen = generated.float()
he = he_input.float()
gen_gray = ((gen + 1) / 2).mean(dim=1, keepdim=True)
he_gray = ((he + 1) / 2).mean(dim=1, keepdim=True)
sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]],
dtype=torch.float32, device=gen.device).view(1, 1, 3, 3)
sobel_y = sobel_x.transpose(-1, -2)
loss = 0.0
full_size = gen_gray.shape[-1]
scales = [full_size, full_size // 2]
for size in scales:
if size < full_size:
g = F.interpolate(gen_gray, size=size, mode='bilinear', align_corners=False)
h = F.interpolate(he_gray, size=size, mode='bilinear', align_corners=False)
else:
g, h = gen_gray, he_gray
gx_gen = F.conv2d(g, sobel_x, padding=1)
gy_gen = F.conv2d(g, sobel_y, padding=1)
edge_gen = (gx_gen**2 + gy_gen**2 + 1e-8).sqrt()
gx_he = F.conv2d(h, sobel_x, padding=1)
gy_he = F.conv2d(h, sobel_y, padding=1)
edge_he = (gx_he**2 + gy_he**2 + 1e-8).sqrt()
loss = loss + F.l1_loss(edge_gen, edge_he.detach())
return loss / 2.0
def compute_background_loss(self, generated, he_input):
"""Background white loss: push background regions toward white."""
with torch.amp.autocast('cuda', enabled=False):
gen = generated.float()
he = he_input.float()
he_bright = ((he + 1) / 2).mean(dim=1, keepdim=True)
threshold = self.hparams.bg_threshold
mask = torch.sigmoid((he_bright - threshold) * 20.0)
white_target = torch.ones_like(gen)
diff = (gen - white_target).abs()
weighted_diff = diff * mask
mask_sum = mask.sum() * 3
if mask_sum > 0:
return weighted_diff.sum() / mask_sum
return torch.tensor(0.0, device=gen.device, requires_grad=True)
def compute_gram_style_loss(self, generated, target):
"""Gram-matrix style loss: match texture statistics via VGG feature correlations."""
with torch.amp.autocast('cuda', enabled=False):
gen = generated.float()
tgt = target.float()
gen_256 = F.interpolate(gen, size=256, mode='bilinear', align_corners=False)
tgt_256 = F.interpolate(tgt, size=256, mode='bilinear', align_corners=False)
gen_feats = self.vgg_extractor(gen_256)
tgt_feats = self.vgg_extractor(tgt_256)
loss = 0.0
for gf, tf in zip(gen_feats, tgt_feats):
gram_gen = gram_matrix(gf)
gram_tgt = gram_matrix(tf)
loss = loss + F.l1_loss(gram_gen, gram_tgt.detach())
return loss / len(gen_feats)
def training_step(self, batch, batch_idx):
he, her2, uni_or_crops, labels, fnames = batch
opt_g, opt_d = self.optimizers()
# On-the-fly UNI extraction: dataset returns [B, 16, 3, 224, 224] sub-crops
if self._uni_extract_on_the_fly:
uni = self._extract_uni_from_sub_crops(uni_or_crops)
else:
uni = uni_or_crops
# Apply CFG dropout
labels_dropped, uni_dropped = self._apply_cfg_dropout(labels, uni)
# Ablation: zero out UNI features
if self.hparams.disable_uni:
uni_dropped = torch.zeros_like(uni_dropped)
# Ablation: force all labels to null class
if self.hparams.disable_class:
labels_dropped = torch.full_like(labels_dropped, self.hparams.null_class)
# ----------------------------------------------------------------
# Generator step
# ----------------------------------------------------------------
generated = self.generator(he, uni_dropped, labels_dropped)
# LPIPS main: 4x downsample (128 for 512 input, 256 for 1024)
lpips_main_size = self.hparams.image_size // 4
gen_lpips = F.interpolate(generated, size=lpips_main_size, mode='bilinear', align_corners=False)
her2_lpips = F.interpolate(her2, size=lpips_main_size, mode='bilinear', align_corners=False)
loss_lpips = self.lpips_fn(gen_lpips, her2_lpips).mean()
loss_g = self.hparams.lpips_weight * loss_lpips
# LPIPS fine: 2x downsample (256 for 512 input, 512 for 1024)
if self.hparams.lpips_256_weight > 0:
lpips_fine_size = self.hparams.image_size // 2
gen_fine = F.interpolate(generated, size=lpips_fine_size, mode='bilinear', align_corners=False)
her2_fine = F.interpolate(her2, size=lpips_fine_size, mode='bilinear', align_corners=False)
loss_lpips_256 = self.lpips_fn(gen_fine, her2_fine).mean()
loss_g = loss_g + self.hparams.lpips_256_weight * loss_lpips_256
self.log('train/lpips_fine', loss_lpips_256, prog_bar=False)
# LPIPS at full resolution (expensive)
if self.hparams.lpips_512_weight > 0:
loss_lpips_512 = self.lpips_fn(generated, her2).mean()
loss_g = loss_g + self.hparams.lpips_512_weight * loss_lpips_512
self.log('train/lpips_fullres', loss_lpips_512, prog_bar=False)
# Low-resolution L1 (color fidelity, misalignment-robust at 64x64)
if self.hparams.l1_lowres_weight > 0:
gen_64 = F.interpolate(generated, size=64, mode='bilinear', align_corners=False)
her2_64 = F.interpolate(her2, size=64, mode='bilinear', align_corners=False)
loss_l1_lowres = F.l1_loss(gen_64, her2_64)
loss_g = loss_g + self.hparams.l1_lowres_weight * loss_l1_lowres
self.log('train/l1_lowres', loss_l1_lowres, prog_bar=False)
# DAB losses (use original labels, not dropped)
if self.hparams.dab_intensity_weight > 0:
loss_dab = self.compute_dab_intensity_loss(generated, her2)
loss_g = loss_g + self.hparams.dab_intensity_weight * loss_dab
self.log('train/dab_intensity', loss_dab, prog_bar=False)
if self.hparams.dab_contrast_weight > 0:
loss_dab_contrast = self.compute_dab_contrast_loss(generated, labels)
loss_g = loss_g + self.hparams.dab_contrast_weight * loss_dab_contrast
self.log('train/dab_contrast', loss_dab_contrast, prog_bar=False)
# Edge loss (boundary sharpness)
if self.hparams.edge_weight > 0:
loss_edge = self.compute_edge_loss(generated, her2)
loss_g = loss_g + self.hparams.edge_weight * loss_edge
self.log('train/edge_loss', loss_edge, prog_bar=False)
# DAB sharpness loss (membrane-localized vs diffuse brown)
if self.hparams.dab_sharpness_weight > 0:
loss_dab_sharp = self.compute_dab_sharpness_loss(generated, her2)
loss_g = loss_g + self.hparams.dab_sharpness_weight * loss_dab_sharp
self.log('train/dab_sharpness', loss_dab_sharp, prog_bar=False)
# Gram-matrix style loss
if self.hparams.gram_style_weight > 0 and self.vgg_extractor is not None:
loss_gram = self.compute_gram_style_loss(generated, her2)
loss_g = loss_g + self.hparams.gram_style_weight * loss_gram
self.log('train/gram_style', loss_gram, prog_bar=False)
# H&E edge structure preservation (pixel-aligned)
if self.hparams.he_edge_weight > 0:
loss_he_edge = self.compute_he_edge_loss(generated, he)
loss_g = loss_g + self.hparams.he_edge_weight * loss_he_edge
self.log('train/he_edge', loss_he_edge, prog_bar=False)
# Background white loss
if self.hparams.bg_white_weight > 0:
loss_bg = self.compute_background_loss(generated, he)
loss_g = loss_g + self.hparams.bg_white_weight * loss_bg
self.log('train/bg_white', loss_bg, prog_bar=False)
# PatchNCE loss (contrastive: H&E input vs generated, never sees GT)
if self.hparams.patchnce_weight > 0 and self.patchnce_loss is not None:
feats_he = self.generator.encode(he)
feats_gen = self.generator.encode(generated)
loss_nce = self.patchnce_loss(feats_he, feats_gen)
loss_g = loss_g + self.hparams.patchnce_weight * loss_nce
self.log('train/patchnce', loss_nce, prog_bar=False)
# Adversarial losses (after warmup)
loss_adv = torch.tensor(0.0, device=self.device)
loss_feat_match = torch.tensor(0.0, device=self.device)
loss_crop_adv = torch.tensor(0.0, device=self.device)
loss_uncond_adv = torch.tensor(0.0, device=self.device)
any_adv = (self.hparams.adversarial_weight > 0 or
self.hparams.uncond_disc_weight > 0 or
self.hparams.crop_disc_weight > 0 or
self.hparams.feat_match_weight > 0)
img_sz = self.hparams.image_size
# Pre-compute disc-resolution tensors (512 for 1024 input, identity for 512)
if img_sz == 1024:
he_for_disc = F.interpolate(he, size=512, mode='bilinear', align_corners=False)
her2_for_disc = F.interpolate(her2, size=512, mode='bilinear', align_corners=False)
else:
he_for_disc = he
her2_for_disc = her2
if self.global_step >= self.hparams.adversarial_start_step and any_adv:
if img_sz == 1024:
gen_for_disc = F.interpolate(generated, size=512, mode='bilinear', align_corners=False)
else:
gen_for_disc = generated
# Conditional discriminator (paired: generated+HE vs real_HER2+HE)
if self.hparams.adversarial_weight > 0:
fake_input = torch.cat([gen_for_disc, he_for_disc], dim=1)
disc_outputs = self.discriminator(fake_input)
loss_adv = sum(hinge_loss_g(out) for out in disc_outputs) / len(disc_outputs)
loss_g = loss_g + self.hparams.adversarial_weight * loss_adv
# Feature matching from unconditional disc
if (self.hparams.feat_match_weight > 0 and
self.uncond_discriminator is not None):
_, fake_feats = self.uncond_discriminator(gen_for_disc, return_features=True)
with torch.no_grad():
_, real_feats = self.uncond_discriminator(her2_for_disc, return_features=True)
loss_feat_match = feature_matching_loss(fake_feats, real_feats)
loss_g = loss_g + self.hparams.feat_match_weight * loss_feat_match
# Crop discriminator: random crops at full resolution
if self.crop_discriminator is not None and self.hparams.crop_disc_weight > 0:
fake_input_crop = torch.cat([generated, he], dim=1)
cs = self.hparams.crop_size
top = torch.randint(0, img_sz - cs, (1,)).item()
left = torch.randint(0, img_sz - cs, (1,)).item()
fake_crop = fake_input_crop[:, :, top:top+cs, left:left+cs]
loss_crop_adv = hinge_loss_g(self.crop_discriminator(fake_crop))
loss_g = loss_g + self.hparams.crop_disc_weight * loss_crop_adv
# Unconditional discriminator: HER2-only adversarial
if self.uncond_discriminator is not None and self.hparams.uncond_disc_weight > 0:
loss_uncond_adv = hinge_loss_g(self.uncond_discriminator(gen_for_disc))
loss_g = loss_g + self.hparams.uncond_disc_weight * loss_uncond_adv
# Generator backward + step
lr_scale = self._get_lr_scale()
for pg in opt_g.param_groups:
pg['lr'] = self.hparams.gen_lr * lr_scale
opt_g.zero_grad()
self.manual_backward(loss_g)
torch.nn.utils.clip_grad_norm_(self.generator.parameters(), 1.0)
opt_g.step()
# Update EMA
self._update_ema()
# ----------------------------------------------------------------
# Discriminator step
# ----------------------------------------------------------------
loss_d = torch.tensor(0.0, device=self.device)
loss_crop_d = torch.tensor(0.0, device=self.device)
loss_uncond_d = torch.tensor(0.0, device=self.device)
if self.global_step >= self.hparams.adversarial_start_step and any_adv:
with torch.no_grad():
fake_detached = self.generator(he, uni_dropped, labels_dropped)
# For 1024, downsample for disc
if img_sz == 1024:
fake_det_disc = F.interpolate(fake_detached, size=512, mode='bilinear', align_corners=False)
else:
fake_det_disc = fake_detached
# Conditional discriminator
if self.hparams.adversarial_weight > 0:
real_input = torch.cat([her2_for_disc, he_for_disc], dim=1)
fake_input = torch.cat([fake_det_disc, he_for_disc], dim=1)
disc_real = self.discriminator(real_input)
disc_fake = self.discriminator(fake_input)
loss_d = sum(
hinge_loss_d(dr, df)
for dr, df in zip(disc_real, disc_fake)
) / len(disc_real)
# Crop discriminator
if self.crop_discriminator is not None and self.hparams.crop_disc_weight > 0:
real_input_c = torch.cat([her2, he], dim=1)
fake_input_c = torch.cat([fake_detached, he], dim=1)
cs = self.hparams.crop_size
top = torch.randint(0, img_sz - cs, (1,)).item()
left = torch.randint(0, img_sz - cs, (1,)).item()
real_crop = real_input_c[:, :, top:top+cs, left:left+cs]
fake_crop = fake_input_c[:, :, top:top+cs, left:left+cs]
loss_crop_d = hinge_loss_d(
self.crop_discriminator(real_crop),
self.crop_discriminator(fake_crop),
)
loss_d = loss_d + self.hparams.crop_disc_weight * loss_crop_d
# Unconditional discriminator
if self.uncond_discriminator is not None and (
self.hparams.uncond_disc_weight > 0 or self.hparams.feat_match_weight > 0):
uncond_real_out = self.uncond_discriminator(her2_for_disc)
uncond_fake_out = self.uncond_discriminator(fake_det_disc)
loss_uncond_d = hinge_loss_d(uncond_real_out, uncond_fake_out)
loss_d = loss_d + max(self.hparams.uncond_disc_weight, 1.0) * loss_uncond_d
# R1 gradient penalty
loss_r1 = torch.tensor(0.0, device=self.device)
if self.global_step % self.hparams.r1_every == 0:
with torch.amp.autocast('cuda', enabled=False):
if self.hparams.adversarial_weight > 0:
real_input_r1 = torch.cat([her2_for_disc, he_for_disc], dim=1).float().detach().requires_grad_(True)
for disc in [self.discriminator.disc_512]:
d_real = disc(real_input_r1)
grad_real = torch.autograd.grad(
outputs=d_real.sum(), inputs=real_input_r1,
create_graph=True,
)[0]
loss_r1 = loss_r1 + self.hparams.r1_weight * grad_real.pow(2).mean()
if self.uncond_discriminator is not None and (
self.hparams.uncond_disc_weight > 0 or self.hparams.feat_match_weight > 0):
her2_r1 = her2_for_disc.float().detach().requires_grad_(True)
d_real_uncond = self.uncond_discriminator(her2_r1)
grad_uncond = torch.autograd.grad(
outputs=d_real_uncond.sum(), inputs=her2_r1,
create_graph=True,
)[0]
loss_r1 = loss_r1 + self.hparams.r1_weight * grad_uncond.pow(2).mean()
loss_d = loss_d + loss_r1
self.log('train/r1_penalty', loss_r1, prog_bar=False)
opt_d.zero_grad()
self.manual_backward(loss_d)
if self.hparams.adversarial_weight > 0:
torch.nn.utils.clip_grad_norm_(self.discriminator.parameters(), 1.0)
if self.crop_discriminator is not None:
torch.nn.utils.clip_grad_norm_(self.crop_discriminator.parameters(), 1.0)
if self.uncond_discriminator is not None:
torch.nn.utils.clip_grad_norm_(self.uncond_discriminator.parameters(), 1.0)
opt_d.step()
# Logging
self.log('train/loss_g', loss_g, prog_bar=True)
self.log('train/loss_d', loss_d, prog_bar=True)
self.log('train/lpips', loss_lpips, prog_bar=True)
self.log('train/adversarial', loss_adv, prog_bar=False)
self.log('train/lr_scale', lr_scale, prog_bar=False)
if self.crop_discriminator is not None:
self.log('train/crop_adv_g', loss_crop_adv, prog_bar=False)
self.log('train/crop_adv_d', loss_crop_d, prog_bar=False)
if self.uncond_discriminator is not None:
self.log('train/uncond_adv_g', loss_uncond_adv, prog_bar=False)
self.log('train/uncond_adv_d', loss_uncond_d, prog_bar=False)
if self.hparams.feat_match_weight > 0:
self.log('train/feat_match', loss_feat_match, prog_bar=False)
def on_validation_epoch_start(self):
# Pick a random batch index for the second sample grid
n_val_batches = max(1, len(self.trainer.val_dataloaders))
self._random_val_batch_idx = torch.randint(1, max(2, n_val_batches), (1,)).item()
# Per-label sample collectors (for multi-stain visual grids)
self._val_per_label_samples = {}
def _log_sample_grid(self, he, her2_01, gen_01, key):
"""Log H&E | Real | Gen grid to wandb."""
n = min(4, len(he))
he_01 = ((he[:n].cpu() + 1) / 2).clamp(0, 1)
grid_images = []
for i in range(n):
grid_images.extend([
he_01[i],
her2_01[i].cpu(),
gen_01[i].cpu(),
])
grid = torchvision.utils.make_grid(grid_images, nrow=3, padding=2)
if self.logger:
self.logger.experiment.log({
key: [wandb.Image(grid, caption='H&E | Real | Gen')],
'global_step': self.global_step,
})
def validation_step(self, batch, batch_idx):
he, her2, uni_or_crops, labels, fnames = batch
# On-the-fly UNI extraction
if self._uni_extract_on_the_fly:
uni = self._extract_uni_from_sub_crops(uni_or_crops)
else:
uni = uni_or_crops
if self.hparams.disable_uni:
uni = torch.zeros_like(uni)
if self.hparams.disable_class:
labels = torch.full_like(labels, self.hparams.null_class)
# Use EMA generator
with torch.no_grad():
generated = self.generator_ema(he, uni, labels)
# LPIPS (4x downsample: 128 for 512, 256 for 1024)
lpips_size = self.hparams.image_size // 4
gen_lpips = F.interpolate(generated, size=lpips_size, mode='bilinear', align_corners=False)
her2_lpips = F.interpolate(her2, size=lpips_size, mode='bilinear', align_corners=False)
lpips_val = self.lpips_fn(gen_lpips, her2_lpips).mean()
# SSIM
gen_01 = ((generated + 1) / 2).clamp(0, 1)
her2_01 = ((her2 + 1) / 2).clamp(0, 1)
from torchmetrics.functional.image import structural_similarity_index_measure
ssim_val = structural_similarity_index_measure(gen_01, her2_01, data_range=1.0)
# DAB MAE (canonical: mean of top-10%)
dab_gen = self.dab_extractor.extract_dab_intensity(generated.float().cpu(), normalize="none")
dab_real = self.dab_extractor.extract_dab_intensity(her2.float().cpu(), normalize="none")
def p90_score(dab):
flat = dab.flatten()
p90 = torch.quantile(flat, 0.9)
mask = flat >= p90
return flat[mask].mean().item() if mask.sum() > 0 else flat.mean().item()
dab_mae = sum(
abs(p90_score(dab_gen[i]) - p90_score(dab_real[i]))
for i in range(len(dab_gen))
) / len(dab_gen)
self.log('val/lpips', lpips_val, prog_bar=True, sync_dist=True)
self.log('val/ssim', ssim_val, prog_bar=True, sync_dist=True)
self.log('val/dab_mae', dab_mae, prog_bar=True, sync_dist=True)
# Collect per-label samples for visual grids (multi-stain only)
if hasattr(self, '_val_per_label_samples'):
for i in range(len(labels)):
lbl = labels[i].item()
if lbl == self.hparams.null_class:
continue
if lbl not in self._val_per_label_samples:
self._val_per_label_samples[lbl] = {'he': [], 'real': [], 'gen': []}
bucket = self._val_per_label_samples[lbl]
if len(bucket['he']) < 4:
bucket['he'].append(he[i].cpu())
bucket['real'].append(her2_01[i].cpu())
bucket['gen'].append(gen_01[i].cpu())
# Log sample grids: first batch (fixed) + one random batch
if batch_idx == 0:
self._log_sample_grid(he, her2_01, gen_01, 'val/samples_fixed')
elif batch_idx == self._random_val_batch_idx:
self._log_sample_grid(he, her2_01, gen_01, 'val/samples_random')
def on_validation_epoch_end(self):
"""Log per-label sample grids if multiple labels are present."""
if not hasattr(self, '_val_per_label_samples') or len(self._val_per_label_samples) <= 1:
return
label_names = getattr(self.hparams, 'label_names', None)
for lbl, bucket in sorted(self._val_per_label_samples.items()):
if not bucket['he'] or not self.logger:
continue
name = label_names[lbl] if label_names and lbl < len(label_names) else str(lbl)
self._log_sample_grid(
torch.stack(bucket['he']),
torch.stack(bucket['real']),
torch.stack(bucket['gen']),
f'val/samples_{name}',
)
self._val_per_label_samples = {}
@torch.no_grad()
def generate(self, he_images, uni_features, labels,
num_inference_steps=None, guidance_scale=1.0, seed=None):
"""Generate IHC images from H&E input.
Args:
he_images: [B, 3, H, H] where H=512 or H=1024
uni_features: [B, N, 1024] where N=16 (4x4 CLS) or N=1024 (32x32 patch)
labels: [B] class/stain labels
num_inference_steps: ignored (single forward pass)
guidance_scale: CFG scale (1.0 = no guidance)
seed: random seed (for reproducibility, though model is deterministic)
"""
if seed is not None:
torch.manual_seed(seed)
gen = self.generator_ema if hasattr(self, 'generator_ema') else self.generator
if guidance_scale <= 1.0:
return gen(he_images, uni_features, labels)
# Classifier-free guidance
null_labels = torch.full_like(labels, self.null_class)
output_cond = gen(he_images, uni_features, labels)
output_uncond = gen(he_images, uni_features, null_labels)
output = output_uncond + guidance_scale * (output_cond - output_uncond)
return output.clamp(-1, 1)
|