| | |
| | from typing import List, Tuple |
| |
|
| | import torch |
| | import torch.nn.functional as F |
| | from mmcv.cnn import ConvModule |
| | from mmengine.model import BaseModule |
| |
|
| | from mmdet.registry import MODELS |
| | from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig |
| |
|
| |
|
| | class SSHContextModule(BaseModule): |
| | """This is an implementation of `SSH context module` described in `SSH: |
| | Single Stage Headless Face Detector. |
| | |
| | <https://arxiv.org/pdf/1708.03979.pdf>`_. |
| | |
| | Args: |
| | in_channels (int): Number of input channels used at each scale. |
| | out_channels (int): Number of output channels used at each scale. |
| | conv_cfg (:obj:`ConfigDict` or dict, optional): Config dict for |
| | convolution layer. Defaults to None. |
| | norm_cfg (:obj:`ConfigDict` or dict): Config dict for normalization |
| | layer. Defaults to dict(type='BN'). |
| | init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or |
| | list[dict], optional): Initialization config dict. |
| | Defaults to None. |
| | """ |
| |
|
| | def __init__(self, |
| | in_channels: int, |
| | out_channels: int, |
| | conv_cfg: OptConfigType = None, |
| | norm_cfg: ConfigType = dict(type='BN'), |
| | init_cfg: OptMultiConfig = None): |
| | super().__init__(init_cfg=init_cfg) |
| | assert out_channels % 4 == 0 |
| |
|
| | self.in_channels = in_channels |
| | self.out_channels = out_channels |
| |
|
| | self.conv5x5_1 = ConvModule( |
| | self.in_channels, |
| | self.out_channels // 4, |
| | 3, |
| | stride=1, |
| | padding=1, |
| | conv_cfg=conv_cfg, |
| | norm_cfg=norm_cfg, |
| | ) |
| |
|
| | self.conv5x5_2 = ConvModule( |
| | self.out_channels // 4, |
| | self.out_channels // 4, |
| | 3, |
| | stride=1, |
| | padding=1, |
| | conv_cfg=conv_cfg, |
| | norm_cfg=norm_cfg, |
| | act_cfg=None) |
| |
|
| | self.conv7x7_2 = ConvModule( |
| | self.out_channels // 4, |
| | self.out_channels // 4, |
| | 3, |
| | stride=1, |
| | padding=1, |
| | conv_cfg=conv_cfg, |
| | norm_cfg=norm_cfg, |
| | ) |
| |
|
| | self.conv7x7_3 = ConvModule( |
| | self.out_channels // 4, |
| | self.out_channels // 4, |
| | 3, |
| | stride=1, |
| | padding=1, |
| | conv_cfg=conv_cfg, |
| | norm_cfg=norm_cfg, |
| | act_cfg=None, |
| | ) |
| |
|
| | def forward(self, x: torch.Tensor) -> tuple: |
| | conv5x5_1 = self.conv5x5_1(x) |
| | conv5x5 = self.conv5x5_2(conv5x5_1) |
| | conv7x7_2 = self.conv7x7_2(conv5x5_1) |
| | conv7x7 = self.conv7x7_3(conv7x7_2) |
| |
|
| | return (conv5x5, conv7x7) |
| |
|
| |
|
| | class SSHDetModule(BaseModule): |
| | """This is an implementation of `SSH detection module` described in `SSH: |
| | Single Stage Headless Face Detector. |
| | |
| | <https://arxiv.org/pdf/1708.03979.pdf>`_. |
| | |
| | Args: |
| | in_channels (int): Number of input channels used at each scale. |
| | out_channels (int): Number of output channels used at each scale. |
| | conv_cfg (:obj:`ConfigDict` or dict, optional): Config dict for |
| | convolution layer. Defaults to None. |
| | norm_cfg (:obj:`ConfigDict` or dict): Config dict for normalization |
| | layer. Defaults to dict(type='BN'). |
| | init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or |
| | list[dict], optional): Initialization config dict. |
| | Defaults to None. |
| | """ |
| |
|
| | def __init__(self, |
| | in_channels: int, |
| | out_channels: int, |
| | conv_cfg: OptConfigType = None, |
| | norm_cfg: ConfigType = dict(type='BN'), |
| | init_cfg: OptMultiConfig = None): |
| | super().__init__(init_cfg=init_cfg) |
| | assert out_channels % 4 == 0 |
| |
|
| | self.in_channels = in_channels |
| | self.out_channels = out_channels |
| |
|
| | self.conv3x3 = ConvModule( |
| | self.in_channels, |
| | self.out_channels // 2, |
| | 3, |
| | stride=1, |
| | padding=1, |
| | conv_cfg=conv_cfg, |
| | norm_cfg=norm_cfg, |
| | act_cfg=None) |
| |
|
| | self.context_module = SSHContextModule( |
| | in_channels=self.in_channels, |
| | out_channels=self.out_channels, |
| | conv_cfg=conv_cfg, |
| | norm_cfg=norm_cfg) |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | conv3x3 = self.conv3x3(x) |
| | conv5x5, conv7x7 = self.context_module(x) |
| | out = torch.cat([conv3x3, conv5x5, conv7x7], dim=1) |
| | out = F.relu(out) |
| |
|
| | return out |
| |
|
| |
|
| | @MODELS.register_module() |
| | class SSH(BaseModule): |
| | """`SSH Neck` used in `SSH: Single Stage Headless Face Detector. |
| | |
| | <https://arxiv.org/pdf/1708.03979.pdf>`_. |
| | |
| | Args: |
| | num_scales (int): The number of scales / stages. |
| | in_channels (list[int]): The number of input channels per scale. |
| | out_channels (list[int]): The number of output channels per scale. |
| | conv_cfg (:obj:`ConfigDict` or dict, optional): Config dict for |
| | convolution layer. Defaults to None. |
| | norm_cfg (:obj:`ConfigDict` or dict): Config dict for normalization |
| | layer. Defaults to dict(type='BN'). |
| | init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or |
| | list[dict], optional): Initialization config dict. |
| | |
| | Example: |
| | >>> import torch |
| | >>> in_channels = [8, 16, 32, 64] |
| | >>> out_channels = [16, 32, 64, 128] |
| | >>> scales = [340, 170, 84, 43] |
| | >>> inputs = [torch.rand(1, c, s, s) |
| | ... for c, s in zip(in_channels, scales)] |
| | >>> self = SSH(num_scales=4, in_channels=in_channels, |
| | ... out_channels=out_channels) |
| | >>> outputs = self.forward(inputs) |
| | >>> for i in range(len(outputs)): |
| | ... print(f'outputs[{i}].shape = {outputs[i].shape}') |
| | outputs[0].shape = torch.Size([1, 16, 340, 340]) |
| | outputs[1].shape = torch.Size([1, 32, 170, 170]) |
| | outputs[2].shape = torch.Size([1, 64, 84, 84]) |
| | outputs[3].shape = torch.Size([1, 128, 43, 43]) |
| | """ |
| |
|
| | def __init__(self, |
| | num_scales: int, |
| | in_channels: List[int], |
| | out_channels: List[int], |
| | conv_cfg: OptConfigType = None, |
| | norm_cfg: ConfigType = dict(type='BN'), |
| | init_cfg: OptMultiConfig = dict( |
| | type='Xavier', layer='Conv2d', distribution='uniform')): |
| | super().__init__(init_cfg=init_cfg) |
| | assert (num_scales == len(in_channels) == len(out_channels)) |
| | self.num_scales = num_scales |
| | self.in_channels = in_channels |
| | self.out_channels = out_channels |
| |
|
| | for idx in range(self.num_scales): |
| | in_c, out_c = self.in_channels[idx], self.out_channels[idx] |
| | self.add_module( |
| | f'ssh_module{idx}', |
| | SSHDetModule( |
| | in_channels=in_c, |
| | out_channels=out_c, |
| | conv_cfg=conv_cfg, |
| | norm_cfg=norm_cfg)) |
| |
|
| | def forward(self, inputs: Tuple[torch.Tensor]) -> tuple: |
| | assert len(inputs) == self.num_scales |
| |
|
| | outs = [] |
| | for idx, x in enumerate(inputs): |
| | ssh_module = getattr(self, f'ssh_module{idx}') |
| | out = ssh_module(x) |
| | outs.append(out) |
| |
|
| | return tuple(outs) |
| |
|