Spaces:
Running
Running
| """ | |
| Building blocks for UNIStainNet generator. | |
| - SPADEBlock: SPADE + FiLM normalization (UNI spatial + class channel modulation) | |
| - ResBlock: Residual block with InstanceNorm | |
| - SelfAttention: Self-attention for global context at bottleneck | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class SPADEBlock(nn.Module): | |
| """SPADE + FiLM normalization block. | |
| Combines spatially-adaptive normalization from UNI features (SPADE) | |
| with channel-wise affine modulation from class embedding (FiLM). | |
| """ | |
| def __init__(self, norm_channels, uni_channels, class_dim=64): | |
| super().__init__() | |
| self.norm = nn.InstanceNorm2d(norm_channels, affine=False) | |
| # SPADE: learn spatial gamma/beta from UNI features | |
| hidden = min(128, norm_channels) | |
| self.spade_shared = nn.Sequential( | |
| nn.Conv2d(uni_channels, hidden, 3, padding=1), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| ) | |
| self.spade_gamma = nn.Conv2d(hidden, norm_channels, 3, padding=1) | |
| self.spade_beta = nn.Conv2d(hidden, norm_channels, 3, padding=1) | |
| # FiLM: learn channel gamma/beta from class embedding | |
| self.film_gamma = nn.Linear(class_dim, norm_channels) | |
| self.film_beta = nn.Linear(class_dim, norm_channels) | |
| # Init SPADE gamma/beta near zero (ControlNet-style gradual activation) | |
| nn.init.zeros_(self.spade_gamma.weight) | |
| nn.init.zeros_(self.spade_gamma.bias) | |
| nn.init.zeros_(self.spade_beta.weight) | |
| nn.init.zeros_(self.spade_beta.bias) | |
| # Init FiLM gamma near 1, beta near 0 | |
| nn.init.ones_(self.film_gamma.weight) | |
| nn.init.zeros_(self.film_gamma.bias) | |
| nn.init.zeros_(self.film_beta.weight) | |
| nn.init.zeros_(self.film_beta.bias) | |
| def forward(self, x, uni_spatial, class_emb): | |
| """ | |
| Args: | |
| x: [B, C, H, W] feature map | |
| uni_spatial: [B, uni_ch, H, W] UNI features at matching resolution | |
| class_emb: [B, class_dim] class embedding | |
| """ | |
| normalized = self.norm(x) | |
| # SPADE modulation from UNI features | |
| shared = self.spade_shared(uni_spatial) | |
| gamma_s = self.spade_gamma(shared) | |
| beta_s = self.spade_beta(shared) | |
| # FiLM modulation from class | |
| gamma_c = self.film_gamma(class_emb).unsqueeze(-1).unsqueeze(-1) # [B, C, 1, 1] | |
| beta_c = self.film_beta(class_emb).unsqueeze(-1).unsqueeze(-1) | |
| # Combined: (gamma_spade + gamma_film) * norm(x) + (beta_spade + beta_film) | |
| return (gamma_s + gamma_c) * normalized + (beta_s + beta_c) | |
| class ResBlock(nn.Module): | |
| """Residual block with InstanceNorm.""" | |
| def __init__(self, channels): | |
| super().__init__() | |
| self.block = nn.Sequential( | |
| nn.Conv2d(channels, channels, 3, padding=1), | |
| nn.InstanceNorm2d(channels), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| nn.Conv2d(channels, channels, 3, padding=1), | |
| nn.InstanceNorm2d(channels), | |
| ) | |
| self.act = nn.LeakyReLU(0.2, inplace=True) | |
| def forward(self, x): | |
| return self.act(x + self.block(x)) | |
| class SelfAttention(nn.Module): | |
| """Self-attention layer for global context at bottleneck.""" | |
| def __init__(self, channels): | |
| super().__init__() | |
| self.norm = nn.GroupNorm(32, channels) | |
| self.qkv = nn.Conv2d(channels, channels * 3, 1) | |
| self.proj = nn.Conv2d(channels, channels, 1) | |
| self.scale = channels ** -0.5 | |
| def forward(self, x): | |
| B, C, H, W = x.shape | |
| h = self.norm(x) | |
| qkv = self.qkv(h).reshape(B, 3, C, H * W) | |
| q, k, v = qkv[:, 0], qkv[:, 1], qkv[:, 2] | |
| attn = (q.transpose(-1, -2) @ k) * self.scale | |
| attn = attn.softmax(dim=-1) | |
| out = (v @ attn.transpose(-1, -2)).reshape(B, C, H, W) | |
| return x + self.proj(out) | |