File size: 578 Bytes
5acc7ae | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 | import torch
import torch.nn as nn
class DGG_Module(nn.Module):
def __init__(self, channels, groups):
super().__init__()
self.groups = groups
self.fc = nn.Linear(groups, groups)
def forward(self, x):
B, C, H, W = x.shape
gc = C // self.groups
xg = x.view(B, self.groups, gc, H, W).mean(dim=(2,3,4)) # (B, groups)
gates = torch.sigmoid(self.fc(xg))[:, :, None, None, None] # (B, groups, 1, 1, 1)
xg = x.view(B, self.groups, gc, H, W)
out = (xg * gates).reshape(B, C, H, W)
return out
|