File size: 3,050 Bytes
5acc7ae | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 | import torch
import torch.nn as nn
class GroupChannelShuffle(nn.Module):
"""
group-based channel shuffle / interleave.
groups: number of source groups you want to interleave (e.g. 4 for c1..c4)
optional cyclic shift (percent) to add deterministic rotation after shuffle.
"""
def __init__(self, groups: int = 4, cyclic_percent: float = 0.0):
super().__init__()
assert groups >= 1
self.groups = groups
self.cyclic_percent = cyclic_percent
def forward(self, x):
# x: (B, C, H, W)
B, C, H, W = x.shape
g = self.groups
assert C % g == 0, f"channels {C} not divisible by groups {g}"
gc = C // g
# reshape to (B, groups, group_channels, H, W)
x = x.view(B, g, gc, H, W)
# transpose to interleave: (B, group_channels, groups, H, W)
x = x.transpose(1, 2).contiguous()
x = x.view(B, C, H, W)
# optional cyclic rotate by percent of channels (deterministic)
if self.cyclic_percent and 0 < self.cyclic_percent < 1.0:
shift = int(C * self.cyclic_percent)
x = torch.roll(x, shifts=shift, dims=1)
return x
class ISF_Module(nn.Module):
"""
A lightweight module that wraps shuffle + depthwise conv + group-wise scaling + residual.
- channels: total channels of x
- groups: number of logical groups (must divide channels)
"""
def __init__(self, channels: int, groups: int = 4, kernel_size: int = 3, cyclic_percent: float = 0.0):
super().__init__()
assert channels % groups == 0
self.groups = groups
self.channels = channels
self.shuffle = GroupChannelShuffle(groups=groups, cyclic_percent=cyclic_percent)
# depthwise conv (per-channel local spatial enhancement)
self.dw = nn.Conv2d(channels, channels, kernel_size=kernel_size, padding=kernel_size//2, groups=channels, bias=False)
self.bn = nn.BatchNorm2d(channels)
self.act = nn.ReLU(inplace=True)
# group-wise scaling: one scalar per group to reweight groups after fusion
self.group_scale = nn.Parameter(torch.ones(groups), requires_grad=True) # tiny param overhead
# optional small pointwise to re-calibrate channels (commented out to keep ultra-light)
# self.pw = nn.Conv2d(channels, channels, kernel_size=1, bias=False)
def forward(self, x):
# x: (B, C, H, W)
B, C, H, W = x.shape
# 1) deterministic interleave
y = self.shuffle(x) # (B, C, H, W)
# 2) per-channel spatial refine
y = self.dw(y)
y = self.bn(y)
y = self.act(y)
# 3) group-wise scaling
gc = C // self.groups
# scale = self.group_scale.repeat_interleave(gc).view(1, C, 1, 1) # (1, C, 1, 1)
scale = self.group_scale.to(x.device)
scale = scale.repeat_interleave(gc).view(1, C, 1, 1)
y = y * scale
# 4) residual add to preserve original information
out = x + y
return out
|