| | |
| | |
| | |
| | |
| | from typing import Union, Tuple |
| |
|
| | import copy |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| |
|
| | __all__ = ["MobileOneBlock", "reparameterize_model"] |
| |
|
| |
|
| | class SEBlock(nn.Module): |
| | """Squeeze and Excite module. |
| | |
| | Pytorch implementation of `Squeeze-and-Excitation Networks` - |
| | https://arxiv.org/pdf/1709.01507.pdf |
| | """ |
| |
|
| | def __init__(self, in_channels: int, rd_ratio: float = 0.0625) -> None: |
| | """Construct a Squeeze and Excite Module. |
| | |
| | Args: |
| | in_channels: Number of input channels. |
| | rd_ratio: Input channel reduction ratio. |
| | """ |
| | super(SEBlock, self).__init__() |
| | self.reduce = nn.Conv2d( |
| | in_channels=in_channels, |
| | out_channels=int(in_channels * rd_ratio), |
| | kernel_size=1, |
| | stride=1, |
| | bias=True, |
| | ) |
| | self.expand = nn.Conv2d( |
| | in_channels=int(in_channels * rd_ratio), |
| | out_channels=in_channels, |
| | kernel_size=1, |
| | stride=1, |
| | bias=True, |
| | ) |
| |
|
| | def forward(self, inputs: torch.Tensor) -> torch.Tensor: |
| | """Apply forward pass.""" |
| | b, c, h, w = inputs.size() |
| | x = F.avg_pool2d(inputs, kernel_size=[h, w]) |
| | x = self.reduce(x) |
| | x = F.relu(x) |
| | x = self.expand(x) |
| | x = torch.sigmoid(x) |
| | x = x.view(-1, c, 1, 1) |
| | return inputs * x |
| |
|
| |
|
| | class MobileOneBlock(nn.Module): |
| | """MobileOne building block. |
| | |
| | This block has a multi-branched architecture at train-time |
| | and plain-CNN style architecture at inference time |
| | For more details, please refer to our paper: |
| | `An Improved One millisecond Mobile Backbone` - |
| | https://arxiv.org/pdf/2206.04040.pdf |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | in_channels: int, |
| | out_channels: int, |
| | kernel_size: int, |
| | stride: int = 1, |
| | padding: int = 0, |
| | dilation: int = 1, |
| | groups: int = 1, |
| | inference_mode: bool = False, |
| | use_se: bool = False, |
| | use_act: bool = True, |
| | use_scale_branch: bool = True, |
| | num_conv_branches: int = 1, |
| | activation: nn.Module = nn.GELU(), |
| | ) -> None: |
| | """Construct a MobileOneBlock module. |
| | |
| | Args: |
| | in_channels: Number of channels in the input. |
| | out_channels: Number of channels produced by the block. |
| | kernel_size: Size of the convolution kernel. |
| | stride: Stride size. |
| | padding: Zero-padding size. |
| | dilation: Kernel dilation factor. |
| | groups: Group number. |
| | inference_mode: If True, instantiates model in inference mode. |
| | use_se: Whether to use SE-ReLU activations. |
| | use_act: Whether to use activation. Default: ``True`` |
| | use_scale_branch: Whether to use scale branch. Default: ``True`` |
| | num_conv_branches: Number of linear conv branches. |
| | """ |
| | super(MobileOneBlock, self).__init__() |
| | self.inference_mode = inference_mode |
| | self.groups = groups |
| | self.stride = stride |
| | self.padding = padding |
| | self.dilation = dilation |
| | self.kernel_size = kernel_size |
| | self.in_channels = in_channels |
| | self.out_channels = out_channels |
| | self.num_conv_branches = num_conv_branches |
| |
|
| | |
| | if use_se: |
| | self.se = SEBlock(out_channels) |
| | else: |
| | self.se = nn.Identity() |
| |
|
| | if use_act: |
| | self.activation = activation |
| | else: |
| | self.activation = nn.Identity() |
| |
|
| | if inference_mode: |
| | self.reparam_conv = nn.Conv2d( |
| | in_channels=in_channels, |
| | out_channels=out_channels, |
| | kernel_size=kernel_size, |
| | stride=stride, |
| | padding=padding, |
| | dilation=dilation, |
| | groups=groups, |
| | bias=True, |
| | ) |
| | else: |
| | |
| | self.rbr_skip = ( |
| | nn.BatchNorm2d(num_features=in_channels) |
| | if out_channels == in_channels and stride == 1 |
| | else None |
| | ) |
| |
|
| | |
| | if num_conv_branches > 0: |
| | rbr_conv = list() |
| | for _ in range(self.num_conv_branches): |
| | rbr_conv.append( |
| | self._conv_bn(kernel_size=kernel_size, padding=padding) |
| | ) |
| | self.rbr_conv = nn.ModuleList(rbr_conv) |
| | else: |
| | self.rbr_conv = None |
| |
|
| | |
| | self.rbr_scale = None |
| | if not isinstance(kernel_size, int): |
| | kernel_size = kernel_size[0] |
| | if (kernel_size > 1) and use_scale_branch: |
| | self.rbr_scale = self._conv_bn(kernel_size=1, padding=0) |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | """Apply forward pass.""" |
| | |
| | if self.inference_mode: |
| | return self.activation(self.se(self.reparam_conv(x))) |
| |
|
| | |
| | |
| | identity_out = 0 |
| | if self.rbr_skip is not None: |
| | identity_out = self.rbr_skip(x) |
| |
|
| | |
| | scale_out = 0 |
| | if self.rbr_scale is not None: |
| | scale_out = self.rbr_scale(x) |
| |
|
| | |
| | out = scale_out + identity_out |
| | if self.rbr_conv is not None: |
| | for ix in range(self.num_conv_branches): |
| | out += self.rbr_conv[ix](x) |
| |
|
| | return self.activation(self.se(out)) |
| |
|
| | def reparameterize(self): |
| | """Following works like `RepVGG: Making VGG-style ConvNets Great Again` - |
| | https://arxiv.org/pdf/2101.03697.pdf. We re-parameterize multi-branched |
| | architecture used at training time to obtain a plain CNN-like structure |
| | for inference. |
| | """ |
| | if self.inference_mode: |
| | return |
| | kernel, bias = self._get_kernel_bias() |
| | self.reparam_conv = nn.Conv2d( |
| | in_channels=self.in_channels, |
| | out_channels=self.out_channels, |
| | kernel_size=self.kernel_size, |
| | stride=self.stride, |
| | padding=self.padding, |
| | dilation=self.dilation, |
| | groups=self.groups, |
| | bias=True, |
| | ) |
| | self.reparam_conv.weight.data = kernel |
| | self.reparam_conv.bias.data = bias |
| |
|
| | |
| | for para in self.parameters(): |
| | para.detach_() |
| | self.__delattr__("rbr_conv") |
| | self.__delattr__("rbr_scale") |
| | if hasattr(self, "rbr_skip"): |
| | self.__delattr__("rbr_skip") |
| |
|
| | self.inference_mode = True |
| |
|
| | def _get_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]: |
| | """Method to obtain re-parameterized kernel and bias. |
| | Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L83 |
| | |
| | Returns: |
| | Tuple of (kernel, bias) after fusing branches. |
| | """ |
| | |
| | kernel_scale = 0 |
| | bias_scale = 0 |
| | if self.rbr_scale is not None: |
| | kernel_scale, bias_scale = self._fuse_bn_tensor(self.rbr_scale) |
| | |
| | pad = self.kernel_size // 2 |
| | kernel_scale = torch.nn.functional.pad(kernel_scale, [pad, pad, pad, pad]) |
| |
|
| | |
| | kernel_identity = 0 |
| | bias_identity = 0 |
| | if self.rbr_skip is not None: |
| | kernel_identity, bias_identity = self._fuse_bn_tensor(self.rbr_skip) |
| |
|
| | |
| | kernel_conv = 0 |
| | bias_conv = 0 |
| | if self.rbr_conv is not None: |
| | for ix in range(self.num_conv_branches): |
| | _kernel, _bias = self._fuse_bn_tensor(self.rbr_conv[ix]) |
| | kernel_conv += _kernel |
| | bias_conv += _bias |
| |
|
| | kernel_final = kernel_conv + kernel_scale + kernel_identity |
| | bias_final = bias_conv + bias_scale + bias_identity |
| | return kernel_final, bias_final |
| |
|
| | def _fuse_bn_tensor( |
| | self, branch: Union[nn.Sequential, nn.BatchNorm2d] |
| | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | """Method to fuse batchnorm layer with preceeding conv layer. |
| | Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L95 |
| | |
| | Args: |
| | branch: Sequence of ops to be fused. |
| | |
| | Returns: |
| | Tuple of (kernel, bias) after fusing batchnorm. |
| | """ |
| | if isinstance(branch, nn.Sequential): |
| | kernel = branch.conv.weight |
| | running_mean = branch.bn.running_mean |
| | running_var = branch.bn.running_var |
| | gamma = branch.bn.weight |
| | beta = branch.bn.bias |
| | eps = branch.bn.eps |
| | else: |
| | assert isinstance(branch, nn.BatchNorm2d) |
| | if not hasattr(self, "id_tensor"): |
| | input_dim = self.in_channels // self.groups |
| |
|
| | kernel_size = self.kernel_size |
| | if isinstance(self.kernel_size, int): |
| | kernel_size = (self.kernel_size, self.kernel_size) |
| |
|
| | kernel_value = torch.zeros( |
| | (self.in_channels, input_dim, kernel_size[0], kernel_size[1]), |
| | dtype=branch.weight.dtype, |
| | device=branch.weight.device, |
| | ) |
| | for i in range(self.in_channels): |
| | kernel_value[ |
| | i, i % input_dim, kernel_size[0] // 2, kernel_size[1] // 2 |
| | ] = 1 |
| | self.id_tensor = kernel_value |
| | kernel = self.id_tensor |
| | running_mean = branch.running_mean |
| | running_var = branch.running_var |
| | gamma = branch.weight |
| | beta = branch.bias |
| | eps = branch.eps |
| | std = (running_var + eps).sqrt() |
| | t = (gamma / std).reshape(-1, 1, 1, 1) |
| | return kernel * t, beta - running_mean * gamma / std |
| |
|
| | def _conv_bn(self, kernel_size: int, padding: int) -> nn.Sequential: |
| | """Helper method to construct conv-batchnorm layers. |
| | |
| | Args: |
| | kernel_size: Size of the convolution kernel. |
| | padding: Zero-padding size. |
| | |
| | Returns: |
| | Conv-BN module. |
| | """ |
| | mod_list = nn.Sequential() |
| | mod_list.add_module( |
| | "conv", |
| | nn.Conv2d( |
| | in_channels=self.in_channels, |
| | out_channels=self.out_channels, |
| | kernel_size=kernel_size, |
| | stride=self.stride, |
| | padding=padding, |
| | groups=self.groups, |
| | bias=False, |
| | ), |
| | ) |
| | mod_list.add_module("bn", nn.BatchNorm2d(num_features=self.out_channels)) |
| | return mod_list |
| |
|
| |
|
| | def reparameterize_model(model: torch.nn.Module) -> nn.Module: |
| | """Method returns a model where a multi-branched structure |
| | used in training is re-parameterized into a single branch |
| | for inference. |
| | |
| | Args: |
| | model: MobileOne model in train mode. |
| | |
| | Returns: |
| | MobileOne model in inference mode. |
| | """ |
| | |
| | model = copy.deepcopy(model) |
| | for module in model.modules(): |
| | if hasattr(module, "reparameterize"): |
| | module.reparameterize() |
| | return model |