UNIStainNet / src /models /edge_encoder.py
faceless-void's picture
Upload folder using huggingface_hub
4db9215 verified
"""
Edge encoders for UNIStainNet: parallel structure pathway from H&E edges.
- EdgeEncoder (v1): Sequential Sobel β†’ multi-scale CNN
- MultiScaleEdgeEncoder (v2): Independent per-scale edge extraction with RGB input
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
class EdgeEncoder(nn.Module):
"""Lightweight encoder that extracts multi-scale edge features from H&E input.
Extracts Sobel edges from grayscale H&E, then encodes them through a small
CNN to produce multi-scale feature maps. These are concatenated with the
main encoder's skip connections in the decoder, giving the generator an
explicit structural signal.
Key insight: H&E input and generated output share the exact same spatial
frame (no misalignment). So edge features from H&E are pixel-aligned with
the decoder's output β€” unlike real HER2 ground truth.
"""
def __init__(self, base_ch=32):
super().__init__()
# Sobel kernels (fixed, not learned)
sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]],
dtype=torch.float32).view(1, 1, 3, 3)
sobel_y = sobel_x.transpose(-1, -2)
self.register_buffer('sobel_x', sobel_x)
self.register_buffer('sobel_y', sobel_y)
# Edge feature encoder: 2ch (grad_x, grad_y) β†’ multi-scale features
# Mirrors the main encoder's spatial hierarchy
self.enc1 = nn.Sequential( # 512β†’256, out: base_ch
nn.Conv2d(2, base_ch, 4, stride=2, padding=1),
nn.LeakyReLU(0.2, inplace=True),
)
self.enc2 = nn.Sequential( # 256β†’128, out: base_ch*2
nn.Conv2d(base_ch, base_ch * 2, 4, stride=2, padding=1),
nn.InstanceNorm2d(base_ch * 2),
nn.LeakyReLU(0.2, inplace=True),
)
self.enc3 = nn.Sequential( # 128β†’64, out: base_ch*4
nn.Conv2d(base_ch * 2, base_ch * 4, 4, stride=2, padding=1),
nn.InstanceNorm2d(base_ch * 4),
nn.LeakyReLU(0.2, inplace=True),
)
self.enc4 = nn.Sequential( # 64β†’32, out: base_ch*4
nn.Conv2d(base_ch * 4, base_ch * 4, 4, stride=2, padding=1),
nn.InstanceNorm2d(base_ch * 4),
nn.LeakyReLU(0.2, inplace=True),
)
def forward(self, he_images):
"""
Args:
he_images: [B, 3, 512, 512] in [-1, 1]
Returns:
dict of edge features at each decoder resolution:
256: [B, base_ch, 256, 256]
128: [B, base_ch*2, 128, 128]
64: [B, base_ch*4, 64, 64]
32: [B, base_ch*4, 32, 32]
"""
# Convert to grayscale [0, 1]
gray = ((he_images + 1) / 2).mean(dim=1, keepdim=True) # [B, 1, 512, 512]
# Sobel edge detection
gx = F.conv2d(gray, self.sobel_x, padding=1)
gy = F.conv2d(gray, self.sobel_y, padding=1)
edges = torch.cat([gx, gy], dim=1) # [B, 2, 512, 512]
# Multi-scale encoding
e1 = self.enc1(edges) # [B, base_ch, 256, 256]
e2 = self.enc2(e1) # [B, base_ch*2, 128, 128]
e3 = self.enc3(e2) # [B, base_ch*4, 64, 64]
e4 = self.enc4(e3) # [B, base_ch*4, 32, 32]
return {256: e1, 128: e2, 64: e3, 32: e4}
class MultiScaleEdgeEncoder(nn.Module):
"""Multi-scale edge encoder with independent per-scale edge extraction.
Improvements over EdgeEncoder:
1. RGB-aware: Learnable first layer on full RGB (can discover stain-specific
edges β€” e.g., hematoxylin boundaries vs eosin boundaries carry different
information for HER2 staining).
2. Multi-scale Sobel: Extracts edges independently at each resolution before
encoding. Fine 2-5px edges don't get lost through sequential downsampling.
3. Edge features at 512: Provides features at output resolution for fine
structure preservation (cell walls, membrane patterns).
"""
def __init__(self, base_ch=32):
super().__init__()
# Fixed Sobel kernels for structural prior
sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]],
dtype=torch.float32).view(1, 1, 3, 3)
sobel_y = sobel_x.transpose(-1, -2)
self.register_buffer('sobel_x', sobel_x)
self.register_buffer('sobel_y', sobel_y)
# Per-scale feature extractors
# Input: 3ch RGB + 2ch Sobel = 5ch at each scale
in_ch = 5
# 512β†’512 (edge features at output resolution)
self.scale_512 = nn.Sequential(
nn.Conv2d(in_ch, base_ch, 3, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(base_ch, base_ch, 3, padding=1),
nn.LeakyReLU(0.2, inplace=True),
)
# 256Γ—256
self.scale_256 = nn.Sequential(
nn.Conv2d(in_ch, base_ch, 3, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(base_ch, base_ch, 3, padding=1),
nn.LeakyReLU(0.2, inplace=True),
)
# 128Γ—128
self.scale_128 = nn.Sequential(
nn.Conv2d(in_ch, base_ch * 2, 3, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(base_ch * 2, base_ch * 2, 3, padding=1),
nn.LeakyReLU(0.2, inplace=True),
)
# 64Γ—64
self.scale_64 = nn.Sequential(
nn.Conv2d(in_ch, base_ch * 4, 3, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(base_ch * 4, base_ch * 4, 3, padding=1),
nn.LeakyReLU(0.2, inplace=True),
)
# 32Γ—32
self.scale_32 = nn.Sequential(
nn.Conv2d(in_ch, base_ch * 4, 3, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(base_ch * 4, base_ch * 4, 3, padding=1),
nn.LeakyReLU(0.2, inplace=True),
)
def _extract_edges_at_scale(self, he_01, size):
"""Downsample H&E, extract Sobel edges, return RGB+edges."""
if size < 512:
h = F.interpolate(he_01, size=size, mode='bilinear', align_corners=False)
else:
h = he_01
gray = h.mean(dim=1, keepdim=True)
gx = F.conv2d(gray, self.sobel_x, padding=1)
gy = F.conv2d(gray, self.sobel_y, padding=1)
return torch.cat([h, gx, gy], dim=1) # [B, 5, size, size]
def forward(self, he_images):
"""
Args:
he_images: [B, 3, 512, 512] in [-1, 1]
Returns:
dict of edge features at each decoder resolution:
512: [B, base_ch, 512, 512]
256: [B, base_ch, 256, 256]
128: [B, base_ch*2, 128, 128]
64: [B, base_ch*4, 64, 64]
32: [B, base_ch*4, 32, 32]
"""
he_01 = (he_images + 1) / 2 # [0, 1] for consistent edge magnitudes
return {
512: self.scale_512(self._extract_edges_at_scale(he_01, 512)),
256: self.scale_256(self._extract_edges_at_scale(he_01, 256)),
128: self.scale_128(self._extract_edges_at_scale(he_01, 128)),
64: self.scale_64(self._extract_edges_at_scale(he_01, 64)),
32: self.scale_32(self._extract_edges_at_scale(he_01, 32)),
}