| import torch |
| import torch.nn as nn |
|
|
| class GroupChannelShuffle(nn.Module): |
| """ |
| group-based channel shuffle / interleave. |
| groups: number of source groups you want to interleave (e.g. 4 for c1..c4) |
| optional cyclic shift (percent) to add deterministic rotation after shuffle. |
| """ |
| def __init__(self, groups: int = 4, cyclic_percent: float = 0.0): |
| super().__init__() |
| assert groups >= 1 |
| self.groups = groups |
| self.cyclic_percent = cyclic_percent |
|
|
| def forward(self, x): |
| |
| B, C, H, W = x.shape |
| g = self.groups |
| assert C % g == 0, f"channels {C} not divisible by groups {g}" |
| gc = C // g |
| |
| x = x.view(B, g, gc, H, W) |
| |
| x = x.transpose(1, 2).contiguous() |
| x = x.view(B, C, H, W) |
| |
| if self.cyclic_percent and 0 < self.cyclic_percent < 1.0: |
| shift = int(C * self.cyclic_percent) |
| x = torch.roll(x, shifts=shift, dims=1) |
| return x |
|
|
| class ISF_Module(nn.Module): |
| """ |
| A lightweight module that wraps shuffle + depthwise conv + group-wise scaling + residual. |
| - channels: total channels of x |
| - groups: number of logical groups (must divide channels) |
| """ |
| def __init__(self, channels: int, groups: int = 4, kernel_size: int = 3, cyclic_percent: float = 0.0): |
| super().__init__() |
| assert channels % groups == 0 |
| self.groups = groups |
| self.channels = channels |
| self.shuffle = GroupChannelShuffle(groups=groups, cyclic_percent=cyclic_percent) |
|
|
| |
| self.dw = nn.Conv2d(channels, channels, kernel_size=kernel_size, padding=kernel_size//2, groups=channels, bias=False) |
| self.bn = nn.BatchNorm2d(channels) |
| self.act = nn.ReLU(inplace=True) |
|
|
| |
| self.group_scale = nn.Parameter(torch.ones(groups), requires_grad=True) |
|
|
| |
| |
|
|
| def forward(self, x): |
| |
| B, C, H, W = x.shape |
| |
| y = self.shuffle(x) |
|
|
| |
| y = self.dw(y) |
| y = self.bn(y) |
| y = self.act(y) |
|
|
| |
| gc = C // self.groups |
| |
| scale = self.group_scale.to(x.device) |
| scale = scale.repeat_interleave(gc).view(1, C, 1, 1) |
| y = y * scale |
|
|
| |
| out = x + y |
| return out |
|
|
|
|