Spaces:
Running
Running
| """ | |
| UNI feature processors: transform UNI pathology features into multi-scale spatial maps. | |
| - UNIFeatureProcessor: for CLS-token features (4x4 = 16 tokens) | |
| - UNIFeatureProcessorHighRes: for patch-token features (32x32 = 1024 tokens) | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class UNIFeatureProcessor(nn.Module): | |
| """Process UNI features [B, 16, 1024] β multi-scale spatial feature maps. | |
| UNI produces 16 spatial tokens (4x4 grid) of 1024-dim. We project to | |
| generator channel dim and upsample to match each decoder layer resolution. | |
| """ | |
| def __init__(self, uni_dim=1024, base_channels=512): | |
| super().__init__() | |
| self.base_channels = base_channels | |
| # Project UNI features to generator channel dim | |
| self.proj = nn.Sequential( | |
| nn.Linear(uni_dim, base_channels), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| ) | |
| # Multi-scale upsamplers: 4Γ4 β {16, 32, 64, 128, 256} | |
| # Each stage doubles spatial resolution | |
| ch = base_channels | |
| # 4β8β16 | |
| self.up_16 = nn.Sequential( | |
| nn.ConvTranspose2d(ch, ch, 4, stride=2, padding=1), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| nn.ConvTranspose2d(ch, ch, 4, stride=2, padding=1), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| ) | |
| # 16β32 | |
| self.up_32 = nn.Sequential( | |
| nn.ConvTranspose2d(ch, ch, 4, stride=2, padding=1), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| ) | |
| # 32β64 | |
| ch_64 = base_channels // 2 # 256 | |
| self.up_64 = nn.Sequential( | |
| nn.ConvTranspose2d(ch, ch_64, 4, stride=2, padding=1), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| ) | |
| # 64β128 | |
| ch_128 = base_channels // 4 # 128 | |
| self.up_128 = nn.Sequential( | |
| nn.ConvTranspose2d(ch_64, ch_128, 4, stride=2, padding=1), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| ) | |
| # 128β256 | |
| ch_256 = base_channels // 8 # 64 | |
| self.up_256 = nn.Sequential( | |
| nn.ConvTranspose2d(ch_128, ch_256, 4, stride=2, padding=1), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| ) | |
| def forward(self, uni_features): | |
| """ | |
| Args: | |
| uni_features: [B, 16, 1024] | |
| Returns: | |
| dict of spatial feature maps at each resolution | |
| """ | |
| B = uni_features.shape[0] | |
| # Project and reshape to spatial | |
| x = self.proj(uni_features) # [B, 16, 512] | |
| x = x.permute(0, 2, 1).reshape(B, self.base_channels, 4, 4) # [B, 512, 4, 4] | |
| # Multi-scale upsampling | |
| feat_16 = self.up_16(x) # [B, 512, 16, 16] | |
| feat_32 = self.up_32(feat_16) # [B, 512, 32, 32] | |
| feat_64 = self.up_64(feat_32) # [B, 256, 64, 64] | |
| feat_128 = self.up_128(feat_64) # [B, 128, 128, 128] | |
| feat_256 = self.up_256(feat_128) # [B, 64, 256, 256] | |
| return { | |
| 16: feat_16, | |
| 32: feat_32, | |
| 64: feat_64, | |
| 128: feat_128, | |
| 256: feat_256, | |
| } | |
| class UNIFeatureProcessorHighRes(nn.Module): | |
| """Process high-res UNI features [B, 1024, 1024] β multi-scale spatial maps. | |
| With patch-token extraction, UNI produces 1024 tokens (32x32 spatial grid) | |
| of 1024-dim β 64x more spatial resolution than the CLS-only 4x4 grid. | |
| Since we START at 32x32, we process features with Conv2d (no hallucinated | |
| upsampling). Every spatial feature is backed by real UNI patch tokens. | |
| Architecture: | |
| 32x32 input β conv process β feat_32 (512ch) | |
| 32β64 upsample β conv β feat_64 (256ch) | |
| 64β128 upsample β conv β feat_128 (128ch) | |
| 128β256 upsample β conv β feat_256 (64ch) | |
| Also: 32β16 downsample β feat_16 (512ch, for bottleneck) | |
| """ | |
| def __init__(self, uni_dim=1024, base_channels=512, spatial_size=32, | |
| output_512=False): | |
| super().__init__() | |
| self.base_channels = base_channels | |
| self.spatial_size = spatial_size | |
| self.output_512 = output_512 | |
| ch = base_channels | |
| # Project UNI 1024-dim β 512-dim per token | |
| self.proj = nn.Sequential( | |
| nn.Linear(uni_dim, ch), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| ) | |
| # Process at 32x32 (native resolution) β refine projected features | |
| self.proc_32 = nn.Sequential( | |
| nn.Conv2d(ch, ch, 3, padding=1), | |
| nn.InstanceNorm2d(ch), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| nn.Conv2d(ch, ch, 3, padding=1), | |
| nn.InstanceNorm2d(ch), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| ) | |
| # 32β16 downsample (for bottleneck conditioning) | |
| self.down_16 = nn.Sequential( | |
| nn.Conv2d(ch, ch, 4, stride=2, padding=1), | |
| nn.InstanceNorm2d(ch), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| ) | |
| # 32β64 upsample + refine | |
| ch_64 = ch // 2 # 256 | |
| self.up_64 = nn.Sequential( | |
| nn.ConvTranspose2d(ch, ch_64, 4, stride=2, padding=1), | |
| nn.InstanceNorm2d(ch_64), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| nn.Conv2d(ch_64, ch_64, 3, padding=1), | |
| nn.InstanceNorm2d(ch_64), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| ) | |
| # 64β128 upsample + refine | |
| ch_128 = ch // 4 # 128 | |
| self.up_128 = nn.Sequential( | |
| nn.ConvTranspose2d(ch_64, ch_128, 4, stride=2, padding=1), | |
| nn.InstanceNorm2d(ch_128), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| nn.Conv2d(ch_128, ch_128, 3, padding=1), | |
| nn.InstanceNorm2d(ch_128), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| ) | |
| # 128β256 upsample + refine | |
| ch_256 = ch // 8 # 64 | |
| self.up_256 = nn.Sequential( | |
| nn.ConvTranspose2d(ch_128, ch_256, 4, stride=2, padding=1), | |
| nn.InstanceNorm2d(ch_256), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| nn.Conv2d(ch_256, ch_256, 3, padding=1), | |
| nn.InstanceNorm2d(ch_256), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| ) | |
| # 256β512 upsample (for 1024 models with SPADE at dec1) | |
| if output_512: | |
| ch_512 = ch // 16 # 32 | |
| self.up_512 = nn.Sequential( | |
| nn.ConvTranspose2d(ch_256, ch_512, 4, stride=2, padding=1), | |
| nn.InstanceNorm2d(ch_512), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| nn.Conv2d(ch_512, ch_512, 3, padding=1), | |
| nn.InstanceNorm2d(ch_512), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| ) | |
| def forward(self, uni_features): | |
| """ | |
| Args: | |
| uni_features: [B, S*S, 1024] where S = spatial_size (default 32) | |
| Returns: | |
| dict of spatial feature maps: {16, 32, 64, 128, 256} | |
| """ | |
| B = uni_features.shape[0] | |
| S = self.spatial_size | |
| # Project and reshape to spatial grid | |
| x = self.proj(uni_features) # [B, S*S, 512] | |
| x = x.permute(0, 2, 1).reshape(B, self.base_channels, S, S) # [B, 512, 32, 32] | |
| # Process at native 32x32 | |
| feat_32 = self.proc_32(x) + x # residual connection | |
| # Downsample for bottleneck | |
| feat_16 = self.down_16(feat_32) # [B, 512, 16, 16] | |
| # Upsample path β each level adds spatial detail from real UNI tokens | |
| feat_64 = self.up_64(feat_32) # [B, 256, 64, 64] | |
| feat_128 = self.up_128(feat_64) # [B, 128, 128, 128] | |
| feat_256 = self.up_256(feat_128) # [B, 64, 256, 256] | |
| out = { | |
| 16: feat_16, | |
| 32: feat_32, | |
| 64: feat_64, | |
| 128: feat_128, | |
| 256: feat_256, | |
| } | |
| if self.output_512: | |
| feat_512 = self.up_512(feat_256) # [B, 32, 512, 512] | |
| out[512] = feat_512 | |
| return out | |