import torch import torch.nn as nn class DGG_Module(nn.Module): def __init__(self, channels, groups): super().__init__() self.groups = groups self.fc = nn.Linear(groups, groups) def forward(self, x): B, C, H, W = x.shape gc = C // self.groups xg = x.view(B, self.groups, gc, H, W).mean(dim=(2,3,4)) # (B, groups) gates = torch.sigmoid(self.fc(xg))[:, :, None, None, None] # (B, groups, 1, 1, 1) xg = x.view(B, self.groups, gc, H, W) out = (xg * gates).reshape(B, C, H, W) return out