File size: 3,875 Bytes
4db9215
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
"""
Building blocks for UNIStainNet generator.

- SPADEBlock: SPADE + FiLM normalization (UNI spatial + class channel modulation)
- ResBlock: Residual block with InstanceNorm
- SelfAttention: Self-attention for global context at bottleneck
"""

import torch
import torch.nn as nn
import torch.nn.functional as F


class SPADEBlock(nn.Module):
    """SPADE + FiLM normalization block.

    Combines spatially-adaptive normalization from UNI features (SPADE)
    with channel-wise affine modulation from class embedding (FiLM).
    """

    def __init__(self, norm_channels, uni_channels, class_dim=64):
        super().__init__()
        self.norm = nn.InstanceNorm2d(norm_channels, affine=False)

        # SPADE: learn spatial gamma/beta from UNI features
        hidden = min(128, norm_channels)
        self.spade_shared = nn.Sequential(
            nn.Conv2d(uni_channels, hidden, 3, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
        )
        self.spade_gamma = nn.Conv2d(hidden, norm_channels, 3, padding=1)
        self.spade_beta = nn.Conv2d(hidden, norm_channels, 3, padding=1)

        # FiLM: learn channel gamma/beta from class embedding
        self.film_gamma = nn.Linear(class_dim, norm_channels)
        self.film_beta = nn.Linear(class_dim, norm_channels)

        # Init SPADE gamma/beta near zero (ControlNet-style gradual activation)
        nn.init.zeros_(self.spade_gamma.weight)
        nn.init.zeros_(self.spade_gamma.bias)
        nn.init.zeros_(self.spade_beta.weight)
        nn.init.zeros_(self.spade_beta.bias)

        # Init FiLM gamma near 1, beta near 0
        nn.init.ones_(self.film_gamma.weight)
        nn.init.zeros_(self.film_gamma.bias)
        nn.init.zeros_(self.film_beta.weight)
        nn.init.zeros_(self.film_beta.bias)

    def forward(self, x, uni_spatial, class_emb):
        """
        Args:
            x: [B, C, H, W] feature map
            uni_spatial: [B, uni_ch, H, W] UNI features at matching resolution
            class_emb: [B, class_dim] class embedding
        """
        normalized = self.norm(x)

        # SPADE modulation from UNI features
        shared = self.spade_shared(uni_spatial)
        gamma_s = self.spade_gamma(shared)
        beta_s = self.spade_beta(shared)

        # FiLM modulation from class
        gamma_c = self.film_gamma(class_emb).unsqueeze(-1).unsqueeze(-1)  # [B, C, 1, 1]
        beta_c = self.film_beta(class_emb).unsqueeze(-1).unsqueeze(-1)

        # Combined: (gamma_spade + gamma_film) * norm(x) + (beta_spade + beta_film)
        return (gamma_s + gamma_c) * normalized + (beta_s + beta_c)


class ResBlock(nn.Module):
    """Residual block with InstanceNorm."""

    def __init__(self, channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(channels, channels, 3, padding=1),
            nn.InstanceNorm2d(channels),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(channels, channels, 3, padding=1),
            nn.InstanceNorm2d(channels),
        )
        self.act = nn.LeakyReLU(0.2, inplace=True)

    def forward(self, x):
        return self.act(x + self.block(x))


class SelfAttention(nn.Module):
    """Self-attention layer for global context at bottleneck."""

    def __init__(self, channels):
        super().__init__()
        self.norm = nn.GroupNorm(32, channels)
        self.qkv = nn.Conv2d(channels, channels * 3, 1)
        self.proj = nn.Conv2d(channels, channels, 1)
        self.scale = channels ** -0.5

    def forward(self, x):
        B, C, H, W = x.shape
        h = self.norm(x)
        qkv = self.qkv(h).reshape(B, 3, C, H * W)
        q, k, v = qkv[:, 0], qkv[:, 1], qkv[:, 2]

        attn = (q.transpose(-1, -2) @ k) * self.scale
        attn = attn.softmax(dim=-1)
        out = (v @ attn.transpose(-1, -2)).reshape(B, C, H, W)
        return x + self.proj(out)