| from __future__ import annotations |
|
|
| from typing import Dict, Optional |
|
|
| import torch |
| from torch import Tensor, nn |
|
|
| from .slide_transformer import VisionTransformer |
|
|
| __all__ = ["WSIEncoderHead"] |
|
|
|
|
| class WSIEncoderHead(nn.Module): |
| """Adapter around VisionTransformer with aggregation over patch tokens. |
| |
| Inputs: |
| - patch_features: [B, N, C] |
| - patch_mask: [B, N] with 1 for valid tokens (required for correct masking) |
| - patch_coords: optional [B, N, 2] integer coords for RoPE |
| |
| Returns: |
| - dict with exactly two keys: |
| - patch_embedding: [B, N, C_in + C] concat(raw_patch_features, transformer_patch_tokens) |
| - slide_embedding: [B, C_in + C] concat(masked_mean(raw_patch_features), masked_mean(transformer_patch_tokens)) |
| """ |
|
|
| def __init__( |
| self, |
| transformer: VisionTransformer, |
| input_dim: int, |
| embed_dim: int, |
| ) -> None: |
| super().__init__() |
| self.transformer = transformer |
| self.embed_dim = int(embed_dim) |
| self.input_dim = int(input_dim) |
|
|
| def _masked_mean(self, tokens: Tensor, mask: Optional[Tensor]) -> Tensor: |
| """Mask-aware mean over sequence dimension without fallback. |
| |
| - tokens: [B, N, C] |
| - mask: [B, N] with 1 valid, 0 invalid; when all invalid, returns zero-vector mean (sum=0, count=1) |
| """ |
| if mask is None: |
| return tokens.mean(dim=1) |
| valid = mask.to(dtype=tokens.dtype).unsqueeze(-1) |
| sums = (tokens * valid).sum(dim=1) |
| counts = valid.sum(dim=1).clamp_min(1.0) |
| return sums / counts |
|
|
| def forward( |
| self, |
| patch_features: Tensor, |
| patch_mask: Tensor, |
| patch_coords: Optional[Tensor] = None, |
| patch_contour_index: Optional[Tensor] = None, |
| ) -> Dict[str, Tensor]: |
| |
| if patch_mask is None: |
| raise ValueError("WSIFeatureEncoder requires patch_mask (shape [B, N]) to be provided.") |
|
|
| mask = patch_mask.to(device=patch_features.device) |
| |
| encoded = self.transformer( |
| patch_features, |
| masks=mask, |
| coords=patch_coords, |
| contour_index=patch_contour_index, |
| ) |
| patch_tokens = encoded["x_norm_patchtokens"] |
|
|
| |
| patch_embedding = torch.cat([patch_features, patch_tokens], dim=-1) |
|
|
| |
| raw_patch_mean = self._masked_mean(patch_features, mask) |
| token_mean = self._masked_mean(patch_tokens, mask) |
| slide_embedding = torch.cat([raw_patch_mean, token_mean], dim=-1) |
|
|
| return { |
| "patch_embedding": patch_embedding, |
| "slide_embedding": slide_embedding, |
| } |
|
|