Spaces:
Running
Running
| """ | |
| Edge encoders for UNIStainNet: parallel structure pathway from H&E edges. | |
| - EdgeEncoder (v1): Sequential Sobel β multi-scale CNN | |
| - MultiScaleEdgeEncoder (v2): Independent per-scale edge extraction with RGB input | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class EdgeEncoder(nn.Module): | |
| """Lightweight encoder that extracts multi-scale edge features from H&E input. | |
| Extracts Sobel edges from grayscale H&E, then encodes them through a small | |
| CNN to produce multi-scale feature maps. These are concatenated with the | |
| main encoder's skip connections in the decoder, giving the generator an | |
| explicit structural signal. | |
| Key insight: H&E input and generated output share the exact same spatial | |
| frame (no misalignment). So edge features from H&E are pixel-aligned with | |
| the decoder's output β unlike real HER2 ground truth. | |
| """ | |
| def __init__(self, base_ch=32): | |
| super().__init__() | |
| # Sobel kernels (fixed, not learned) | |
| sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], | |
| dtype=torch.float32).view(1, 1, 3, 3) | |
| sobel_y = sobel_x.transpose(-1, -2) | |
| self.register_buffer('sobel_x', sobel_x) | |
| self.register_buffer('sobel_y', sobel_y) | |
| # Edge feature encoder: 2ch (grad_x, grad_y) β multi-scale features | |
| # Mirrors the main encoder's spatial hierarchy | |
| self.enc1 = nn.Sequential( # 512β256, out: base_ch | |
| nn.Conv2d(2, base_ch, 4, stride=2, padding=1), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| ) | |
| self.enc2 = nn.Sequential( # 256β128, out: base_ch*2 | |
| nn.Conv2d(base_ch, base_ch * 2, 4, stride=2, padding=1), | |
| nn.InstanceNorm2d(base_ch * 2), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| ) | |
| self.enc3 = nn.Sequential( # 128β64, out: base_ch*4 | |
| nn.Conv2d(base_ch * 2, base_ch * 4, 4, stride=2, padding=1), | |
| nn.InstanceNorm2d(base_ch * 4), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| ) | |
| self.enc4 = nn.Sequential( # 64β32, out: base_ch*4 | |
| nn.Conv2d(base_ch * 4, base_ch * 4, 4, stride=2, padding=1), | |
| nn.InstanceNorm2d(base_ch * 4), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| ) | |
| def forward(self, he_images): | |
| """ | |
| Args: | |
| he_images: [B, 3, 512, 512] in [-1, 1] | |
| Returns: | |
| dict of edge features at each decoder resolution: | |
| 256: [B, base_ch, 256, 256] | |
| 128: [B, base_ch*2, 128, 128] | |
| 64: [B, base_ch*4, 64, 64] | |
| 32: [B, base_ch*4, 32, 32] | |
| """ | |
| # Convert to grayscale [0, 1] | |
| gray = ((he_images + 1) / 2).mean(dim=1, keepdim=True) # [B, 1, 512, 512] | |
| # Sobel edge detection | |
| gx = F.conv2d(gray, self.sobel_x, padding=1) | |
| gy = F.conv2d(gray, self.sobel_y, padding=1) | |
| edges = torch.cat([gx, gy], dim=1) # [B, 2, 512, 512] | |
| # Multi-scale encoding | |
| e1 = self.enc1(edges) # [B, base_ch, 256, 256] | |
| e2 = self.enc2(e1) # [B, base_ch*2, 128, 128] | |
| e3 = self.enc3(e2) # [B, base_ch*4, 64, 64] | |
| e4 = self.enc4(e3) # [B, base_ch*4, 32, 32] | |
| return {256: e1, 128: e2, 64: e3, 32: e4} | |
| class MultiScaleEdgeEncoder(nn.Module): | |
| """Multi-scale edge encoder with independent per-scale edge extraction. | |
| Improvements over EdgeEncoder: | |
| 1. RGB-aware: Learnable first layer on full RGB (can discover stain-specific | |
| edges β e.g., hematoxylin boundaries vs eosin boundaries carry different | |
| information for HER2 staining). | |
| 2. Multi-scale Sobel: Extracts edges independently at each resolution before | |
| encoding. Fine 2-5px edges don't get lost through sequential downsampling. | |
| 3. Edge features at 512: Provides features at output resolution for fine | |
| structure preservation (cell walls, membrane patterns). | |
| """ | |
| def __init__(self, base_ch=32): | |
| super().__init__() | |
| # Fixed Sobel kernels for structural prior | |
| sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], | |
| dtype=torch.float32).view(1, 1, 3, 3) | |
| sobel_y = sobel_x.transpose(-1, -2) | |
| self.register_buffer('sobel_x', sobel_x) | |
| self.register_buffer('sobel_y', sobel_y) | |
| # Per-scale feature extractors | |
| # Input: 3ch RGB + 2ch Sobel = 5ch at each scale | |
| in_ch = 5 | |
| # 512β512 (edge features at output resolution) | |
| self.scale_512 = nn.Sequential( | |
| nn.Conv2d(in_ch, base_ch, 3, padding=1), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| nn.Conv2d(base_ch, base_ch, 3, padding=1), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| ) | |
| # 256Γ256 | |
| self.scale_256 = nn.Sequential( | |
| nn.Conv2d(in_ch, base_ch, 3, padding=1), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| nn.Conv2d(base_ch, base_ch, 3, padding=1), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| ) | |
| # 128Γ128 | |
| self.scale_128 = nn.Sequential( | |
| nn.Conv2d(in_ch, base_ch * 2, 3, padding=1), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| nn.Conv2d(base_ch * 2, base_ch * 2, 3, padding=1), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| ) | |
| # 64Γ64 | |
| self.scale_64 = nn.Sequential( | |
| nn.Conv2d(in_ch, base_ch * 4, 3, padding=1), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| nn.Conv2d(base_ch * 4, base_ch * 4, 3, padding=1), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| ) | |
| # 32Γ32 | |
| self.scale_32 = nn.Sequential( | |
| nn.Conv2d(in_ch, base_ch * 4, 3, padding=1), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| nn.Conv2d(base_ch * 4, base_ch * 4, 3, padding=1), | |
| nn.LeakyReLU(0.2, inplace=True), | |
| ) | |
| def _extract_edges_at_scale(self, he_01, size): | |
| """Downsample H&E, extract Sobel edges, return RGB+edges.""" | |
| if size < 512: | |
| h = F.interpolate(he_01, size=size, mode='bilinear', align_corners=False) | |
| else: | |
| h = he_01 | |
| gray = h.mean(dim=1, keepdim=True) | |
| gx = F.conv2d(gray, self.sobel_x, padding=1) | |
| gy = F.conv2d(gray, self.sobel_y, padding=1) | |
| return torch.cat([h, gx, gy], dim=1) # [B, 5, size, size] | |
| def forward(self, he_images): | |
| """ | |
| Args: | |
| he_images: [B, 3, 512, 512] in [-1, 1] | |
| Returns: | |
| dict of edge features at each decoder resolution: | |
| 512: [B, base_ch, 512, 512] | |
| 256: [B, base_ch, 256, 256] | |
| 128: [B, base_ch*2, 128, 128] | |
| 64: [B, base_ch*4, 64, 64] | |
| 32: [B, base_ch*4, 32, 32] | |
| """ | |
| he_01 = (he_images + 1) / 2 # [0, 1] for consistent edge magnitudes | |
| return { | |
| 512: self.scale_512(self._extract_edges_at_scale(he_01, 512)), | |
| 256: self.scale_256(self._extract_edges_at_scale(he_01, 256)), | |
| 128: self.scale_128(self._extract_edges_at_scale(he_01, 128)), | |
| 64: self.scale_64(self._extract_edges_at_scale(he_01, 64)), | |
| 32: self.scale_32(self._extract_edges_at_scale(he_01, 32)), | |
| } | |