File size: 3,050 Bytes
5acc7ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import torch
import torch.nn as nn

class GroupChannelShuffle(nn.Module):
    """
    group-based channel shuffle / interleave.
    groups: number of source groups you want to interleave (e.g. 4 for c1..c4)
    optional cyclic shift (percent) to add deterministic rotation after shuffle.
    """
    def __init__(self, groups: int = 4, cyclic_percent: float = 0.0):
        super().__init__()
        assert groups >= 1
        self.groups = groups
        self.cyclic_percent = cyclic_percent

    def forward(self, x):
        # x: (B, C, H, W)
        B, C, H, W = x.shape
        g = self.groups
        assert C % g == 0, f"channels {C} not divisible by groups {g}"
        gc = C // g
        # reshape to (B, groups, group_channels, H, W)
        x = x.view(B, g, gc, H, W)
        # transpose to interleave: (B, group_channels, groups, H, W)
        x = x.transpose(1, 2).contiguous()
        x = x.view(B, C, H, W)
        # optional cyclic rotate by percent of channels (deterministic)
        if self.cyclic_percent and 0 < self.cyclic_percent < 1.0:
            shift = int(C * self.cyclic_percent)
            x = torch.roll(x, shifts=shift, dims=1)
        return x

class ISF_Module(nn.Module):
    """
    A lightweight module that wraps shuffle + depthwise conv + group-wise scaling + residual.
    - channels: total channels of x
    - groups: number of logical groups (must divide channels)
    """
    def __init__(self, channels: int, groups: int = 4, kernel_size: int = 3, cyclic_percent: float = 0.0):
        super().__init__()
        assert channels % groups == 0
        self.groups = groups
        self.channels = channels
        self.shuffle = GroupChannelShuffle(groups=groups, cyclic_percent=cyclic_percent)

        # depthwise conv (per-channel local spatial enhancement)
        self.dw = nn.Conv2d(channels, channels, kernel_size=kernel_size, padding=kernel_size//2, groups=channels, bias=False)
        self.bn = nn.BatchNorm2d(channels)
        self.act = nn.ReLU(inplace=True)

        # group-wise scaling: one scalar per group to reweight groups after fusion
        self.group_scale = nn.Parameter(torch.ones(groups), requires_grad=True)  # tiny param overhead

        # optional small pointwise to re-calibrate channels (commented out to keep ultra-light)
        # self.pw = nn.Conv2d(channels, channels, kernel_size=1, bias=False)

    def forward(self, x):
        # x: (B, C, H, W)
        B, C, H, W = x.shape
        # 1) deterministic interleave
        y = self.shuffle(x)               # (B, C, H, W)

        # 2) per-channel spatial refine
        y = self.dw(y)
        y = self.bn(y)
        y = self.act(y)

        # 3) group-wise scaling
        gc = C // self.groups
        # scale = self.group_scale.repeat_interleave(gc).view(1, C, 1, 1)  # (1, C, 1, 1)
        scale = self.group_scale.to(x.device)  
        scale = scale.repeat_interleave(gc).view(1, C, 1, 1)
        y = y * scale

        # 4) residual add to preserve original information
        out = x + y
        return out