Spaces:
Sleeping
Sleeping
Add nets modules
Browse files- nets/Common.py +311 -0
- nets/__init__.py +1 -0
- nets/backbone.py +105 -0
- nets/model.py +121 -0
- nets/yolo_training.py +348 -0
nets/Common.py
ADDED
|
@@ -0,0 +1,311 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from thop import profile
|
| 4 |
+
|
| 5 |
+
class SiLU(nn.Module):
|
| 6 |
+
@staticmethod
|
| 7 |
+
def forward(x):
|
| 8 |
+
return x * torch.sigmoid(x)
|
| 9 |
+
|
| 10 |
+
def autopad(k, p=None):
|
| 11 |
+
if p is None:
|
| 12 |
+
p = k // 2 if isinstance(k, int) else [x // 2 for x in k]
|
| 13 |
+
return p
|
| 14 |
+
|
| 15 |
+
class Conv(nn.Module):
|
| 16 |
+
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=nn.LeakyReLU(0.1, inplace=True)): # ch_in, ch_out, kernel, stride, padding, groups
|
| 17 |
+
super(Conv, self).__init__()
|
| 18 |
+
self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
|
| 19 |
+
self.bn = nn.BatchNorm2d(c2, eps=0.001, momentum=0.03)
|
| 20 |
+
self.act = nn.LeakyReLU(0.1, inplace=True) if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def forward(self, x):
|
| 24 |
+
return self.act(self.bn(self.conv(x)))
|
| 25 |
+
|
| 26 |
+
def fuseforward(self, x):
|
| 27 |
+
return self.act(self.conv(x))
|
| 28 |
+
|
| 29 |
+
class BasicConv(nn.Module):
|
| 30 |
+
def __init__(
|
| 31 |
+
self,
|
| 32 |
+
in_planes,
|
| 33 |
+
out_planes,
|
| 34 |
+
kernel_size,
|
| 35 |
+
stride=1,
|
| 36 |
+
padding=0,
|
| 37 |
+
dilation=1,
|
| 38 |
+
groups=1,
|
| 39 |
+
relu=True,
|
| 40 |
+
bn=True,
|
| 41 |
+
bias=False,
|
| 42 |
+
):
|
| 43 |
+
super(BasicConv, self).__init__()
|
| 44 |
+
self.out_channels = out_planes
|
| 45 |
+
self.conv = nn.Conv2d(
|
| 46 |
+
in_planes,
|
| 47 |
+
out_planes,
|
| 48 |
+
kernel_size=kernel_size,
|
| 49 |
+
stride=stride,
|
| 50 |
+
padding=padding,
|
| 51 |
+
dilation=dilation,
|
| 52 |
+
groups=groups,
|
| 53 |
+
bias=bias,
|
| 54 |
+
)
|
| 55 |
+
self.bn = (
|
| 56 |
+
nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True)
|
| 57 |
+
if bn
|
| 58 |
+
else None
|
| 59 |
+
)
|
| 60 |
+
self.relu = nn.ReLU() if relu else None
|
| 61 |
+
|
| 62 |
+
def forward(self, x):
|
| 63 |
+
x = self.conv(x)
|
| 64 |
+
if self.bn is not None:
|
| 65 |
+
x = self.bn(x)
|
| 66 |
+
if self.relu is not None:
|
| 67 |
+
x = self.relu(x)
|
| 68 |
+
return x
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class ChannelPool(nn.Module):
|
| 72 |
+
def forward(self, x):
|
| 73 |
+
return torch.cat(
|
| 74 |
+
(torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class SpatialGate(nn.Module):
|
| 79 |
+
def __init__(self):
|
| 80 |
+
super(SpatialGate, self).__init__()
|
| 81 |
+
kernel_size = 7
|
| 82 |
+
self.compress = ChannelPool()
|
| 83 |
+
self.spatial = BasicConv(
|
| 84 |
+
2, 1, kernel_size, stride=1, padding=(kernel_size - 1) // 2, relu=False
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
def forward(self, x):
|
| 88 |
+
x_compress = self.compress(x)
|
| 89 |
+
x_out = self.spatial(x_compress)
|
| 90 |
+
scale = torch.sigmoid_(x_out)
|
| 91 |
+
return x * scale
|
| 92 |
+
|
| 93 |
+
def autopad(k, p=None):
|
| 94 |
+
if p is None:
|
| 95 |
+
p = k // 2 if isinstance(k, int) else [x // 2 for x in k]
|
| 96 |
+
return p
|
| 97 |
+
|
| 98 |
+
class Conv(nn.Module):
|
| 99 |
+
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=nn.LeakyReLU(0.1, inplace=True)): # ch_in, ch_out, kernel, stride, padding, groups
|
| 100 |
+
super(Conv, self).__init__()
|
| 101 |
+
self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
|
| 102 |
+
self.bn = nn.BatchNorm2d(c2, eps=0.001, momentum=0.03)
|
| 103 |
+
self.act = nn.LeakyReLU(0.1, inplace=True) if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def forward(self, x):
|
| 107 |
+
return self.act(self.bn(self.conv(x)))
|
| 108 |
+
#lighting dehaze network
|
| 109 |
+
class LMDNet(nn.Module):
|
| 110 |
+
|
| 111 |
+
def __init__(self):
|
| 112 |
+
super(LMDNet, self).__init__()
|
| 113 |
+
# mainNet Architecture
|
| 114 |
+
self.AAM = nn.Sequential(
|
| 115 |
+
nn.Conv2d(64, 3, 1, 1),
|
| 116 |
+
nn.LeakyReLU(inplace=True),
|
| 117 |
+
nn.Dropout(0.5)
|
| 118 |
+
)
|
| 119 |
+
self.AAM_1 = nn.Sequential(
|
| 120 |
+
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
|
| 121 |
+
nn.Conv2d(128, 32, 1, 1),
|
| 122 |
+
nn.LeakyReLU(inplace=True),
|
| 123 |
+
nn.Dropout(0.5)
|
| 124 |
+
)
|
| 125 |
+
self.AAM_2 = nn.Sequential(
|
| 126 |
+
nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True),
|
| 127 |
+
nn.Conv2d(256, 32, 1, 1),
|
| 128 |
+
nn.LeakyReLU(inplace=True),
|
| 129 |
+
nn.Dropout(0.5)
|
| 130 |
+
)
|
| 131 |
+
self.TA = TripletAttention(64)
|
| 132 |
+
|
| 133 |
+
self.conv = Conv(64, 3, 3, 1)
|
| 134 |
+
|
| 135 |
+
self.up4 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True)
|
| 136 |
+
self.relu = nn.LeakyReLU(0.1, inplace=True)
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def forward(self, f1, f2, f3):
|
| 140 |
+
t = self.AAM(f1)
|
| 141 |
+
f2 = self.AAM_1(f2)
|
| 142 |
+
f3 = self.AAM_2(f3)
|
| 143 |
+
x1 = f1
|
| 144 |
+
x2 = torch.cat([f2, f3], dim=1)
|
| 145 |
+
|
| 146 |
+
x = x1 + x2
|
| 147 |
+
x = self.TA(x)
|
| 148 |
+
x = self.conv(x)
|
| 149 |
+
|
| 150 |
+
dehaze = ((x * t) - x + 1)
|
| 151 |
+
|
| 152 |
+
out = self.up4(dehaze)
|
| 153 |
+
out = self.relu(out)
|
| 154 |
+
|
| 155 |
+
return out
|
| 156 |
+
|
| 157 |
+
class TripletAttention(nn.Module):
|
| 158 |
+
def __init__(
|
| 159 |
+
self,
|
| 160 |
+
in_channels,
|
| 161 |
+
reduction_ratio=16,
|
| 162 |
+
pool_types=["avg", "max"],
|
| 163 |
+
no_spatial=False,
|
| 164 |
+
):
|
| 165 |
+
super(TripletAttention, self).__init__()
|
| 166 |
+
self.ChannelGateH = SpatialGate()
|
| 167 |
+
self.ChannelGateW = SpatialGate()
|
| 168 |
+
self.no_spatial = no_spatial
|
| 169 |
+
if not no_spatial:
|
| 170 |
+
self.SpatialGate = SpatialGate()
|
| 171 |
+
|
| 172 |
+
def forward(self, x):
|
| 173 |
+
x_perm1 = x.permute(0, 2, 1, 3).contiguous()
|
| 174 |
+
x_out1 = self.ChannelGateH(x_perm1)
|
| 175 |
+
x_out11 = x_out1.permute(0, 2, 1, 3).contiguous()
|
| 176 |
+
x_perm2 = x.permute(0, 3, 2, 1).contiguous()
|
| 177 |
+
x_out2 = self.ChannelGateW(x_perm2)
|
| 178 |
+
x_out21 = x_out2.permute(0, 3, 2, 1).contiguous()
|
| 179 |
+
if not self.no_spatial:
|
| 180 |
+
x_out = self.SpatialGate(x)
|
| 181 |
+
x_out = (1 / 3) * (x_out + x_out11 + x_out21)
|
| 182 |
+
else:
|
| 183 |
+
x_out = (1 / 2) * (x_out11 + x_out21)
|
| 184 |
+
return x_out
|
| 185 |
+
|
| 186 |
+
class SiLU(nn.Module):
|
| 187 |
+
@staticmethod
|
| 188 |
+
def forward(x):
|
| 189 |
+
return x * torch.sigmoid(x)
|
| 190 |
+
|
| 191 |
+
def autopad(k, p=None):
|
| 192 |
+
if p is None:
|
| 193 |
+
p = k // 2 if isinstance(k, int) else [x // 2 for x in k]
|
| 194 |
+
return p
|
| 195 |
+
|
| 196 |
+
class Conv(nn.Module):
|
| 197 |
+
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=nn.LeakyReLU(0.1, inplace=True)):
|
| 198 |
+
super(Conv, self).__init__()
|
| 199 |
+
self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
|
| 200 |
+
self.bn = nn.BatchNorm2d(c2, eps=0.001, momentum=0.03)
|
| 201 |
+
self.act = nn.LeakyReLU(0.1, inplace=True) if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def forward(self, x):
|
| 205 |
+
return self.act(self.bn(self.conv(x)))
|
| 206 |
+
|
| 207 |
+
def fuseforward(self, x):
|
| 208 |
+
return self.act(self.conv(x))
|
| 209 |
+
|
| 210 |
+
class GIE(torch.nn.Module):
|
| 211 |
+
def __init__(self, epsilon=1e-8):
|
| 212 |
+
super(GIE, self).__init__()
|
| 213 |
+
self.epsilon = epsilon
|
| 214 |
+
|
| 215 |
+
def forward(self, x):
|
| 216 |
+
# Step 1: Pixel Mean Squared
|
| 217 |
+
x_mean = torch.mean(x, dim=(2, 3), keepdim=True)
|
| 218 |
+
epsilon = (x - x_mean) ** 2
|
| 219 |
+
# Step 2: Global Extraction
|
| 220 |
+
epsilon_mean = torch.mean(epsilon, dim=(2, 3), keepdim=False)
|
| 221 |
+
epsilon_mean += self.epsilon
|
| 222 |
+
# Step 3: Gamma Calculation (Global Extraction)
|
| 223 |
+
gamma_t_c = epsilon / epsilon_mean.unsqueeze(-1).unsqueeze(-1)
|
| 224 |
+
sigmoid_gamma = torch.sigmoid(gamma_t_c)
|
| 225 |
+
output = x * sigmoid_gamma
|
| 226 |
+
return output
|
| 227 |
+
|
| 228 |
+
# Multi-branch Pooling Information Fusion Module
|
| 229 |
+
class MPIF(nn.Module):
|
| 230 |
+
def __init__(self, c1, c2, c3, s=2, n=4, e=1, ids=[0]):
|
| 231 |
+
super(MPIF, self).__init__()
|
| 232 |
+
c_ = int(c2 * e)
|
| 233 |
+
|
| 234 |
+
self.ids = ids
|
| 235 |
+
if s == 1:
|
| 236 |
+
self.m1 = nn.MaxPool2d(kernel_size=3, stride=s, padding=1)
|
| 237 |
+
self.m2 = nn.AvgPool2d(kernel_size=3, stride=s, padding=1)
|
| 238 |
+
else:
|
| 239 |
+
self.m1 = nn.MaxPool2d(kernel_size=2, stride=s)
|
| 240 |
+
self.m2 = nn.AvgPool2d(kernel_size=2, stride=s)
|
| 241 |
+
|
| 242 |
+
self.cv1 = Conv(c1, c_, 1, 1)
|
| 243 |
+
self.cv2 = Conv(c1, c_, 1, 1)
|
| 244 |
+
self.cv3 = nn.ModuleList(
|
| 245 |
+
[Conv(c_ if i ==0 else c2, c2, 3, 1) for i in range(n)]
|
| 246 |
+
)
|
| 247 |
+
self.cv4 = Conv(c_ * 2 + c2 * (len(ids) - 2), c3, 1, 1)
|
| 248 |
+
|
| 249 |
+
self.GIE = GIE(c1)
|
| 250 |
+
|
| 251 |
+
def forward(self, x):
|
| 252 |
+
x1 = self.m1(x)
|
| 253 |
+
x2 = self.m2(x)
|
| 254 |
+
x = x1 + x2
|
| 255 |
+
x_1 = self.cv1(x)
|
| 256 |
+
x_1 = self.GIE(x_1)
|
| 257 |
+
x_2 = self.cv2(x)
|
| 258 |
+
|
| 259 |
+
x_all = [x_1, x_2]
|
| 260 |
+
|
| 261 |
+
for i in range(len(self.cv3)):
|
| 262 |
+
x_2 = self.cv3[i](x_2)
|
| 263 |
+
x_all.append(x_2)
|
| 264 |
+
|
| 265 |
+
out = self.cv4(torch.cat([x_all[id] for id in self.ids], 1))
|
| 266 |
+
|
| 267 |
+
return out
|
| 268 |
+
|
| 269 |
+
class DilatedConvNet(nn.Module):
|
| 270 |
+
def __init__(self, in_channels, out_channels, dilation, padding, kernel_size):
|
| 271 |
+
super(DilatedConvNet, self).__init__()
|
| 272 |
+
self.dilated_conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=1, padding=padding, dilation=dilation)
|
| 273 |
+
self.relu = nn.ReLU(inplace=False)
|
| 274 |
+
|
| 275 |
+
def forward(self, x):
|
| 276 |
+
|
| 277 |
+
x = self.dilated_conv(x)
|
| 278 |
+
x = self.relu(x)
|
| 279 |
+
|
| 280 |
+
return x
|
| 281 |
+
|
| 282 |
+
class SPPELAN(nn.Module):
|
| 283 |
+
def __init__(self, c1, c2, c3=16):
|
| 284 |
+
super().__init__()
|
| 285 |
+
self.c = c3
|
| 286 |
+
self.cv1 = Conv(c1, c3, 1, 1)
|
| 287 |
+
self.cv2 = DilatedConvNet(c3, c3, kernel_size=3, padding=1, dilation=1)
|
| 288 |
+
self.cv3 = DilatedConvNet(c3, c3, kernel_size=3, padding=2, dilation=2)
|
| 289 |
+
self.cv4 = DilatedConvNet(c3, c3, kernel_size=3, padding=3, dilation=3)
|
| 290 |
+
self.cv5 = Conv(4*c3, c2, 1, 1)
|
| 291 |
+
|
| 292 |
+
def forward(self, x):
|
| 293 |
+
y = [self.cv1(x)]
|
| 294 |
+
y.extend(m(y[-1]) for m in [self.cv2, self.cv3, self.cv4])
|
| 295 |
+
return self.cv5(torch.cat(y, 1))
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
def print_model_flops_and_params(model, inputs):
|
| 300 |
+
flops, params = profile(model, inputs=inputs)
|
| 301 |
+
print(f"FLOPs: {flops / 1e9:.2f} GFLOPs")
|
| 302 |
+
print(f"Parameters: {params / 1e6:.2f} M")
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
if __name__ == "__main__":
|
| 306 |
+
feat1 = torch.randn(1, 64, 160, 160)
|
| 307 |
+
feat2 = torch.randn(1, 128, 80, 80)
|
| 308 |
+
feat3 = torch.randn(1, 256, 40, 40)
|
| 309 |
+
encoder = LMDNet()
|
| 310 |
+
print_model_flops_and_params(encoder, (feat1, feat2, feat3))
|
| 311 |
+
|
nets/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
#
|
nets/backbone.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from nets.Common import GIE, LMDNet
|
| 4 |
+
|
| 5 |
+
def autopad(k, p=None):
|
| 6 |
+
if p is None:
|
| 7 |
+
p = k // 2 if isinstance(k, int) else [x // 2 for x in k]
|
| 8 |
+
return p
|
| 9 |
+
|
| 10 |
+
class Conv(nn.Module):
|
| 11 |
+
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=nn.LeakyReLU(0.1, inplace=True)): # ch_in, ch_out, kernel, stride, padding, groups
|
| 12 |
+
super(Conv, self).__init__()
|
| 13 |
+
self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
|
| 14 |
+
self.bn = nn.BatchNorm2d(c2, eps=0.001, momentum=0.03)
|
| 15 |
+
self.act = nn.LeakyReLU(0.1, inplace=True) if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
|
| 16 |
+
|
| 17 |
+
def forward(self, x):
|
| 18 |
+
return self.act(self.bn(self.conv(x)))
|
| 19 |
+
|
| 20 |
+
def fuseforward(self, x):
|
| 21 |
+
return self.act(self.conv(x))
|
| 22 |
+
|
| 23 |
+
# Multi-branch Pooling Information Fusion Module (Multi_Concat_Block + MP)#
|
| 24 |
+
# ------------------------------------------------------------------------- #
|
| 25 |
+
class Multi_Concat_Block(nn.Module):
|
| 26 |
+
def __init__(self, c1, c2, c3, n=4, e=1, ids=[0]):
|
| 27 |
+
super(Multi_Concat_Block, self).__init__()
|
| 28 |
+
c_ = int(c2 * e)
|
| 29 |
+
self.ids = ids
|
| 30 |
+
self.cv1 = Conv(c1, c_, 1, 1)
|
| 31 |
+
self.cv2 = Conv(c1, c_, 1, 1)
|
| 32 |
+
self.cv3 = nn.ModuleList(
|
| 33 |
+
[Conv(c_ if i ==0 else c2, c2, 3, 1) for i in range(n)]
|
| 34 |
+
)
|
| 35 |
+
self.cv4 = Conv(c_ * 2 + c2 * (len(ids) - 2), c3, 1, 1)
|
| 36 |
+
self.GIE = GIE(c1)
|
| 37 |
+
|
| 38 |
+
def forward(self, x):
|
| 39 |
+
x_1 = self.cv1(x)
|
| 40 |
+
x_1 = self.GIE(x_1)
|
| 41 |
+
x_2 = self.cv2(x)
|
| 42 |
+
x_all = [x_1, x_2]
|
| 43 |
+
|
| 44 |
+
for i in range(len(self.cv3)):
|
| 45 |
+
x_2 = self.cv3[i](x_2)
|
| 46 |
+
x_all.append(x_2)
|
| 47 |
+
|
| 48 |
+
out = self.cv4(torch.cat([x_all[id] for id in self.ids], 1))
|
| 49 |
+
return out
|
| 50 |
+
|
| 51 |
+
class MP(nn.Module):
|
| 52 |
+
def __init__(self, k=2):
|
| 53 |
+
super(MP, self).__init__()
|
| 54 |
+
self.m1 = nn.MaxPool2d(kernel_size=k, stride=k)
|
| 55 |
+
self.m2 = nn.AvgPool2d(kernel_size=k, stride=k)
|
| 56 |
+
|
| 57 |
+
def forward(self, x):
|
| 58 |
+
x1 = self.m1(x)
|
| 59 |
+
x2 = self.m2(x)
|
| 60 |
+
return x1 + x2
|
| 61 |
+
# ------------------------------------------------------------------------- #
|
| 62 |
+
|
| 63 |
+
class Backbone(nn.Module):
|
| 64 |
+
def __init__(self, transition_channels, block_channels, n):
|
| 65 |
+
super().__init__()
|
| 66 |
+
ids = [-1, -2, -3, -4]
|
| 67 |
+
|
| 68 |
+
self.stem = Conv(3, transition_channels * 2, 3, 2)
|
| 69 |
+
self.dehze = LMDNet()
|
| 70 |
+
self.dark2 = nn.Sequential(
|
| 71 |
+
Conv(transition_channels * 2, transition_channels * 4, 3, 2),
|
| 72 |
+
Multi_Concat_Block(transition_channels * 4, block_channels * 2, transition_channels * 4, n=n, ids=ids),
|
| 73 |
+
)
|
| 74 |
+
self.dark3 = nn.Sequential(
|
| 75 |
+
MP(),
|
| 76 |
+
Multi_Concat_Block(transition_channels * 4, block_channels * 4, transition_channels * 8, n=n, ids=ids),
|
| 77 |
+
)
|
| 78 |
+
self.dark4 = nn.Sequential(
|
| 79 |
+
MP(),
|
| 80 |
+
Multi_Concat_Block(transition_channels * 8, block_channels * 8, transition_channels * 16, n=n, ids=ids),
|
| 81 |
+
)
|
| 82 |
+
self.dark5 = nn.Sequential(
|
| 83 |
+
MP(),
|
| 84 |
+
Multi_Concat_Block(transition_channels * 16, block_channels * 16, transition_channels * 32, n=n, ids=ids),
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
def forward(self, x):
|
| 88 |
+
if self.training:
|
| 89 |
+
x, clear_x = x.split((8, 8), dim=0)
|
| 90 |
+
x = self.stem(x)
|
| 91 |
+
x = self.dark2(x)
|
| 92 |
+
f1 = x
|
| 93 |
+
x = self.dark3(x)
|
| 94 |
+
feat1 = x
|
| 95 |
+
f2 = x
|
| 96 |
+
x = self.dark4(x)
|
| 97 |
+
feat2 = x
|
| 98 |
+
f3 = x
|
| 99 |
+
x = self.dark5(x)
|
| 100 |
+
feat3 = x
|
| 101 |
+
dehazing = self.dehze(f1, f2, f3)
|
| 102 |
+
|
| 103 |
+
if self.training:
|
| 104 |
+
return feat1, feat2, feat3, dehazing
|
| 105 |
+
return feat1, feat2, feat3
|
nets/model.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from nets.Common import Conv, SPPELAN
|
| 4 |
+
from nets.backbone import Backbone, Multi_Concat_Block
|
| 5 |
+
|
| 6 |
+
def fuse_conv_and_bn(conv, bn):
|
| 7 |
+
fusedconv = nn.Conv2d(conv.in_channels,
|
| 8 |
+
conv.out_channels,
|
| 9 |
+
kernel_size=conv.kernel_size,
|
| 10 |
+
stride=conv.stride,
|
| 11 |
+
padding=conv.padding,
|
| 12 |
+
groups=conv.groups,
|
| 13 |
+
bias=True).requires_grad_(False).to(conv.weight.device)
|
| 14 |
+
|
| 15 |
+
w_conv = conv.weight.clone().view(conv.out_channels, -1)
|
| 16 |
+
w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
|
| 17 |
+
fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape).detach())
|
| 18 |
+
|
| 19 |
+
b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias
|
| 20 |
+
b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
|
| 21 |
+
fusedconv.bias.copy_((torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn).detach())
|
| 22 |
+
return fusedconv
|
| 23 |
+
|
| 24 |
+
class MP(nn.Module):
|
| 25 |
+
def __init__(self, k=2):
|
| 26 |
+
super(MP, self).__init__()
|
| 27 |
+
self.m1 = nn.MaxPool2d(kernel_size=k, stride=k)
|
| 28 |
+
self.m2 = nn.AvgPool2d(kernel_size=k, stride=k)
|
| 29 |
+
self.up = nn.Upsample(scale_factor=2)
|
| 30 |
+
|
| 31 |
+
def forward(self, x):
|
| 32 |
+
x1 = self.m1(x)
|
| 33 |
+
x2 = self.m2(x)
|
| 34 |
+
return self.up(x1 + x2)
|
| 35 |
+
|
| 36 |
+
class YoloBody(nn.Module):
|
| 37 |
+
def __init__(self, anchors_mask, num_classes):
|
| 38 |
+
super(YoloBody, self).__init__()
|
| 39 |
+
transition_channels = 16
|
| 40 |
+
block_channels = 16
|
| 41 |
+
panet_channels = 16
|
| 42 |
+
e = 1
|
| 43 |
+
n = 2
|
| 44 |
+
ids = [-1, -2, -3, -4]
|
| 45 |
+
|
| 46 |
+
self.backbone = Backbone(transition_channels, block_channels, n)
|
| 47 |
+
|
| 48 |
+
self.upsample = nn.Upsample(scale_factor=2, mode="nearest")
|
| 49 |
+
|
| 50 |
+
self.sppelan = SPPELAN(transition_channels * 32, transition_channels * 16)
|
| 51 |
+
self.conv_for_P5 = Conv(transition_channels * 16, transition_channels * 8)
|
| 52 |
+
self.conv_for_feat2 = Conv(transition_channels * 16, transition_channels * 8)
|
| 53 |
+
self.conv3_for_upsample1 = Multi_Concat_Block(transition_channels * 16, panet_channels * 4, transition_channels * 8, e=e, n=n, ids=ids)
|
| 54 |
+
|
| 55 |
+
self.conv_for_P4 = Conv(transition_channels * 8, transition_channels * 4)
|
| 56 |
+
self.conv_for_feat1 = Conv(transition_channels * 8, transition_channels * 4)
|
| 57 |
+
self.conv3_for_upsample2 = Multi_Concat_Block(transition_channels * 8, panet_channels * 2, transition_channels * 4, e=e, n=n, ids=ids)
|
| 58 |
+
|
| 59 |
+
self.down_sample1 = Conv(transition_channels * 4, transition_channels * 8, k=3, s=2)
|
| 60 |
+
self.conv3_for_downsample1 = Multi_Concat_Block(transition_channels * 16, panet_channels * 4, transition_channels * 8, e=e, n=n, ids=ids)
|
| 61 |
+
|
| 62 |
+
self.down_sample2 = Conv(transition_channels * 8, transition_channels * 16, k=3, s=2)
|
| 63 |
+
self.conv3_for_downsample2 = Multi_Concat_Block(transition_channels * 32, panet_channels * 8, transition_channels * 16, e=e, n=n, ids=ids)
|
| 64 |
+
|
| 65 |
+
self.pf = MP()
|
| 66 |
+
|
| 67 |
+
self.rep_conv_1 = Conv(transition_channels * 4, transition_channels * 8, 3, 1)
|
| 68 |
+
self.rep_conv_2 = Conv(transition_channels * 8, transition_channels * 16, 3, 1)
|
| 69 |
+
self.rep_conv_3 = Conv(transition_channels * 16, transition_channels * 32, 3, 1)
|
| 70 |
+
|
| 71 |
+
self.yolo_head_P3 = nn.Conv2d(transition_channels * 8, len(anchors_mask[2]) * (5 + num_classes), 1)
|
| 72 |
+
self.yolo_head_P4 = nn.Conv2d(transition_channels * 16, len(anchors_mask[1]) * (5 + num_classes), 1)
|
| 73 |
+
self.yolo_head_P5 = nn.Conv2d(transition_channels * 32, len(anchors_mask[0]) * (5 + num_classes), 1)
|
| 74 |
+
|
| 75 |
+
def fuse(self):
|
| 76 |
+
print('Fusing layers... ')
|
| 77 |
+
for m in self.modules():
|
| 78 |
+
if type(m) is Conv and hasattr(m, 'bn'):
|
| 79 |
+
m.conv = fuse_conv_and_bn(m.conv, m.bn)
|
| 80 |
+
delattr(m, 'bn')
|
| 81 |
+
m.forward = m.fuseforward
|
| 82 |
+
return self
|
| 83 |
+
|
| 84 |
+
def forward(self, x):
|
| 85 |
+
if self.training:
|
| 86 |
+
feat1, feat2, feat3, dehazing = self.backbone.forward(x)
|
| 87 |
+
else:
|
| 88 |
+
feat1, feat2, feat3 = self.backbone.forward(x)
|
| 89 |
+
|
| 90 |
+
P5 = self.sppelan(feat3)
|
| 91 |
+
P5_conv = self.conv_for_P5(P5)
|
| 92 |
+
P5_upsample = self.upsample(P5_conv)
|
| 93 |
+
P4 = torch.cat([self.conv_for_feat2(feat2), P5_upsample], 1)
|
| 94 |
+
P4 = self.conv3_for_upsample1(P4)
|
| 95 |
+
|
| 96 |
+
P4_conv = self.conv_for_P4(P4)
|
| 97 |
+
P4_upsample = self.upsample(P4_conv)
|
| 98 |
+
P3 = torch.cat([self.conv_for_feat1(feat1), P4_upsample], 1)
|
| 99 |
+
P3 = self.conv3_for_upsample2(P3)
|
| 100 |
+
|
| 101 |
+
P3_downsample = self.down_sample1(P3)
|
| 102 |
+
P4 = torch.cat([P3_downsample, P4], 1)
|
| 103 |
+
P4 = self.conv3_for_downsample1(P4)
|
| 104 |
+
P4 = self.pf(P4)
|
| 105 |
+
|
| 106 |
+
P4_downsample = self.down_sample2(P4)
|
| 107 |
+
P5 = torch.cat([P4_downsample, P5], 1)
|
| 108 |
+
P5 = self.conv3_for_downsample2(P5)
|
| 109 |
+
|
| 110 |
+
P3 = self.rep_conv_1(P3)
|
| 111 |
+
P4 = self.rep_conv_2(P4)
|
| 112 |
+
P5 = self.rep_conv_3(P5)
|
| 113 |
+
|
| 114 |
+
out2 = self.yolo_head_P3(P3)
|
| 115 |
+
out1 = self.yolo_head_P4(P4)
|
| 116 |
+
out0 = self.yolo_head_P5(P5)
|
| 117 |
+
|
| 118 |
+
if self.training:
|
| 119 |
+
return [out0, out1, out2, dehazing]
|
| 120 |
+
else:
|
| 121 |
+
return [out0, out1, out2]
|
nets/yolo_training.py
ADDED
|
@@ -0,0 +1,348 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from copy import deepcopy
|
| 3 |
+
from functools import partial
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
def smooth_BCE(eps=0.1):
|
| 9 |
+
return 1.0 - 0.5 * eps, 0.5 * eps
|
| 10 |
+
class YOLOLoss(nn.Module):
|
| 11 |
+
def __init__(self, anchors, num_classes, input_shape, anchors_mask = [[6,7,8], [3,4,5], [0,1,2]], label_smoothing = 0):
|
| 12 |
+
super(YOLOLoss, self).__init__()
|
| 13 |
+
self.anchors = [anchors[mask] for mask in anchors_mask]
|
| 14 |
+
self.num_classes = num_classes
|
| 15 |
+
self.input_shape = input_shape
|
| 16 |
+
self.anchors_mask = anchors_mask
|
| 17 |
+
self.balance = [0.4, 1.0, 4]
|
| 18 |
+
self.stride = [32, 16, 8]
|
| 19 |
+
self.box_ratio = 0.05
|
| 20 |
+
self.obj_ratio = 1 * (input_shape[0] * input_shape[1]) / (640 ** 2)
|
| 21 |
+
self.cls_ratio = 0.5 * (num_classes / 80)
|
| 22 |
+
self.threshold = 4
|
| 23 |
+
self.cp, self.cn = smooth_BCE(eps=label_smoothing)
|
| 24 |
+
self.BCEcls, self.BCEobj, self.gr = nn.BCEWithLogitsLoss(), nn.BCEWithLogitsLoss(), 1
|
| 25 |
+
def bbox_iou(self, box1, box2, x1y1x2y2=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7):
|
| 26 |
+
box2 = box2.T
|
| 27 |
+
if x1y1x2y2:
|
| 28 |
+
b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
|
| 29 |
+
b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]
|
| 30 |
+
else:
|
| 31 |
+
b1_x1, b1_x2 = box1[0] - box1[2] / 2, box1[0] + box1[2] / 2
|
| 32 |
+
b1_y1, b1_y2 = box1[1] - box1[3] / 2, box1[1] + box1[3] / 2
|
| 33 |
+
b2_x1, b2_x2 = box2[0] - box2[2] / 2, box2[0] + box2[2] / 2
|
| 34 |
+
b2_y1, b2_y2 = box2[1] - box2[3] / 2, box2[1] + box2[3] / 2
|
| 35 |
+
inter = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * \
|
| 36 |
+
(torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)).clamp(0)
|
| 37 |
+
w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
|
| 38 |
+
w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
|
| 39 |
+
union = w1 * h1 + w2 * h2 - inter + eps
|
| 40 |
+
iou = inter / union
|
| 41 |
+
if GIoU or DIoU or CIoU:
|
| 42 |
+
cw = torch.max(b1_x2, b2_x2) - torch.min(b1_x1, b2_x1)
|
| 43 |
+
ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1)
|
| 44 |
+
if CIoU or DIoU:
|
| 45 |
+
c2 = cw ** 2 + ch ** 2 + eps
|
| 46 |
+
rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 +
|
| 47 |
+
(b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4
|
| 48 |
+
if DIoU:
|
| 49 |
+
return iou - rho2 / c2
|
| 50 |
+
elif CIoU:
|
| 51 |
+
v = (4 / math.pi ** 2) * torch.pow(torch.atan(w2 / h2) - torch.atan(w1 / h1), 2)
|
| 52 |
+
with torch.no_grad():
|
| 53 |
+
alpha = v / (v - iou + (1 + eps))
|
| 54 |
+
return iou - (rho2 / c2 + v * alpha)
|
| 55 |
+
else:
|
| 56 |
+
c_area = cw * ch + eps
|
| 57 |
+
return iou - (c_area - union) / c_area
|
| 58 |
+
else:
|
| 59 |
+
return iou
|
| 60 |
+
def __call__(self, predictions, targets, imgs):
|
| 61 |
+
for i in range(len(predictions)):
|
| 62 |
+
bs, _, h, w = predictions[i].size()
|
| 63 |
+
predictions[i] = predictions[i].view(bs, len(self.anchors_mask[i]), -1, h, w).permute(0, 1, 3, 4, 2).contiguous()
|
| 64 |
+
device = targets.device
|
| 65 |
+
cls_loss, box_loss, obj_loss = torch.zeros(1, device = device), torch.zeros(1, device = device), torch.zeros(1, device = device)
|
| 66 |
+
bs, as_, gjs, gis, targets, anchors = self.build_targets(predictions, targets, imgs)
|
| 67 |
+
feature_map_sizes = [torch.tensor(prediction.shape, device=device)[[3, 2, 3, 2]].type_as(prediction) for prediction in predictions]
|
| 68 |
+
for i, prediction in enumerate(predictions):
|
| 69 |
+
b, a, gj, gi = bs[i], as_[i], gjs[i], gis[i]
|
| 70 |
+
tobj = torch.zeros_like(prediction[..., 0], device=device)
|
| 71 |
+
n = b.shape[0]
|
| 72 |
+
if n:
|
| 73 |
+
prediction_pos = prediction[b, a, gj, gi]
|
| 74 |
+
grid = torch.stack([gi, gj], dim=1)
|
| 75 |
+
xy = prediction_pos[:, :2].sigmoid() * 2. - 0.5
|
| 76 |
+
wh = (prediction_pos[:, 2:4].sigmoid() * 2) ** 2 * anchors[i]
|
| 77 |
+
box = torch.cat((xy, wh), 1)
|
| 78 |
+
selected_tbox = targets[i][:, 2:6] * feature_map_sizes[i]
|
| 79 |
+
selected_tbox[:, :2] -= grid.type_as(prediction)
|
| 80 |
+
iou = self.bbox_iou(box.T, selected_tbox, x1y1x2y2=False, CIoU=True)
|
| 81 |
+
box_loss += (1.0 - iou).mean()
|
| 82 |
+
tobj[b, a, gj, gi] = (1.0 - self.gr) + self.gr * iou.detach().clamp(0).type(tobj.dtype)
|
| 83 |
+
selected_tcls = targets[i][:, 1].long()
|
| 84 |
+
t = torch.full_like(prediction_pos[:, 5:], self.cn, device=device)
|
| 85 |
+
t[range(n), selected_tcls] = self.cp
|
| 86 |
+
cls_loss += self.BCEcls(prediction_pos[:, 5:], t)
|
| 87 |
+
obj_loss += self.BCEobj(prediction[..., 4], tobj) * self.balance[i]
|
| 88 |
+
box_loss *= self.box_ratio
|
| 89 |
+
obj_loss *= self.obj_ratio
|
| 90 |
+
cls_loss *= self.cls_ratio
|
| 91 |
+
bs = tobj.shape[0]
|
| 92 |
+
loss = box_loss + obj_loss + cls_loss
|
| 93 |
+
return loss
|
| 94 |
+
def xywh2xyxy(self, x):
|
| 95 |
+
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
| 96 |
+
y[:, 0] = x[:, 0] - x[:, 2] / 2
|
| 97 |
+
y[:, 1] = x[:, 1] - x[:, 3] / 2
|
| 98 |
+
y[:, 2] = x[:, 0] + x[:, 2] / 2
|
| 99 |
+
y[:, 3] = x[:, 1] + x[:, 3] / 2
|
| 100 |
+
return y
|
| 101 |
+
def box_iou(self, box1, box2):
|
| 102 |
+
"""
|
| 103 |
+
Return intersection-over-union (Jaccard index) of boxes.
|
| 104 |
+
Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
|
| 105 |
+
Arguments:
|
| 106 |
+
box1 (Tensor[N, 4])
|
| 107 |
+
box2 (Tensor[M, 4])
|
| 108 |
+
Returns:
|
| 109 |
+
iou (Tensor[N, M]): the NxM matrix containing the pairwise
|
| 110 |
+
IoU values for every element in boxes1 and boxes2
|
| 111 |
+
"""
|
| 112 |
+
def box_area(box):
|
| 113 |
+
return (box[2] - box[0]) * (box[3] - box[1])
|
| 114 |
+
area1 = box_area(box1.T)
|
| 115 |
+
area2 = box_area(box2.T)
|
| 116 |
+
inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) - torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2)
|
| 117 |
+
return inter / (area1[:, None] + area2 - inter)
|
| 118 |
+
def build_targets(self, predictions, targets, imgs):
|
| 119 |
+
indices, anch = self.find_3_positive(predictions, targets)
|
| 120 |
+
matching_bs = [[] for _ in predictions]
|
| 121 |
+
matching_as = [[] for _ in predictions]
|
| 122 |
+
matching_gjs = [[] for _ in predictions]
|
| 123 |
+
matching_gis = [[] for _ in predictions]
|
| 124 |
+
matching_targets = [[] for _ in predictions]
|
| 125 |
+
matching_anchs = [[] for _ in predictions]
|
| 126 |
+
num_layer = len(predictions)
|
| 127 |
+
for batch_idx in range(predictions[0].shape[0]):
|
| 128 |
+
b_idx = targets[:, 0]==batch_idx
|
| 129 |
+
this_target = targets[b_idx]
|
| 130 |
+
if this_target.shape[0] == 0:
|
| 131 |
+
continue
|
| 132 |
+
txywh = this_target[:, 2:6] * imgs[batch_idx].shape[1]
|
| 133 |
+
txyxy = self.xywh2xyxy(txywh)
|
| 134 |
+
pxyxys = []
|
| 135 |
+
p_cls = []
|
| 136 |
+
p_obj = []
|
| 137 |
+
from_which_layer = []
|
| 138 |
+
all_b = []
|
| 139 |
+
all_a = []
|
| 140 |
+
all_gj = []
|
| 141 |
+
all_gi = []
|
| 142 |
+
all_anch = []
|
| 143 |
+
for i, prediction in enumerate(predictions):
|
| 144 |
+
b, a, gj, gi = indices[i]
|
| 145 |
+
idx = (b == batch_idx)
|
| 146 |
+
b, a, gj, gi = b[idx], a[idx], gj[idx], gi[idx]
|
| 147 |
+
all_b.append(b)
|
| 148 |
+
all_a.append(a)
|
| 149 |
+
all_gj.append(gj)
|
| 150 |
+
all_gi.append(gi)
|
| 151 |
+
all_anch.append(anch[i][idx])
|
| 152 |
+
from_which_layer.append(torch.ones(size=(len(b),)) * i)
|
| 153 |
+
fg_pred = prediction[b, a, gj, gi]
|
| 154 |
+
p_obj.append(fg_pred[:, 4:5])
|
| 155 |
+
p_cls.append(fg_pred[:, 5:])
|
| 156 |
+
grid = torch.stack([gi, gj], dim=1).type_as(fg_pred)
|
| 157 |
+
pxy = (fg_pred[:, :2].sigmoid() * 2. - 0.5 + grid) * self.stride[i]
|
| 158 |
+
pwh = (fg_pred[:, 2:4].sigmoid() * 2) ** 2 * anch[i][idx] * self.stride[i]
|
| 159 |
+
pxywh = torch.cat([pxy, pwh], dim=-1)
|
| 160 |
+
pxyxy = self.xywh2xyxy(pxywh)
|
| 161 |
+
pxyxys.append(pxyxy)
|
| 162 |
+
pxyxys = torch.cat(pxyxys, dim=0)
|
| 163 |
+
if pxyxys.shape[0] == 0:
|
| 164 |
+
continue
|
| 165 |
+
p_obj = torch.cat(p_obj, dim=0)
|
| 166 |
+
p_cls = torch.cat(p_cls, dim=0)
|
| 167 |
+
from_which_layer = torch.cat(from_which_layer, dim=0)
|
| 168 |
+
all_b = torch.cat(all_b, dim=0)
|
| 169 |
+
all_a = torch.cat(all_a, dim=0)
|
| 170 |
+
all_gj = torch.cat(all_gj, dim=0)
|
| 171 |
+
all_gi = torch.cat(all_gi, dim=0)
|
| 172 |
+
all_anch = torch.cat(all_anch, dim=0)
|
| 173 |
+
pair_wise_iou = self.box_iou(txyxy, pxyxys)
|
| 174 |
+
pair_wise_iou_loss = -torch.log(pair_wise_iou + 1e-8)
|
| 175 |
+
top_k, _ = torch.topk(pair_wise_iou, min(20, pair_wise_iou.shape[1]), dim=1)
|
| 176 |
+
dynamic_ks = torch.clamp(top_k.sum(1).int(), min=1)
|
| 177 |
+
gt_cls_per_image = F.one_hot(this_target[:, 1].to(torch.int64), self.num_classes).float().unsqueeze(1).repeat(1, pxyxys.shape[0], 1)
|
| 178 |
+
num_gt = this_target.shape[0]
|
| 179 |
+
cls_preds_ = p_cls.float().unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_() * p_obj.unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_()
|
| 180 |
+
y = cls_preds_.sqrt_()
|
| 181 |
+
pair_wise_cls_loss = F.binary_cross_entropy_with_logits(torch.log(y / (1 - y)), gt_cls_per_image, reduction="none").sum(-1)
|
| 182 |
+
del cls_preds_
|
| 183 |
+
cost = (
|
| 184 |
+
pair_wise_cls_loss
|
| 185 |
+
+ 3.0 * pair_wise_iou_loss
|
| 186 |
+
)
|
| 187 |
+
matching_matrix = torch.zeros_like(cost)
|
| 188 |
+
for gt_idx in range(num_gt):
|
| 189 |
+
_, pos_idx = torch.topk(cost[gt_idx], k=dynamic_ks[gt_idx].item(), largest=False)
|
| 190 |
+
matching_matrix[gt_idx][pos_idx] = 1.0
|
| 191 |
+
del top_k, dynamic_ks
|
| 192 |
+
anchor_matching_gt = matching_matrix.sum(0)
|
| 193 |
+
if (anchor_matching_gt > 1).sum() > 0:
|
| 194 |
+
_, cost_argmin = torch.min(cost[:, anchor_matching_gt > 1], dim=0)
|
| 195 |
+
matching_matrix[:, anchor_matching_gt > 1] *= 0.0
|
| 196 |
+
matching_matrix[cost_argmin, anchor_matching_gt > 1] = 1.0
|
| 197 |
+
fg_mask_inboxes = matching_matrix.sum(0) > 0.0
|
| 198 |
+
matched_gt_inds = matching_matrix[:, fg_mask_inboxes].argmax(0)
|
| 199 |
+
from_which_layer = from_which_layer.to(fg_mask_inboxes.device)[fg_mask_inboxes]
|
| 200 |
+
all_b = all_b[fg_mask_inboxes]
|
| 201 |
+
all_a = all_a[fg_mask_inboxes]
|
| 202 |
+
all_gj = all_gj[fg_mask_inboxes]
|
| 203 |
+
all_gi = all_gi[fg_mask_inboxes]
|
| 204 |
+
all_anch = all_anch[fg_mask_inboxes]
|
| 205 |
+
this_target = this_target[matched_gt_inds]
|
| 206 |
+
for i in range(num_layer):
|
| 207 |
+
layer_idx = from_which_layer == i
|
| 208 |
+
matching_bs[i].append(all_b[layer_idx])
|
| 209 |
+
matching_as[i].append(all_a[layer_idx])
|
| 210 |
+
matching_gjs[i].append(all_gj[layer_idx])
|
| 211 |
+
matching_gis[i].append(all_gi[layer_idx])
|
| 212 |
+
matching_targets[i].append(this_target[layer_idx])
|
| 213 |
+
matching_anchs[i].append(all_anch[layer_idx])
|
| 214 |
+
for i in range(num_layer):
|
| 215 |
+
matching_bs[i] = torch.cat(matching_bs[i], dim=0) if len(matching_bs[i]) != 0 else torch.Tensor(matching_bs[i])
|
| 216 |
+
matching_as[i] = torch.cat(matching_as[i], dim=0) if len(matching_as[i]) != 0 else torch.Tensor(matching_as[i])
|
| 217 |
+
matching_gjs[i] = torch.cat(matching_gjs[i], dim=0) if len(matching_gjs[i]) != 0 else torch.Tensor(matching_gjs[i])
|
| 218 |
+
matching_gis[i] = torch.cat(matching_gis[i], dim=0) if len(matching_gis[i]) != 0 else torch.Tensor(matching_gis[i])
|
| 219 |
+
matching_targets[i] = torch.cat(matching_targets[i], dim=0) if len(matching_targets[i]) != 0 else torch.Tensor(matching_targets[i])
|
| 220 |
+
matching_anchs[i] = torch.cat(matching_anchs[i], dim=0) if len(matching_anchs[i]) != 0 else torch.Tensor(matching_anchs[i])
|
| 221 |
+
return matching_bs, matching_as, matching_gjs, matching_gis, matching_targets, matching_anchs
|
| 222 |
+
def find_3_positive(self, predictions, targets):
|
| 223 |
+
num_anchor, num_gt = len(self.anchors_mask[0]), targets.shape[0]
|
| 224 |
+
indices, anchors = [], []
|
| 225 |
+
gain = torch.ones(7, device=targets.device)
|
| 226 |
+
ai = torch.arange(num_anchor, device=targets.device).float().view(num_anchor, 1).repeat(1, num_gt)
|
| 227 |
+
targets = torch.cat((targets.repeat(num_anchor, 1, 1), ai[:, :, None]), 2)
|
| 228 |
+
g = 0.5
|
| 229 |
+
off = torch.tensor([
|
| 230 |
+
[0, 0],
|
| 231 |
+
[1, 0], [0, 1], [-1, 0], [0, -1],
|
| 232 |
+
], device=targets.device).float() * g
|
| 233 |
+
for i in range(len(predictions)):
|
| 234 |
+
anchors_i = torch.from_numpy(self.anchors[i] / self.stride[i]).type_as(predictions[i])
|
| 235 |
+
anchors_i, shape = torch.from_numpy(self.anchors[i] / self.stride[i]).type_as(predictions[i]), predictions[i].shape
|
| 236 |
+
gain[2:6] = torch.tensor(predictions[i].shape)[[3, 2, 3, 2]]
|
| 237 |
+
t = targets * gain
|
| 238 |
+
if num_gt:
|
| 239 |
+
r = t[:, :, 4:6] / anchors_i[:, None]
|
| 240 |
+
j = torch.max(r, 1. / r).max(2)[0] < self.threshold
|
| 241 |
+
t = t[j]
|
| 242 |
+
gxy = t[:, 2:4]
|
| 243 |
+
gxi = gain[[2, 3]] - gxy
|
| 244 |
+
j, k = ((gxy % 1. < g) & (gxy > 1.)).T
|
| 245 |
+
l, m = ((gxi % 1. < g) & (gxi > 1.)).T
|
| 246 |
+
j = torch.stack((torch.ones_like(j), j, k, l, m))
|
| 247 |
+
t = t.repeat((5, 1, 1))[j]
|
| 248 |
+
offsets = (torch.zeros_like(gxy)[None] + off[:, None])[j]
|
| 249 |
+
else:
|
| 250 |
+
t = targets[0]
|
| 251 |
+
offsets = 0
|
| 252 |
+
b, c = t[:, :2].long().T
|
| 253 |
+
gxy = t[:, 2:4]
|
| 254 |
+
gwh = t[:, 4:6]
|
| 255 |
+
gij = (gxy - offsets).long()
|
| 256 |
+
gi, gj = gij.T
|
| 257 |
+
a = t[:, 6].long()
|
| 258 |
+
indices.append((b, a, gj.clamp_(0, shape[2] - 1), gi.clamp_(0, shape[3] - 1)))
|
| 259 |
+
anchors.append(anchors_i[a])
|
| 260 |
+
return indices, anchors
|
| 261 |
+
def is_parallel(model):
|
| 262 |
+
return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
|
| 263 |
+
def de_parallel(model):
|
| 264 |
+
return model.module if is_parallel(model) else model
|
| 265 |
+
def copy_attr(a, b, include=(), exclude=()):
|
| 266 |
+
for k, v in b.__dict__.items():
|
| 267 |
+
if (len(include) and k not in include) or k.startswith('_') or k in exclude:
|
| 268 |
+
continue
|
| 269 |
+
else:
|
| 270 |
+
setattr(a, k, v)
|
| 271 |
+
class ModelEMA:
|
| 272 |
+
""" Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models
|
| 273 |
+
Keeps a moving average of everything in the model state_dict (parameters and buffers)
|
| 274 |
+
For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
|
| 275 |
+
"""
|
| 276 |
+
def __init__(self, model, decay=0.9999, tau=2000, updates=0):
|
| 277 |
+
self.ema = deepcopy(de_parallel(model)).eval()
|
| 278 |
+
self.updates = updates
|
| 279 |
+
self.decay = lambda x: decay * (1 - math.exp(-x / tau))
|
| 280 |
+
for p in self.ema.parameters():
|
| 281 |
+
p.requires_grad_(False)
|
| 282 |
+
def update(self, model):
|
| 283 |
+
with torch.no_grad():
|
| 284 |
+
self.updates += 1
|
| 285 |
+
d = self.decay(self.updates)
|
| 286 |
+
msd = de_parallel(model).state_dict()
|
| 287 |
+
for k, v in self.ema.state_dict().items():
|
| 288 |
+
if v.dtype.is_floating_point:
|
| 289 |
+
v *= d
|
| 290 |
+
v += (1 - d) * msd[k].detach()
|
| 291 |
+
def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
|
| 292 |
+
copy_attr(self.ema, model, include, exclude)
|
| 293 |
+
def weights_init(net, init_type='normal', init_gain = 0.02):
|
| 294 |
+
def init_func(m):
|
| 295 |
+
classname = m.__class__.__name__
|
| 296 |
+
if hasattr(m, 'weight') and classname.find('Conv') != -1:
|
| 297 |
+
if init_type == 'normal':
|
| 298 |
+
torch.nn.init.normal_(m.weight.data, 0.0, init_gain)
|
| 299 |
+
elif init_type == 'xavier':
|
| 300 |
+
torch.nn.init.xavier_normal_(m.weight.data, gain=init_gain)
|
| 301 |
+
elif init_type == 'kaiming':
|
| 302 |
+
torch.nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
|
| 303 |
+
elif init_type == 'orthogonal':
|
| 304 |
+
torch.nn.init.orthogonal_(m.weight.data, gain=init_gain)
|
| 305 |
+
else:
|
| 306 |
+
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
|
| 307 |
+
elif classname.find('BatchNorm2d') != -1:
|
| 308 |
+
torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
|
| 309 |
+
torch.nn.init.constant_(m.bias.data, 0.0)
|
| 310 |
+
print('initialize network with %s type' % init_type)
|
| 311 |
+
net.apply(init_func)
|
| 312 |
+
def get_lr_scheduler(lr_decay_type, lr, min_lr, total_iters, warmup_iters_ratio = 0.05, warmup_lr_ratio = 0.1, no_aug_iter_ratio = 0.05, step_num = 10):
|
| 313 |
+
def yolox_warm_cos_lr(lr, min_lr, total_iters, warmup_total_iters, warmup_lr_start, no_aug_iter, iters):
|
| 314 |
+
if iters <= warmup_total_iters:
|
| 315 |
+
lr = (lr - warmup_lr_start) * pow(iters / float(warmup_total_iters), 2
|
| 316 |
+
) + warmup_lr_start
|
| 317 |
+
elif iters >= total_iters - no_aug_iter:
|
| 318 |
+
lr = min_lr
|
| 319 |
+
else:
|
| 320 |
+
lr = min_lr + 0.5 * (lr - min_lr) * (
|
| 321 |
+
1.0
|
| 322 |
+
+ math.cos(
|
| 323 |
+
math.pi
|
| 324 |
+
* (iters - warmup_total_iters)
|
| 325 |
+
/ (total_iters - warmup_total_iters - no_aug_iter)
|
| 326 |
+
)
|
| 327 |
+
)
|
| 328 |
+
return lr
|
| 329 |
+
def step_lr(lr, decay_rate, step_size, iters):
|
| 330 |
+
if step_size < 1:
|
| 331 |
+
raise ValueError("step_size must above 1.")
|
| 332 |
+
n = iters // step_size
|
| 333 |
+
out_lr = lr * decay_rate ** n
|
| 334 |
+
return out_lr
|
| 335 |
+
if lr_decay_type == "cos":
|
| 336 |
+
warmup_total_iters = min(max(warmup_iters_ratio * total_iters, 1), 3)
|
| 337 |
+
warmup_lr_start = max(warmup_lr_ratio * lr, 1e-6)
|
| 338 |
+
no_aug_iter = min(max(no_aug_iter_ratio * total_iters, 1), 15)
|
| 339 |
+
func = partial(yolox_warm_cos_lr ,lr, min_lr, total_iters, warmup_total_iters, warmup_lr_start, no_aug_iter)
|
| 340 |
+
else:
|
| 341 |
+
decay_rate = (min_lr / lr) ** (1 / (step_num - 1))
|
| 342 |
+
step_size = total_iters / step_num
|
| 343 |
+
func = partial(step_lr, lr, decay_rate, step_size)
|
| 344 |
+
return func
|
| 345 |
+
def set_optimizer_lr(optimizer, lr_scheduler_func, epoch):
|
| 346 |
+
lr = lr_scheduler_func(epoch)
|
| 347 |
+
for param_group in optimizer.param_groups:
|
| 348 |
+
param_group['lr'] = lr
|