Spaces:
Running
Running
File size: 7,882 Bytes
4db9215 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 | """
UNI feature processors: transform UNI pathology features into multi-scale spatial maps.
- UNIFeatureProcessor: for CLS-token features (4x4 = 16 tokens)
- UNIFeatureProcessorHighRes: for patch-token features (32x32 = 1024 tokens)
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
class UNIFeatureProcessor(nn.Module):
"""Process UNI features [B, 16, 1024] β multi-scale spatial feature maps.
UNI produces 16 spatial tokens (4x4 grid) of 1024-dim. We project to
generator channel dim and upsample to match each decoder layer resolution.
"""
def __init__(self, uni_dim=1024, base_channels=512):
super().__init__()
self.base_channels = base_channels
# Project UNI features to generator channel dim
self.proj = nn.Sequential(
nn.Linear(uni_dim, base_channels),
nn.LeakyReLU(0.2, inplace=True),
)
# Multi-scale upsamplers: 4Γ4 β {16, 32, 64, 128, 256}
# Each stage doubles spatial resolution
ch = base_channels
# 4β8β16
self.up_16 = nn.Sequential(
nn.ConvTranspose2d(ch, ch, 4, stride=2, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.ConvTranspose2d(ch, ch, 4, stride=2, padding=1),
nn.LeakyReLU(0.2, inplace=True),
)
# 16β32
self.up_32 = nn.Sequential(
nn.ConvTranspose2d(ch, ch, 4, stride=2, padding=1),
nn.LeakyReLU(0.2, inplace=True),
)
# 32β64
ch_64 = base_channels // 2 # 256
self.up_64 = nn.Sequential(
nn.ConvTranspose2d(ch, ch_64, 4, stride=2, padding=1),
nn.LeakyReLU(0.2, inplace=True),
)
# 64β128
ch_128 = base_channels // 4 # 128
self.up_128 = nn.Sequential(
nn.ConvTranspose2d(ch_64, ch_128, 4, stride=2, padding=1),
nn.LeakyReLU(0.2, inplace=True),
)
# 128β256
ch_256 = base_channels // 8 # 64
self.up_256 = nn.Sequential(
nn.ConvTranspose2d(ch_128, ch_256, 4, stride=2, padding=1),
nn.LeakyReLU(0.2, inplace=True),
)
def forward(self, uni_features):
"""
Args:
uni_features: [B, 16, 1024]
Returns:
dict of spatial feature maps at each resolution
"""
B = uni_features.shape[0]
# Project and reshape to spatial
x = self.proj(uni_features) # [B, 16, 512]
x = x.permute(0, 2, 1).reshape(B, self.base_channels, 4, 4) # [B, 512, 4, 4]
# Multi-scale upsampling
feat_16 = self.up_16(x) # [B, 512, 16, 16]
feat_32 = self.up_32(feat_16) # [B, 512, 32, 32]
feat_64 = self.up_64(feat_32) # [B, 256, 64, 64]
feat_128 = self.up_128(feat_64) # [B, 128, 128, 128]
feat_256 = self.up_256(feat_128) # [B, 64, 256, 256]
return {
16: feat_16,
32: feat_32,
64: feat_64,
128: feat_128,
256: feat_256,
}
class UNIFeatureProcessorHighRes(nn.Module):
"""Process high-res UNI features [B, 1024, 1024] β multi-scale spatial maps.
With patch-token extraction, UNI produces 1024 tokens (32x32 spatial grid)
of 1024-dim β 64x more spatial resolution than the CLS-only 4x4 grid.
Since we START at 32x32, we process features with Conv2d (no hallucinated
upsampling). Every spatial feature is backed by real UNI patch tokens.
Architecture:
32x32 input β conv process β feat_32 (512ch)
32β64 upsample β conv β feat_64 (256ch)
64β128 upsample β conv β feat_128 (128ch)
128β256 upsample β conv β feat_256 (64ch)
Also: 32β16 downsample β feat_16 (512ch, for bottleneck)
"""
def __init__(self, uni_dim=1024, base_channels=512, spatial_size=32,
output_512=False):
super().__init__()
self.base_channels = base_channels
self.spatial_size = spatial_size
self.output_512 = output_512
ch = base_channels
# Project UNI 1024-dim β 512-dim per token
self.proj = nn.Sequential(
nn.Linear(uni_dim, ch),
nn.LeakyReLU(0.2, inplace=True),
)
# Process at 32x32 (native resolution) β refine projected features
self.proc_32 = nn.Sequential(
nn.Conv2d(ch, ch, 3, padding=1),
nn.InstanceNorm2d(ch),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(ch, ch, 3, padding=1),
nn.InstanceNorm2d(ch),
nn.LeakyReLU(0.2, inplace=True),
)
# 32β16 downsample (for bottleneck conditioning)
self.down_16 = nn.Sequential(
nn.Conv2d(ch, ch, 4, stride=2, padding=1),
nn.InstanceNorm2d(ch),
nn.LeakyReLU(0.2, inplace=True),
)
# 32β64 upsample + refine
ch_64 = ch // 2 # 256
self.up_64 = nn.Sequential(
nn.ConvTranspose2d(ch, ch_64, 4, stride=2, padding=1),
nn.InstanceNorm2d(ch_64),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(ch_64, ch_64, 3, padding=1),
nn.InstanceNorm2d(ch_64),
nn.LeakyReLU(0.2, inplace=True),
)
# 64β128 upsample + refine
ch_128 = ch // 4 # 128
self.up_128 = nn.Sequential(
nn.ConvTranspose2d(ch_64, ch_128, 4, stride=2, padding=1),
nn.InstanceNorm2d(ch_128),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(ch_128, ch_128, 3, padding=1),
nn.InstanceNorm2d(ch_128),
nn.LeakyReLU(0.2, inplace=True),
)
# 128β256 upsample + refine
ch_256 = ch // 8 # 64
self.up_256 = nn.Sequential(
nn.ConvTranspose2d(ch_128, ch_256, 4, stride=2, padding=1),
nn.InstanceNorm2d(ch_256),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(ch_256, ch_256, 3, padding=1),
nn.InstanceNorm2d(ch_256),
nn.LeakyReLU(0.2, inplace=True),
)
# 256β512 upsample (for 1024 models with SPADE at dec1)
if output_512:
ch_512 = ch // 16 # 32
self.up_512 = nn.Sequential(
nn.ConvTranspose2d(ch_256, ch_512, 4, stride=2, padding=1),
nn.InstanceNorm2d(ch_512),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(ch_512, ch_512, 3, padding=1),
nn.InstanceNorm2d(ch_512),
nn.LeakyReLU(0.2, inplace=True),
)
def forward(self, uni_features):
"""
Args:
uni_features: [B, S*S, 1024] where S = spatial_size (default 32)
Returns:
dict of spatial feature maps: {16, 32, 64, 128, 256}
"""
B = uni_features.shape[0]
S = self.spatial_size
# Project and reshape to spatial grid
x = self.proj(uni_features) # [B, S*S, 512]
x = x.permute(0, 2, 1).reshape(B, self.base_channels, S, S) # [B, 512, 32, 32]
# Process at native 32x32
feat_32 = self.proc_32(x) + x # residual connection
# Downsample for bottleneck
feat_16 = self.down_16(feat_32) # [B, 512, 16, 16]
# Upsample path β each level adds spatial detail from real UNI tokens
feat_64 = self.up_64(feat_32) # [B, 256, 64, 64]
feat_128 = self.up_128(feat_64) # [B, 128, 128, 128]
feat_256 = self.up_256(feat_128) # [B, 64, 256, 256]
out = {
16: feat_16,
32: feat_32,
64: feat_64,
128: feat_128,
256: feat_256,
}
if self.output_512:
feat_512 = self.up_512(feat_256) # [B, 32, 512, 512]
out[512] = feat_512
return out
|