| from __future__ import annotations |
|
|
| from typing import Any, Dict, Optional, Tuple |
|
|
| from torch import Tensor, nn |
| from huggingface_hub import PyTorchModelHubMixin |
| from omegaconf import DictConfig, OmegaConf |
|
|
| from .slide_transformer import VisionTransformer |
| from .slide_encoder_head import WSIEncoderHead |
|
|
|
|
| def _build_wsi_encoder(wsi_cfg: DictConfig) -> Tuple[nn.Module, int]: |
| """Construct a WSIFeatureEncoder composed of a VisionTransformer. |
| |
| This is a minimal, WSI-only factory equivalent to |
| ``MultiModalMetaModel._build_wsi_encoder`` but without importing the |
| full multimodal meta model. |
| """ |
|
|
| embed_dim = int(wsi_cfg.get("embed_dim", 768)) |
| input_dim = int(wsi_cfg.get("input_dim", 768)) |
|
|
| transformer_kwargs = { |
| "input_dim": input_dim, |
| "patch_size": int(wsi_cfg.get("patch_size", 256)), |
| "embed_use_norm": bool(wsi_cfg.get("embed_use_norm", True)), |
| "embed_dim": embed_dim, |
| "depth": int(wsi_cfg.get("depth", 12)), |
| "num_heads": int(wsi_cfg.get("num_heads", 12)), |
| "ffn_ratio": float(wsi_cfg.get("ffn_ratio", 4.0)), |
| "qkv_bias": bool(wsi_cfg.get("qkv_bias", True)), |
| "norm_layer": wsi_cfg.get("norm_layer", "layernorm"), |
| "ffn_layer": wsi_cfg.get("ffn_layer", "swiglu128"), |
| "ffn_bias": bool(wsi_cfg.get("ffn_bias", True)), |
| "proj_bias": bool(wsi_cfg.get("proj_bias", True)), |
| "ffn_drop": float(wsi_cfg.get("ffn_drop", 0.0)), |
| "attn_drop": float(wsi_cfg.get("attn_drop", 0.0)), |
| "n_storage_tokens": int(wsi_cfg.get("n_storage_tokens", 0)), |
| "nope_interval": int(wsi_cfg.get("nope_interval", 2)), |
| |
| "pos_embed_rope_base": wsi_cfg.get("pos_embed_rope_base", 10000.0), |
| "pos_embed_rope_min_period": wsi_cfg.get("pos_embed_rope_min_period"), |
| "pos_embed_rope_max_period": wsi_cfg.get("pos_embed_rope_max_period"), |
| "pos_embed_rope_dtype": wsi_cfg.get("pos_embed_rope_dtype", "fp32"), |
| } |
|
|
| |
| transformer = VisionTransformer(**transformer_kwargs) |
| if hasattr(transformer, "init_weights"): |
| transformer.init_weights() |
|
|
| wsi_encoder = WSIEncoderHead( |
| transformer, |
| input_dim, |
| embed_dim, |
| ) |
|
|
| return wsi_encoder |
|
|
|
|
| class WSIEncoder(nn.Module, PyTorchModelHubMixin): |
| """WSI slide-level encoder wrapper with Hugging Face Hub support. |
| |
| This wraps the internal :class:`WSIFeatureEncoder` (ViT + aggregation) |
| used in EXAONE-Path for slide-level feature extraction, and exposes it |
| as a Hub-compatible model via :class:`PyTorchModelHubMixin`. |
| |
| The minimal configuration (``wsi_cfg``) |
| is stored on the instance so that it can be serialized to ``config.json`` |
| when calling :meth:`save_pretrained`. |
| """ |
|
|
| def __init__( |
| self, |
| *, |
| wsi_cfg: Dict[str, Any], |
| ) -> None: |
| super().__init__() |
|
|
| |
| self.wsi_cfg: Dict[str, Any] = dict(wsi_cfg) |
|
|
| |
| cfg_obj: DictConfig = OmegaConf.create(self.wsi_cfg) |
| if isinstance(cfg_obj, DictConfig): |
| OmegaConf.resolve(cfg_obj) |
|
|
| wsi_encoder = _build_wsi_encoder(cfg_obj) |
| self.wsi_encoder = wsi_encoder |
|
|
| def forward( |
| self, |
| patch_features: Tensor, |
| patch_mask: Tensor, |
| patch_coords: Optional[Tensor] = None, |
| patch_contour_index: Optional[Tensor] = None, |
| ) -> Dict[str, Tensor]: |
| """Forward to underlying WSIFeatureEncoder. |
| |
| Args: |
| patch_features: [B, N, C] |
| patch_mask: [B, N] with 1 for valid tokens |
| patch_coords: optional [B, N, 2] coords (for RoPE) |
| patch_contour_index: optional [B, N] contour indices |
| """ |
|
|
| return self.wsi_encoder( |
| patch_features=patch_features, |
| patch_mask=patch_mask, |
| patch_coords=patch_coords, |
| patch_contour_index=patch_contour_index, |
| ) |
|
|
| |
| @classmethod |
| def from_wsi_config( |
| cls, |
| wsi_cfg: Dict[str, Any], |
| ) -> WSIEncoder: |
| return cls(wsi_cfg=wsi_cfg) |