| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | from typing import Callable, Dict, Iterable, List, NamedTuple, Optional, Tuple, Union |
| |
|
| | import torch |
| | from torch import nn |
| |
|
| | from timm.models import create_model, VisionTransformer |
| |
|
| | from .enable_cpe_support import enable_cpe |
| | from .input_conditioner import InputConditioner |
| | from .adaptor_base import AdaptorBase, RadioOutput, AdaptorInput |
| | from . import eradio_model |
| | from .enable_spectral_reparam import configure_spectral_reparam_from_args |
| | from .feature_normalizer import FeatureNormalizer, IntermediateFeatureNormalizer |
| |
|
| |
|
| | class Resolution(NamedTuple): |
| | height: int |
| | width: int |
| |
|
| |
|
| | class RADIOModel(nn.Module): |
| | def __init__( |
| | self, |
| | model: nn.Module, |
| | input_conditioner: InputConditioner, |
| | patch_size: int, |
| | max_resolution: int, |
| | preferred_resolution: Resolution, |
| | summary_idxs: Optional[torch.Tensor] = None, |
| | window_size: int = None, |
| | adaptors: Dict[str, AdaptorBase] = None, |
| | feature_normalizer: Optional[FeatureNormalizer] = None, |
| | inter_feature_normalizer: Optional[IntermediateFeatureNormalizer] = None, |
| | ): |
| | super().__init__() |
| |
|
| | self.model = model |
| | self.input_conditioner = input_conditioner |
| | if summary_idxs is not None: |
| | self.register_buffer('summary_idxs', summary_idxs) |
| | else: |
| | self.summary_idxs = None |
| |
|
| | self._preferred_resolution = preferred_resolution |
| | self._patch_size = patch_size |
| | self._max_resolution = max_resolution |
| | self._window_size = window_size |
| |
|
| | adaptors = adaptors or dict() |
| | self.adaptors = nn.ModuleDict(adaptors) |
| |
|
| | if feature_normalizer is None: |
| | feature_normalizer = nn.Identity() |
| | self.feature_normalizer = feature_normalizer |
| | self.inter_feature_normalizer = inter_feature_normalizer |
| |
|
| | @property |
| | def num_summary_tokens(self) -> int: |
| | if hasattr(self.model, 'num_summary_tokens'): |
| | return self.model.num_summary_tokens |
| |
|
| | patch_gen = getattr(self.model, "patch_generator", None) |
| | if patch_gen is not None: |
| | return patch_gen.num_skip |
| | elif self.model.global_pool == 'avg': |
| | return 0 |
| | return 1 |
| |
|
| | @property |
| | def num_cls_tokens(self) -> int: |
| | if hasattr(self.model, 'num_cls_tokens'): |
| | return self.model.num_cls_tokens |
| |
|
| | patch_gen = getattr(self.model, 'patch_generator', None) |
| | if patch_gen is not None: |
| | return patch_gen.num_cls_tokens |
| | elif self.model.global_pool == 'avg': |
| | return 0 |
| | return 1 |
| |
|
| | @property |
| | def patch_size(self) -> int: |
| | if self._patch_size is not None: |
| | return self._patch_size |
| | if hasattr(self.model, "patch_size"): |
| | return self.model.patch_size |
| | patch_gen = getattr(self.model, "patch_generator", None) |
| | if patch_gen is not None: |
| | return patch_gen.patch_size |
| | return None |
| |
|
| | @property |
| | def max_resolution(self) -> int: |
| | return self._max_resolution |
| |
|
| | @property |
| | def preferred_resolution(self) -> Resolution: |
| | return self._preferred_resolution |
| |
|
| | @property |
| | def window_size(self) -> int: |
| | return self._window_size |
| |
|
| | @property |
| | def min_resolution_step(self) -> int: |
| | res = self.patch_size |
| | if self.window_size is not None: |
| | res *= self.window_size |
| | return res |
| |
|
| | @property |
| | def blocks(self) -> Iterable[nn.Module]: |
| | blocks = getattr(self.model, 'blocks', None) |
| | if blocks is not None: |
| | return blocks |
| | return None |
| |
|
| | @property |
| | def embed_dim(self) -> int: |
| | return self.model.embed_dim |
| |
|
| | def make_preprocessor_external(self) -> Callable[[torch.Tensor], torch.Tensor]: |
| | ret = self.input_conditioner |
| | self.input_conditioner = nn.Identity() |
| | return ret |
| |
|
| | def get_nearest_supported_resolution(self, height: int, width: int) -> Resolution: |
| | height = int(round(height / self.min_resolution_step) * self.min_resolution_step) |
| | width = int(round(width / self.min_resolution_step) * self.min_resolution_step) |
| |
|
| | height = max(height, self.min_resolution_step) |
| | width = max(width, self.min_resolution_step) |
| |
|
| | return Resolution(height=height, width=width) |
| |
|
| | def switch_to_deploy(self): |
| | fn = getattr(self.model, 'switch_to_deploy', None) |
| | if fn is not None: |
| | fn() |
| |
|
| | def forward(self, x: torch.Tensor, feature_fmt: str = 'NLC') -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: |
| | ''' |
| | Forward process for model. |
| | Args: |
| | x: Input tensor. Unless `make_preprocessor_external` has been called, then the dynamic range of `x` is expected to be `[0, 1]`, |
| | otherwise `x` is expected to be mean centered with unit standard deviation. |
| | feature_format: ['NLC', 'NCHW'] - The output format for the features. |
| | ''' |
| | res_step = self.min_resolution_step |
| | if res_step is not None and (x.shape[-2] % res_step != 0 or x.shape[-1] % res_step != 0): |
| | raise ValueError('The input resolution must be a multiple of `self.min_resolution_step`. ' |
| | '`self.get_nearest_supported_resolution(<height>, <width>) is provided as a convenience API. ' |
| | f'Input: {x.shape[-2:]}, Nearest: {self.get_nearest_supported_resolution(*x.shape[-2:])}') |
| |
|
| | x = self.input_conditioner(x) |
| | y = self.model.forward_features(x) |
| | ret = self._extract_final(x, y, feature_fmt=feature_fmt) |
| | return ret |
| |
|
| | def _extract_final(self, x: torch.Tensor, y: torch.Tensor, feature_fmt: str = 'NLC'): |
| | if isinstance(self.model, VisionTransformer): |
| | patch_gen = getattr(self.model, "patch_generator", None) |
| | if patch_gen is not None: |
| | all_summary = y[:, : patch_gen.num_cls_tokens] |
| | if self.summary_idxs is not None: |
| | bb_summary = all_summary[:, self.summary_idxs] |
| | else: |
| | bb_summary = all_summary |
| | all_feat = y[:, patch_gen.num_skip :] |
| | elif self.model.global_pool == "avg": |
| | all_summary = y[:, self.model.num_prefix_tokens :].mean(dim=1) |
| | bb_summary = all_summary |
| | all_feat = y |
| | else: |
| | all_summary = y[:, 0] |
| | bb_summary = all_summary |
| | all_feat = y[:, 1:] |
| | elif isinstance(self.model, eradio_model.ERADIO): |
| | _, f = y |
| | all_feat = f.flatten(2).transpose(1, 2) |
| | all_summary = all_feat.mean(dim=1) |
| | bb_summary = all_summary |
| | elif isinstance(y, (list, tuple)): |
| | all_summary, all_feat = y |
| | bb_summary = all_summary |
| | else: |
| | all_summary = y[:, :self.num_cls_tokens] |
| | if self.summary_idxs is not None and all_summary.shape[1] > 1: |
| | if all_summary.shape[1] == 1: |
| | |
| | all_summary = all_summary.expand(-1, 128, -1) |
| | bb_summary = all_summary[:, self.summary_idxs] |
| | else: |
| | bb_summary = all_summary |
| | all_feat = y[:, self.num_summary_tokens:] |
| |
|
| | all_feat = self.feature_normalizer(all_feat) |
| |
|
| | if feature_fmt == 'NCHW': |
| | fmt_feat = (all_feat.reshape(all_feat.shape[0], x.shape[-2] // self.patch_size, x.shape[-1] // self.patch_size, all_feat.shape[2]) |
| | .permute(0, 3, 1, 2) |
| | ) |
| | elif feature_fmt == 'NLC': |
| | fmt_feat = all_feat |
| | else: |
| | raise ValueError(f'Unsupported feature_fmt: {feature_fmt}. Must be one of ["NLC", "NCHW"]') |
| |
|
| | ret = RadioOutput(bb_summary.flatten(1), fmt_feat) |
| |
|
| | if self.adaptors: |
| | ret = dict(backbone=ret) |
| | for name, adaptor in self.adaptors.items(): |
| | if all_summary.ndim == 3: |
| | summary = all_summary[:, adaptor.head_idx] |
| | else: |
| | summary = all_summary |
| | ada_input = AdaptorInput(images=x, summary=summary.float(), features=all_feat, feature_fmt=feature_fmt, patch_size=self.patch_size) |
| | v = adaptor(ada_input).to(torch.float32) |
| | ret[name] = v |
| |
|
| | return ret |
| |
|
| | def forward_intermediates( |
| | self, |
| | x: torch.Tensor, |
| | indices: Optional[Union[int, List[int], Tuple[int]]] = None, |
| | return_prefix_tokens: bool = False, |
| | norm: bool = False, |
| | stop_early: bool = False, |
| | output_fmt: str = 'NCHW', |
| | intermediates_only: bool = False, |
| | aggregation: Optional[str] = "sparse", |
| | norm_alpha_scheme: Optional[str] = "post-alpha", |
| | ) -> List[RadioOutput]: |
| | """ Forward features that returns intermediates. |
| | Args: |
| | x: Input image tensor |
| | indices: Take last n blocks if int, select matching indices if sequence |
| | return_prefix_tokens: Return both prefix and spatial intermediate tokens |
| | norm: Apply norm layer to all intermediates |
| | stop_early: Stop iterating over blocks when last desired intermediate hit |
| | output_fmt: Shape of intermediate feature outputs. Options: NCHW, NLC |
| | intermediates_only: Only return intermediate features |
| | aggregation: intermediate layer aggregation method (sparse or dense). |
| | Dense accumulation is done by averaging the features in each group. |
| | norm_alpha_scheme: apply alpha before ("pre-alpha") or after accumulation ("post-alpha"), or don't normalize ("none") |
| | Only affects dense aggregation |
| | Returns: |
| | List of RadioOutput objects. |
| | """ |
| | x = self.input_conditioner(x) |
| | intermediates = self.model.forward_intermediates( |
| | x, |
| | indices=indices, |
| | return_prefix_tokens=return_prefix_tokens, |
| | norm=norm, |
| | stop_early=stop_early, |
| | output_fmt=output_fmt, |
| | intermediates_only=intermediates_only, |
| | aggregation=aggregation, |
| | inter_feature_normalizer=self.inter_feature_normalizer, |
| | norm_alpha_scheme=norm_alpha_scheme, |
| | ) |
| |
|
| | if not intermediates_only: |
| | final, intermediates = intermediates |
| |
|
| | def prepare_summary(summ: Optional[torch.Tensor]): |
| | if summ is None: |
| | return summ |
| | if self.summary_idxs is not None and summ.shape[1] > 1: |
| | summ = summ[:, self.summary_idxs] |
| | return summ.flatten(1) |
| |
|
| | if return_prefix_tokens: |
| | radio_outputs = [ |
| | RadioOutput(prepare_summary(summary), features) |
| | for summary, features in intermediates |
| | ] |
| | else: |
| | radio_outputs = intermediates |
| |
|
| | if intermediates_only: |
| | return radio_outputs |
| | else: |
| | final = self._extract_final(x, final, feature_fmt=output_fmt) |
| | return final, radio_outputs |
| |
|
| |
|
| | def create_model_from_args(args) -> nn.Module: |
| | in_chans = 3 |
| | if args.in_chans is not None: |
| | in_chans = args.in_chans |
| | elif args.input_size is not None: |
| | in_chans = args.input_size[0] |
| |
|
| | |
| | weight_init = args.model_kwargs.pop("weight_init", "skip") |
| |
|
| | model = create_model( |
| | args.model, |
| | pretrained=args.pretrained, |
| | in_chans=in_chans, |
| | num_classes=args.num_classes, |
| | drop_rate=args.drop, |
| | drop_path_rate=args.drop_path, |
| | drop_block_rate=args.drop_block, |
| | global_pool=args.gp, |
| | bn_momentum=args.bn_momentum, |
| | bn_eps=args.bn_eps, |
| | scriptable=args.torchscript, |
| | checkpoint_path=args.initial_checkpoint, |
| | weight_init=weight_init, |
| | **args.model_kwargs, |
| | ) |
| |
|
| | if hasattr(model, 'norm') and not getattr(args, 'model_norm', False): |
| | model.norm = nn.Identity() |
| |
|
| | model.head = nn.Identity() |
| |
|
| | assert ( |
| | not args.cls_token_per_teacher or args.cpe_max_size is not None |
| | ), "CPE must be enabled for multiple CLS tokens!" |
| |
|
| | if args.cpe_max_size is not None: |
| | uq_teachers = set(t['name'] for t in args.teachers) |
| | enable_cpe( |
| | model, |
| | args.cpe_max_size, |
| | num_cls_tokens=len(uq_teachers) if args.cls_token_per_teacher else 1, |
| | register_multiple=getattr(args, 'register_multiple', None), |
| | num_registers=getattr(args, 'cpe_num_registers', None), |
| | ) |
| |
|
| | return model |
| |
|