PolarisFTL commited on
Commit
c79402e
·
verified ·
1 Parent(s): efb567f

Add nets modules

Browse files
Files changed (5) hide show
  1. nets/Common.py +311 -0
  2. nets/__init__.py +1 -0
  3. nets/backbone.py +105 -0
  4. nets/model.py +121 -0
  5. nets/yolo_training.py +348 -0
nets/Common.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from thop import profile
4
+
5
+ class SiLU(nn.Module):
6
+ @staticmethod
7
+ def forward(x):
8
+ return x * torch.sigmoid(x)
9
+
10
+ def autopad(k, p=None):
11
+ if p is None:
12
+ p = k // 2 if isinstance(k, int) else [x // 2 for x in k]
13
+ return p
14
+
15
+ class Conv(nn.Module):
16
+ def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=nn.LeakyReLU(0.1, inplace=True)): # ch_in, ch_out, kernel, stride, padding, groups
17
+ super(Conv, self).__init__()
18
+ self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
19
+ self.bn = nn.BatchNorm2d(c2, eps=0.001, momentum=0.03)
20
+ self.act = nn.LeakyReLU(0.1, inplace=True) if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
21
+
22
+
23
+ def forward(self, x):
24
+ return self.act(self.bn(self.conv(x)))
25
+
26
+ def fuseforward(self, x):
27
+ return self.act(self.conv(x))
28
+
29
+ class BasicConv(nn.Module):
30
+ def __init__(
31
+ self,
32
+ in_planes,
33
+ out_planes,
34
+ kernel_size,
35
+ stride=1,
36
+ padding=0,
37
+ dilation=1,
38
+ groups=1,
39
+ relu=True,
40
+ bn=True,
41
+ bias=False,
42
+ ):
43
+ super(BasicConv, self).__init__()
44
+ self.out_channels = out_planes
45
+ self.conv = nn.Conv2d(
46
+ in_planes,
47
+ out_planes,
48
+ kernel_size=kernel_size,
49
+ stride=stride,
50
+ padding=padding,
51
+ dilation=dilation,
52
+ groups=groups,
53
+ bias=bias,
54
+ )
55
+ self.bn = (
56
+ nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True)
57
+ if bn
58
+ else None
59
+ )
60
+ self.relu = nn.ReLU() if relu else None
61
+
62
+ def forward(self, x):
63
+ x = self.conv(x)
64
+ if self.bn is not None:
65
+ x = self.bn(x)
66
+ if self.relu is not None:
67
+ x = self.relu(x)
68
+ return x
69
+
70
+
71
+ class ChannelPool(nn.Module):
72
+ def forward(self, x):
73
+ return torch.cat(
74
+ (torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1
75
+ )
76
+
77
+
78
+ class SpatialGate(nn.Module):
79
+ def __init__(self):
80
+ super(SpatialGate, self).__init__()
81
+ kernel_size = 7
82
+ self.compress = ChannelPool()
83
+ self.spatial = BasicConv(
84
+ 2, 1, kernel_size, stride=1, padding=(kernel_size - 1) // 2, relu=False
85
+ )
86
+
87
+ def forward(self, x):
88
+ x_compress = self.compress(x)
89
+ x_out = self.spatial(x_compress)
90
+ scale = torch.sigmoid_(x_out)
91
+ return x * scale
92
+
93
+ def autopad(k, p=None):
94
+ if p is None:
95
+ p = k // 2 if isinstance(k, int) else [x // 2 for x in k]
96
+ return p
97
+
98
+ class Conv(nn.Module):
99
+ def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=nn.LeakyReLU(0.1, inplace=True)): # ch_in, ch_out, kernel, stride, padding, groups
100
+ super(Conv, self).__init__()
101
+ self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
102
+ self.bn = nn.BatchNorm2d(c2, eps=0.001, momentum=0.03)
103
+ self.act = nn.LeakyReLU(0.1, inplace=True) if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
104
+
105
+
106
+ def forward(self, x):
107
+ return self.act(self.bn(self.conv(x)))
108
+ #lighting dehaze network
109
+ class LMDNet(nn.Module):
110
+
111
+ def __init__(self):
112
+ super(LMDNet, self).__init__()
113
+ # mainNet Architecture
114
+ self.AAM = nn.Sequential(
115
+ nn.Conv2d(64, 3, 1, 1),
116
+ nn.LeakyReLU(inplace=True),
117
+ nn.Dropout(0.5)
118
+ )
119
+ self.AAM_1 = nn.Sequential(
120
+ nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
121
+ nn.Conv2d(128, 32, 1, 1),
122
+ nn.LeakyReLU(inplace=True),
123
+ nn.Dropout(0.5)
124
+ )
125
+ self.AAM_2 = nn.Sequential(
126
+ nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True),
127
+ nn.Conv2d(256, 32, 1, 1),
128
+ nn.LeakyReLU(inplace=True),
129
+ nn.Dropout(0.5)
130
+ )
131
+ self.TA = TripletAttention(64)
132
+
133
+ self.conv = Conv(64, 3, 3, 1)
134
+
135
+ self.up4 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True)
136
+ self.relu = nn.LeakyReLU(0.1, inplace=True)
137
+
138
+
139
+ def forward(self, f1, f2, f3):
140
+ t = self.AAM(f1)
141
+ f2 = self.AAM_1(f2)
142
+ f3 = self.AAM_2(f3)
143
+ x1 = f1
144
+ x2 = torch.cat([f2, f3], dim=1)
145
+
146
+ x = x1 + x2
147
+ x = self.TA(x)
148
+ x = self.conv(x)
149
+
150
+ dehaze = ((x * t) - x + 1)
151
+
152
+ out = self.up4(dehaze)
153
+ out = self.relu(out)
154
+
155
+ return out
156
+
157
+ class TripletAttention(nn.Module):
158
+ def __init__(
159
+ self,
160
+ in_channels,
161
+ reduction_ratio=16,
162
+ pool_types=["avg", "max"],
163
+ no_spatial=False,
164
+ ):
165
+ super(TripletAttention, self).__init__()
166
+ self.ChannelGateH = SpatialGate()
167
+ self.ChannelGateW = SpatialGate()
168
+ self.no_spatial = no_spatial
169
+ if not no_spatial:
170
+ self.SpatialGate = SpatialGate()
171
+
172
+ def forward(self, x):
173
+ x_perm1 = x.permute(0, 2, 1, 3).contiguous()
174
+ x_out1 = self.ChannelGateH(x_perm1)
175
+ x_out11 = x_out1.permute(0, 2, 1, 3).contiguous()
176
+ x_perm2 = x.permute(0, 3, 2, 1).contiguous()
177
+ x_out2 = self.ChannelGateW(x_perm2)
178
+ x_out21 = x_out2.permute(0, 3, 2, 1).contiguous()
179
+ if not self.no_spatial:
180
+ x_out = self.SpatialGate(x)
181
+ x_out = (1 / 3) * (x_out + x_out11 + x_out21)
182
+ else:
183
+ x_out = (1 / 2) * (x_out11 + x_out21)
184
+ return x_out
185
+
186
+ class SiLU(nn.Module):
187
+ @staticmethod
188
+ def forward(x):
189
+ return x * torch.sigmoid(x)
190
+
191
+ def autopad(k, p=None):
192
+ if p is None:
193
+ p = k // 2 if isinstance(k, int) else [x // 2 for x in k]
194
+ return p
195
+
196
+ class Conv(nn.Module):
197
+ def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=nn.LeakyReLU(0.1, inplace=True)):
198
+ super(Conv, self).__init__()
199
+ self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
200
+ self.bn = nn.BatchNorm2d(c2, eps=0.001, momentum=0.03)
201
+ self.act = nn.LeakyReLU(0.1, inplace=True) if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
202
+
203
+
204
+ def forward(self, x):
205
+ return self.act(self.bn(self.conv(x)))
206
+
207
+ def fuseforward(self, x):
208
+ return self.act(self.conv(x))
209
+
210
+ class GIE(torch.nn.Module):
211
+ def __init__(self, epsilon=1e-8):
212
+ super(GIE, self).__init__()
213
+ self.epsilon = epsilon
214
+
215
+ def forward(self, x):
216
+ # Step 1: Pixel Mean Squared
217
+ x_mean = torch.mean(x, dim=(2, 3), keepdim=True)
218
+ epsilon = (x - x_mean) ** 2
219
+ # Step 2: Global Extraction
220
+ epsilon_mean = torch.mean(epsilon, dim=(2, 3), keepdim=False)
221
+ epsilon_mean += self.epsilon
222
+ # Step 3: Gamma Calculation (Global Extraction)
223
+ gamma_t_c = epsilon / epsilon_mean.unsqueeze(-1).unsqueeze(-1)
224
+ sigmoid_gamma = torch.sigmoid(gamma_t_c)
225
+ output = x * sigmoid_gamma
226
+ return output
227
+
228
+ # Multi-branch Pooling Information Fusion Module
229
+ class MPIF(nn.Module):
230
+ def __init__(self, c1, c2, c3, s=2, n=4, e=1, ids=[0]):
231
+ super(MPIF, self).__init__()
232
+ c_ = int(c2 * e)
233
+
234
+ self.ids = ids
235
+ if s == 1:
236
+ self.m1 = nn.MaxPool2d(kernel_size=3, stride=s, padding=1)
237
+ self.m2 = nn.AvgPool2d(kernel_size=3, stride=s, padding=1)
238
+ else:
239
+ self.m1 = nn.MaxPool2d(kernel_size=2, stride=s)
240
+ self.m2 = nn.AvgPool2d(kernel_size=2, stride=s)
241
+
242
+ self.cv1 = Conv(c1, c_, 1, 1)
243
+ self.cv2 = Conv(c1, c_, 1, 1)
244
+ self.cv3 = nn.ModuleList(
245
+ [Conv(c_ if i ==0 else c2, c2, 3, 1) for i in range(n)]
246
+ )
247
+ self.cv4 = Conv(c_ * 2 + c2 * (len(ids) - 2), c3, 1, 1)
248
+
249
+ self.GIE = GIE(c1)
250
+
251
+ def forward(self, x):
252
+ x1 = self.m1(x)
253
+ x2 = self.m2(x)
254
+ x = x1 + x2
255
+ x_1 = self.cv1(x)
256
+ x_1 = self.GIE(x_1)
257
+ x_2 = self.cv2(x)
258
+
259
+ x_all = [x_1, x_2]
260
+
261
+ for i in range(len(self.cv3)):
262
+ x_2 = self.cv3[i](x_2)
263
+ x_all.append(x_2)
264
+
265
+ out = self.cv4(torch.cat([x_all[id] for id in self.ids], 1))
266
+
267
+ return out
268
+
269
+ class DilatedConvNet(nn.Module):
270
+ def __init__(self, in_channels, out_channels, dilation, padding, kernel_size):
271
+ super(DilatedConvNet, self).__init__()
272
+ self.dilated_conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=1, padding=padding, dilation=dilation)
273
+ self.relu = nn.ReLU(inplace=False)
274
+
275
+ def forward(self, x):
276
+
277
+ x = self.dilated_conv(x)
278
+ x = self.relu(x)
279
+
280
+ return x
281
+
282
+ class SPPELAN(nn.Module):
283
+ def __init__(self, c1, c2, c3=16):
284
+ super().__init__()
285
+ self.c = c3
286
+ self.cv1 = Conv(c1, c3, 1, 1)
287
+ self.cv2 = DilatedConvNet(c3, c3, kernel_size=3, padding=1, dilation=1)
288
+ self.cv3 = DilatedConvNet(c3, c3, kernel_size=3, padding=2, dilation=2)
289
+ self.cv4 = DilatedConvNet(c3, c3, kernel_size=3, padding=3, dilation=3)
290
+ self.cv5 = Conv(4*c3, c2, 1, 1)
291
+
292
+ def forward(self, x):
293
+ y = [self.cv1(x)]
294
+ y.extend(m(y[-1]) for m in [self.cv2, self.cv3, self.cv4])
295
+ return self.cv5(torch.cat(y, 1))
296
+
297
+
298
+
299
+ def print_model_flops_and_params(model, inputs):
300
+ flops, params = profile(model, inputs=inputs)
301
+ print(f"FLOPs: {flops / 1e9:.2f} GFLOPs")
302
+ print(f"Parameters: {params / 1e6:.2f} M")
303
+
304
+
305
+ if __name__ == "__main__":
306
+ feat1 = torch.randn(1, 64, 160, 160)
307
+ feat2 = torch.randn(1, 128, 80, 80)
308
+ feat3 = torch.randn(1, 256, 40, 40)
309
+ encoder = LMDNet()
310
+ print_model_flops_and_params(encoder, (feat1, feat2, feat3))
311
+
nets/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ #
nets/backbone.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from nets.Common import GIE, LMDNet
4
+
5
+ def autopad(k, p=None):
6
+ if p is None:
7
+ p = k // 2 if isinstance(k, int) else [x // 2 for x in k]
8
+ return p
9
+
10
+ class Conv(nn.Module):
11
+ def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=nn.LeakyReLU(0.1, inplace=True)): # ch_in, ch_out, kernel, stride, padding, groups
12
+ super(Conv, self).__init__()
13
+ self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
14
+ self.bn = nn.BatchNorm2d(c2, eps=0.001, momentum=0.03)
15
+ self.act = nn.LeakyReLU(0.1, inplace=True) if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
16
+
17
+ def forward(self, x):
18
+ return self.act(self.bn(self.conv(x)))
19
+
20
+ def fuseforward(self, x):
21
+ return self.act(self.conv(x))
22
+
23
+ # Multi-branch Pooling Information Fusion Module (Multi_Concat_Block + MP)#
24
+ # ------------------------------------------------------------------------- #
25
+ class Multi_Concat_Block(nn.Module):
26
+ def __init__(self, c1, c2, c3, n=4, e=1, ids=[0]):
27
+ super(Multi_Concat_Block, self).__init__()
28
+ c_ = int(c2 * e)
29
+ self.ids = ids
30
+ self.cv1 = Conv(c1, c_, 1, 1)
31
+ self.cv2 = Conv(c1, c_, 1, 1)
32
+ self.cv3 = nn.ModuleList(
33
+ [Conv(c_ if i ==0 else c2, c2, 3, 1) for i in range(n)]
34
+ )
35
+ self.cv4 = Conv(c_ * 2 + c2 * (len(ids) - 2), c3, 1, 1)
36
+ self.GIE = GIE(c1)
37
+
38
+ def forward(self, x):
39
+ x_1 = self.cv1(x)
40
+ x_1 = self.GIE(x_1)
41
+ x_2 = self.cv2(x)
42
+ x_all = [x_1, x_2]
43
+
44
+ for i in range(len(self.cv3)):
45
+ x_2 = self.cv3[i](x_2)
46
+ x_all.append(x_2)
47
+
48
+ out = self.cv4(torch.cat([x_all[id] for id in self.ids], 1))
49
+ return out
50
+
51
+ class MP(nn.Module):
52
+ def __init__(self, k=2):
53
+ super(MP, self).__init__()
54
+ self.m1 = nn.MaxPool2d(kernel_size=k, stride=k)
55
+ self.m2 = nn.AvgPool2d(kernel_size=k, stride=k)
56
+
57
+ def forward(self, x):
58
+ x1 = self.m1(x)
59
+ x2 = self.m2(x)
60
+ return x1 + x2
61
+ # ------------------------------------------------------------------------- #
62
+
63
+ class Backbone(nn.Module):
64
+ def __init__(self, transition_channels, block_channels, n):
65
+ super().__init__()
66
+ ids = [-1, -2, -3, -4]
67
+
68
+ self.stem = Conv(3, transition_channels * 2, 3, 2)
69
+ self.dehze = LMDNet()
70
+ self.dark2 = nn.Sequential(
71
+ Conv(transition_channels * 2, transition_channels * 4, 3, 2),
72
+ Multi_Concat_Block(transition_channels * 4, block_channels * 2, transition_channels * 4, n=n, ids=ids),
73
+ )
74
+ self.dark3 = nn.Sequential(
75
+ MP(),
76
+ Multi_Concat_Block(transition_channels * 4, block_channels * 4, transition_channels * 8, n=n, ids=ids),
77
+ )
78
+ self.dark4 = nn.Sequential(
79
+ MP(),
80
+ Multi_Concat_Block(transition_channels * 8, block_channels * 8, transition_channels * 16, n=n, ids=ids),
81
+ )
82
+ self.dark5 = nn.Sequential(
83
+ MP(),
84
+ Multi_Concat_Block(transition_channels * 16, block_channels * 16, transition_channels * 32, n=n, ids=ids),
85
+ )
86
+
87
+ def forward(self, x):
88
+ if self.training:
89
+ x, clear_x = x.split((8, 8), dim=0)
90
+ x = self.stem(x)
91
+ x = self.dark2(x)
92
+ f1 = x
93
+ x = self.dark3(x)
94
+ feat1 = x
95
+ f2 = x
96
+ x = self.dark4(x)
97
+ feat2 = x
98
+ f3 = x
99
+ x = self.dark5(x)
100
+ feat3 = x
101
+ dehazing = self.dehze(f1, f2, f3)
102
+
103
+ if self.training:
104
+ return feat1, feat2, feat3, dehazing
105
+ return feat1, feat2, feat3
nets/model.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from nets.Common import Conv, SPPELAN
4
+ from nets.backbone import Backbone, Multi_Concat_Block
5
+
6
+ def fuse_conv_and_bn(conv, bn):
7
+ fusedconv = nn.Conv2d(conv.in_channels,
8
+ conv.out_channels,
9
+ kernel_size=conv.kernel_size,
10
+ stride=conv.stride,
11
+ padding=conv.padding,
12
+ groups=conv.groups,
13
+ bias=True).requires_grad_(False).to(conv.weight.device)
14
+
15
+ w_conv = conv.weight.clone().view(conv.out_channels, -1)
16
+ w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
17
+ fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape).detach())
18
+
19
+ b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias
20
+ b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
21
+ fusedconv.bias.copy_((torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn).detach())
22
+ return fusedconv
23
+
24
+ class MP(nn.Module):
25
+ def __init__(self, k=2):
26
+ super(MP, self).__init__()
27
+ self.m1 = nn.MaxPool2d(kernel_size=k, stride=k)
28
+ self.m2 = nn.AvgPool2d(kernel_size=k, stride=k)
29
+ self.up = nn.Upsample(scale_factor=2)
30
+
31
+ def forward(self, x):
32
+ x1 = self.m1(x)
33
+ x2 = self.m2(x)
34
+ return self.up(x1 + x2)
35
+
36
+ class YoloBody(nn.Module):
37
+ def __init__(self, anchors_mask, num_classes):
38
+ super(YoloBody, self).__init__()
39
+ transition_channels = 16
40
+ block_channels = 16
41
+ panet_channels = 16
42
+ e = 1
43
+ n = 2
44
+ ids = [-1, -2, -3, -4]
45
+
46
+ self.backbone = Backbone(transition_channels, block_channels, n)
47
+
48
+ self.upsample = nn.Upsample(scale_factor=2, mode="nearest")
49
+
50
+ self.sppelan = SPPELAN(transition_channels * 32, transition_channels * 16)
51
+ self.conv_for_P5 = Conv(transition_channels * 16, transition_channels * 8)
52
+ self.conv_for_feat2 = Conv(transition_channels * 16, transition_channels * 8)
53
+ self.conv3_for_upsample1 = Multi_Concat_Block(transition_channels * 16, panet_channels * 4, transition_channels * 8, e=e, n=n, ids=ids)
54
+
55
+ self.conv_for_P4 = Conv(transition_channels * 8, transition_channels * 4)
56
+ self.conv_for_feat1 = Conv(transition_channels * 8, transition_channels * 4)
57
+ self.conv3_for_upsample2 = Multi_Concat_Block(transition_channels * 8, panet_channels * 2, transition_channels * 4, e=e, n=n, ids=ids)
58
+
59
+ self.down_sample1 = Conv(transition_channels * 4, transition_channels * 8, k=3, s=2)
60
+ self.conv3_for_downsample1 = Multi_Concat_Block(transition_channels * 16, panet_channels * 4, transition_channels * 8, e=e, n=n, ids=ids)
61
+
62
+ self.down_sample2 = Conv(transition_channels * 8, transition_channels * 16, k=3, s=2)
63
+ self.conv3_for_downsample2 = Multi_Concat_Block(transition_channels * 32, panet_channels * 8, transition_channels * 16, e=e, n=n, ids=ids)
64
+
65
+ self.pf = MP()
66
+
67
+ self.rep_conv_1 = Conv(transition_channels * 4, transition_channels * 8, 3, 1)
68
+ self.rep_conv_2 = Conv(transition_channels * 8, transition_channels * 16, 3, 1)
69
+ self.rep_conv_3 = Conv(transition_channels * 16, transition_channels * 32, 3, 1)
70
+
71
+ self.yolo_head_P3 = nn.Conv2d(transition_channels * 8, len(anchors_mask[2]) * (5 + num_classes), 1)
72
+ self.yolo_head_P4 = nn.Conv2d(transition_channels * 16, len(anchors_mask[1]) * (5 + num_classes), 1)
73
+ self.yolo_head_P5 = nn.Conv2d(transition_channels * 32, len(anchors_mask[0]) * (5 + num_classes), 1)
74
+
75
+ def fuse(self):
76
+ print('Fusing layers... ')
77
+ for m in self.modules():
78
+ if type(m) is Conv and hasattr(m, 'bn'):
79
+ m.conv = fuse_conv_and_bn(m.conv, m.bn)
80
+ delattr(m, 'bn')
81
+ m.forward = m.fuseforward
82
+ return self
83
+
84
+ def forward(self, x):
85
+ if self.training:
86
+ feat1, feat2, feat3, dehazing = self.backbone.forward(x)
87
+ else:
88
+ feat1, feat2, feat3 = self.backbone.forward(x)
89
+
90
+ P5 = self.sppelan(feat3)
91
+ P5_conv = self.conv_for_P5(P5)
92
+ P5_upsample = self.upsample(P5_conv)
93
+ P4 = torch.cat([self.conv_for_feat2(feat2), P5_upsample], 1)
94
+ P4 = self.conv3_for_upsample1(P4)
95
+
96
+ P4_conv = self.conv_for_P4(P4)
97
+ P4_upsample = self.upsample(P4_conv)
98
+ P3 = torch.cat([self.conv_for_feat1(feat1), P4_upsample], 1)
99
+ P3 = self.conv3_for_upsample2(P3)
100
+
101
+ P3_downsample = self.down_sample1(P3)
102
+ P4 = torch.cat([P3_downsample, P4], 1)
103
+ P4 = self.conv3_for_downsample1(P4)
104
+ P4 = self.pf(P4)
105
+
106
+ P4_downsample = self.down_sample2(P4)
107
+ P5 = torch.cat([P4_downsample, P5], 1)
108
+ P5 = self.conv3_for_downsample2(P5)
109
+
110
+ P3 = self.rep_conv_1(P3)
111
+ P4 = self.rep_conv_2(P4)
112
+ P5 = self.rep_conv_3(P5)
113
+
114
+ out2 = self.yolo_head_P3(P3)
115
+ out1 = self.yolo_head_P4(P4)
116
+ out0 = self.yolo_head_P5(P5)
117
+
118
+ if self.training:
119
+ return [out0, out1, out2, dehazing]
120
+ else:
121
+ return [out0, out1, out2]
nets/yolo_training.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from copy import deepcopy
3
+ from functools import partial
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ def smooth_BCE(eps=0.1):
9
+ return 1.0 - 0.5 * eps, 0.5 * eps
10
+ class YOLOLoss(nn.Module):
11
+ def __init__(self, anchors, num_classes, input_shape, anchors_mask = [[6,7,8], [3,4,5], [0,1,2]], label_smoothing = 0):
12
+ super(YOLOLoss, self).__init__()
13
+ self.anchors = [anchors[mask] for mask in anchors_mask]
14
+ self.num_classes = num_classes
15
+ self.input_shape = input_shape
16
+ self.anchors_mask = anchors_mask
17
+ self.balance = [0.4, 1.0, 4]
18
+ self.stride = [32, 16, 8]
19
+ self.box_ratio = 0.05
20
+ self.obj_ratio = 1 * (input_shape[0] * input_shape[1]) / (640 ** 2)
21
+ self.cls_ratio = 0.5 * (num_classes / 80)
22
+ self.threshold = 4
23
+ self.cp, self.cn = smooth_BCE(eps=label_smoothing)
24
+ self.BCEcls, self.BCEobj, self.gr = nn.BCEWithLogitsLoss(), nn.BCEWithLogitsLoss(), 1
25
+ def bbox_iou(self, box1, box2, x1y1x2y2=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7):
26
+ box2 = box2.T
27
+ if x1y1x2y2:
28
+ b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3]
29
+ b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3]
30
+ else:
31
+ b1_x1, b1_x2 = box1[0] - box1[2] / 2, box1[0] + box1[2] / 2
32
+ b1_y1, b1_y2 = box1[1] - box1[3] / 2, box1[1] + box1[3] / 2
33
+ b2_x1, b2_x2 = box2[0] - box2[2] / 2, box2[0] + box2[2] / 2
34
+ b2_y1, b2_y2 = box2[1] - box2[3] / 2, box2[1] + box2[3] / 2
35
+ inter = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * \
36
+ (torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)).clamp(0)
37
+ w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
38
+ w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
39
+ union = w1 * h1 + w2 * h2 - inter + eps
40
+ iou = inter / union
41
+ if GIoU or DIoU or CIoU:
42
+ cw = torch.max(b1_x2, b2_x2) - torch.min(b1_x1, b2_x1)
43
+ ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1)
44
+ if CIoU or DIoU:
45
+ c2 = cw ** 2 + ch ** 2 + eps
46
+ rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 +
47
+ (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4
48
+ if DIoU:
49
+ return iou - rho2 / c2
50
+ elif CIoU:
51
+ v = (4 / math.pi ** 2) * torch.pow(torch.atan(w2 / h2) - torch.atan(w1 / h1), 2)
52
+ with torch.no_grad():
53
+ alpha = v / (v - iou + (1 + eps))
54
+ return iou - (rho2 / c2 + v * alpha)
55
+ else:
56
+ c_area = cw * ch + eps
57
+ return iou - (c_area - union) / c_area
58
+ else:
59
+ return iou
60
+ def __call__(self, predictions, targets, imgs):
61
+ for i in range(len(predictions)):
62
+ bs, _, h, w = predictions[i].size()
63
+ predictions[i] = predictions[i].view(bs, len(self.anchors_mask[i]), -1, h, w).permute(0, 1, 3, 4, 2).contiguous()
64
+ device = targets.device
65
+ cls_loss, box_loss, obj_loss = torch.zeros(1, device = device), torch.zeros(1, device = device), torch.zeros(1, device = device)
66
+ bs, as_, gjs, gis, targets, anchors = self.build_targets(predictions, targets, imgs)
67
+ feature_map_sizes = [torch.tensor(prediction.shape, device=device)[[3, 2, 3, 2]].type_as(prediction) for prediction in predictions]
68
+ for i, prediction in enumerate(predictions):
69
+ b, a, gj, gi = bs[i], as_[i], gjs[i], gis[i]
70
+ tobj = torch.zeros_like(prediction[..., 0], device=device)
71
+ n = b.shape[0]
72
+ if n:
73
+ prediction_pos = prediction[b, a, gj, gi]
74
+ grid = torch.stack([gi, gj], dim=1)
75
+ xy = prediction_pos[:, :2].sigmoid() * 2. - 0.5
76
+ wh = (prediction_pos[:, 2:4].sigmoid() * 2) ** 2 * anchors[i]
77
+ box = torch.cat((xy, wh), 1)
78
+ selected_tbox = targets[i][:, 2:6] * feature_map_sizes[i]
79
+ selected_tbox[:, :2] -= grid.type_as(prediction)
80
+ iou = self.bbox_iou(box.T, selected_tbox, x1y1x2y2=False, CIoU=True)
81
+ box_loss += (1.0 - iou).mean()
82
+ tobj[b, a, gj, gi] = (1.0 - self.gr) + self.gr * iou.detach().clamp(0).type(tobj.dtype)
83
+ selected_tcls = targets[i][:, 1].long()
84
+ t = torch.full_like(prediction_pos[:, 5:], self.cn, device=device)
85
+ t[range(n), selected_tcls] = self.cp
86
+ cls_loss += self.BCEcls(prediction_pos[:, 5:], t)
87
+ obj_loss += self.BCEobj(prediction[..., 4], tobj) * self.balance[i]
88
+ box_loss *= self.box_ratio
89
+ obj_loss *= self.obj_ratio
90
+ cls_loss *= self.cls_ratio
91
+ bs = tobj.shape[0]
92
+ loss = box_loss + obj_loss + cls_loss
93
+ return loss
94
+ def xywh2xyxy(self, x):
95
+ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
96
+ y[:, 0] = x[:, 0] - x[:, 2] / 2
97
+ y[:, 1] = x[:, 1] - x[:, 3] / 2
98
+ y[:, 2] = x[:, 0] + x[:, 2] / 2
99
+ y[:, 3] = x[:, 1] + x[:, 3] / 2
100
+ return y
101
+ def box_iou(self, box1, box2):
102
+ """
103
+ Return intersection-over-union (Jaccard index) of boxes.
104
+ Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
105
+ Arguments:
106
+ box1 (Tensor[N, 4])
107
+ box2 (Tensor[M, 4])
108
+ Returns:
109
+ iou (Tensor[N, M]): the NxM matrix containing the pairwise
110
+ IoU values for every element in boxes1 and boxes2
111
+ """
112
+ def box_area(box):
113
+ return (box[2] - box[0]) * (box[3] - box[1])
114
+ area1 = box_area(box1.T)
115
+ area2 = box_area(box2.T)
116
+ inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) - torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2)
117
+ return inter / (area1[:, None] + area2 - inter)
118
+ def build_targets(self, predictions, targets, imgs):
119
+ indices, anch = self.find_3_positive(predictions, targets)
120
+ matching_bs = [[] for _ in predictions]
121
+ matching_as = [[] for _ in predictions]
122
+ matching_gjs = [[] for _ in predictions]
123
+ matching_gis = [[] for _ in predictions]
124
+ matching_targets = [[] for _ in predictions]
125
+ matching_anchs = [[] for _ in predictions]
126
+ num_layer = len(predictions)
127
+ for batch_idx in range(predictions[0].shape[0]):
128
+ b_idx = targets[:, 0]==batch_idx
129
+ this_target = targets[b_idx]
130
+ if this_target.shape[0] == 0:
131
+ continue
132
+ txywh = this_target[:, 2:6] * imgs[batch_idx].shape[1]
133
+ txyxy = self.xywh2xyxy(txywh)
134
+ pxyxys = []
135
+ p_cls = []
136
+ p_obj = []
137
+ from_which_layer = []
138
+ all_b = []
139
+ all_a = []
140
+ all_gj = []
141
+ all_gi = []
142
+ all_anch = []
143
+ for i, prediction in enumerate(predictions):
144
+ b, a, gj, gi = indices[i]
145
+ idx = (b == batch_idx)
146
+ b, a, gj, gi = b[idx], a[idx], gj[idx], gi[idx]
147
+ all_b.append(b)
148
+ all_a.append(a)
149
+ all_gj.append(gj)
150
+ all_gi.append(gi)
151
+ all_anch.append(anch[i][idx])
152
+ from_which_layer.append(torch.ones(size=(len(b),)) * i)
153
+ fg_pred = prediction[b, a, gj, gi]
154
+ p_obj.append(fg_pred[:, 4:5])
155
+ p_cls.append(fg_pred[:, 5:])
156
+ grid = torch.stack([gi, gj], dim=1).type_as(fg_pred)
157
+ pxy = (fg_pred[:, :2].sigmoid() * 2. - 0.5 + grid) * self.stride[i]
158
+ pwh = (fg_pred[:, 2:4].sigmoid() * 2) ** 2 * anch[i][idx] * self.stride[i]
159
+ pxywh = torch.cat([pxy, pwh], dim=-1)
160
+ pxyxy = self.xywh2xyxy(pxywh)
161
+ pxyxys.append(pxyxy)
162
+ pxyxys = torch.cat(pxyxys, dim=0)
163
+ if pxyxys.shape[0] == 0:
164
+ continue
165
+ p_obj = torch.cat(p_obj, dim=0)
166
+ p_cls = torch.cat(p_cls, dim=0)
167
+ from_which_layer = torch.cat(from_which_layer, dim=0)
168
+ all_b = torch.cat(all_b, dim=0)
169
+ all_a = torch.cat(all_a, dim=0)
170
+ all_gj = torch.cat(all_gj, dim=0)
171
+ all_gi = torch.cat(all_gi, dim=0)
172
+ all_anch = torch.cat(all_anch, dim=0)
173
+ pair_wise_iou = self.box_iou(txyxy, pxyxys)
174
+ pair_wise_iou_loss = -torch.log(pair_wise_iou + 1e-8)
175
+ top_k, _ = torch.topk(pair_wise_iou, min(20, pair_wise_iou.shape[1]), dim=1)
176
+ dynamic_ks = torch.clamp(top_k.sum(1).int(), min=1)
177
+ gt_cls_per_image = F.one_hot(this_target[:, 1].to(torch.int64), self.num_classes).float().unsqueeze(1).repeat(1, pxyxys.shape[0], 1)
178
+ num_gt = this_target.shape[0]
179
+ cls_preds_ = p_cls.float().unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_() * p_obj.unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_()
180
+ y = cls_preds_.sqrt_()
181
+ pair_wise_cls_loss = F.binary_cross_entropy_with_logits(torch.log(y / (1 - y)), gt_cls_per_image, reduction="none").sum(-1)
182
+ del cls_preds_
183
+ cost = (
184
+ pair_wise_cls_loss
185
+ + 3.0 * pair_wise_iou_loss
186
+ )
187
+ matching_matrix = torch.zeros_like(cost)
188
+ for gt_idx in range(num_gt):
189
+ _, pos_idx = torch.topk(cost[gt_idx], k=dynamic_ks[gt_idx].item(), largest=False)
190
+ matching_matrix[gt_idx][pos_idx] = 1.0
191
+ del top_k, dynamic_ks
192
+ anchor_matching_gt = matching_matrix.sum(0)
193
+ if (anchor_matching_gt > 1).sum() > 0:
194
+ _, cost_argmin = torch.min(cost[:, anchor_matching_gt > 1], dim=0)
195
+ matching_matrix[:, anchor_matching_gt > 1] *= 0.0
196
+ matching_matrix[cost_argmin, anchor_matching_gt > 1] = 1.0
197
+ fg_mask_inboxes = matching_matrix.sum(0) > 0.0
198
+ matched_gt_inds = matching_matrix[:, fg_mask_inboxes].argmax(0)
199
+ from_which_layer = from_which_layer.to(fg_mask_inboxes.device)[fg_mask_inboxes]
200
+ all_b = all_b[fg_mask_inboxes]
201
+ all_a = all_a[fg_mask_inboxes]
202
+ all_gj = all_gj[fg_mask_inboxes]
203
+ all_gi = all_gi[fg_mask_inboxes]
204
+ all_anch = all_anch[fg_mask_inboxes]
205
+ this_target = this_target[matched_gt_inds]
206
+ for i in range(num_layer):
207
+ layer_idx = from_which_layer == i
208
+ matching_bs[i].append(all_b[layer_idx])
209
+ matching_as[i].append(all_a[layer_idx])
210
+ matching_gjs[i].append(all_gj[layer_idx])
211
+ matching_gis[i].append(all_gi[layer_idx])
212
+ matching_targets[i].append(this_target[layer_idx])
213
+ matching_anchs[i].append(all_anch[layer_idx])
214
+ for i in range(num_layer):
215
+ matching_bs[i] = torch.cat(matching_bs[i], dim=0) if len(matching_bs[i]) != 0 else torch.Tensor(matching_bs[i])
216
+ matching_as[i] = torch.cat(matching_as[i], dim=0) if len(matching_as[i]) != 0 else torch.Tensor(matching_as[i])
217
+ matching_gjs[i] = torch.cat(matching_gjs[i], dim=0) if len(matching_gjs[i]) != 0 else torch.Tensor(matching_gjs[i])
218
+ matching_gis[i] = torch.cat(matching_gis[i], dim=0) if len(matching_gis[i]) != 0 else torch.Tensor(matching_gis[i])
219
+ matching_targets[i] = torch.cat(matching_targets[i], dim=0) if len(matching_targets[i]) != 0 else torch.Tensor(matching_targets[i])
220
+ matching_anchs[i] = torch.cat(matching_anchs[i], dim=0) if len(matching_anchs[i]) != 0 else torch.Tensor(matching_anchs[i])
221
+ return matching_bs, matching_as, matching_gjs, matching_gis, matching_targets, matching_anchs
222
+ def find_3_positive(self, predictions, targets):
223
+ num_anchor, num_gt = len(self.anchors_mask[0]), targets.shape[0]
224
+ indices, anchors = [], []
225
+ gain = torch.ones(7, device=targets.device)
226
+ ai = torch.arange(num_anchor, device=targets.device).float().view(num_anchor, 1).repeat(1, num_gt)
227
+ targets = torch.cat((targets.repeat(num_anchor, 1, 1), ai[:, :, None]), 2)
228
+ g = 0.5
229
+ off = torch.tensor([
230
+ [0, 0],
231
+ [1, 0], [0, 1], [-1, 0], [0, -1],
232
+ ], device=targets.device).float() * g
233
+ for i in range(len(predictions)):
234
+ anchors_i = torch.from_numpy(self.anchors[i] / self.stride[i]).type_as(predictions[i])
235
+ anchors_i, shape = torch.from_numpy(self.anchors[i] / self.stride[i]).type_as(predictions[i]), predictions[i].shape
236
+ gain[2:6] = torch.tensor(predictions[i].shape)[[3, 2, 3, 2]]
237
+ t = targets * gain
238
+ if num_gt:
239
+ r = t[:, :, 4:6] / anchors_i[:, None]
240
+ j = torch.max(r, 1. / r).max(2)[0] < self.threshold
241
+ t = t[j]
242
+ gxy = t[:, 2:4]
243
+ gxi = gain[[2, 3]] - gxy
244
+ j, k = ((gxy % 1. < g) & (gxy > 1.)).T
245
+ l, m = ((gxi % 1. < g) & (gxi > 1.)).T
246
+ j = torch.stack((torch.ones_like(j), j, k, l, m))
247
+ t = t.repeat((5, 1, 1))[j]
248
+ offsets = (torch.zeros_like(gxy)[None] + off[:, None])[j]
249
+ else:
250
+ t = targets[0]
251
+ offsets = 0
252
+ b, c = t[:, :2].long().T
253
+ gxy = t[:, 2:4]
254
+ gwh = t[:, 4:6]
255
+ gij = (gxy - offsets).long()
256
+ gi, gj = gij.T
257
+ a = t[:, 6].long()
258
+ indices.append((b, a, gj.clamp_(0, shape[2] - 1), gi.clamp_(0, shape[3] - 1)))
259
+ anchors.append(anchors_i[a])
260
+ return indices, anchors
261
+ def is_parallel(model):
262
+ return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
263
+ def de_parallel(model):
264
+ return model.module if is_parallel(model) else model
265
+ def copy_attr(a, b, include=(), exclude=()):
266
+ for k, v in b.__dict__.items():
267
+ if (len(include) and k not in include) or k.startswith('_') or k in exclude:
268
+ continue
269
+ else:
270
+ setattr(a, k, v)
271
+ class ModelEMA:
272
+ """ Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models
273
+ Keeps a moving average of everything in the model state_dict (parameters and buffers)
274
+ For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
275
+ """
276
+ def __init__(self, model, decay=0.9999, tau=2000, updates=0):
277
+ self.ema = deepcopy(de_parallel(model)).eval()
278
+ self.updates = updates
279
+ self.decay = lambda x: decay * (1 - math.exp(-x / tau))
280
+ for p in self.ema.parameters():
281
+ p.requires_grad_(False)
282
+ def update(self, model):
283
+ with torch.no_grad():
284
+ self.updates += 1
285
+ d = self.decay(self.updates)
286
+ msd = de_parallel(model).state_dict()
287
+ for k, v in self.ema.state_dict().items():
288
+ if v.dtype.is_floating_point:
289
+ v *= d
290
+ v += (1 - d) * msd[k].detach()
291
+ def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
292
+ copy_attr(self.ema, model, include, exclude)
293
+ def weights_init(net, init_type='normal', init_gain = 0.02):
294
+ def init_func(m):
295
+ classname = m.__class__.__name__
296
+ if hasattr(m, 'weight') and classname.find('Conv') != -1:
297
+ if init_type == 'normal':
298
+ torch.nn.init.normal_(m.weight.data, 0.0, init_gain)
299
+ elif init_type == 'xavier':
300
+ torch.nn.init.xavier_normal_(m.weight.data, gain=init_gain)
301
+ elif init_type == 'kaiming':
302
+ torch.nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
303
+ elif init_type == 'orthogonal':
304
+ torch.nn.init.orthogonal_(m.weight.data, gain=init_gain)
305
+ else:
306
+ raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
307
+ elif classname.find('BatchNorm2d') != -1:
308
+ torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
309
+ torch.nn.init.constant_(m.bias.data, 0.0)
310
+ print('initialize network with %s type' % init_type)
311
+ net.apply(init_func)
312
+ def get_lr_scheduler(lr_decay_type, lr, min_lr, total_iters, warmup_iters_ratio = 0.05, warmup_lr_ratio = 0.1, no_aug_iter_ratio = 0.05, step_num = 10):
313
+ def yolox_warm_cos_lr(lr, min_lr, total_iters, warmup_total_iters, warmup_lr_start, no_aug_iter, iters):
314
+ if iters <= warmup_total_iters:
315
+ lr = (lr - warmup_lr_start) * pow(iters / float(warmup_total_iters), 2
316
+ ) + warmup_lr_start
317
+ elif iters >= total_iters - no_aug_iter:
318
+ lr = min_lr
319
+ else:
320
+ lr = min_lr + 0.5 * (lr - min_lr) * (
321
+ 1.0
322
+ + math.cos(
323
+ math.pi
324
+ * (iters - warmup_total_iters)
325
+ / (total_iters - warmup_total_iters - no_aug_iter)
326
+ )
327
+ )
328
+ return lr
329
+ def step_lr(lr, decay_rate, step_size, iters):
330
+ if step_size < 1:
331
+ raise ValueError("step_size must above 1.")
332
+ n = iters // step_size
333
+ out_lr = lr * decay_rate ** n
334
+ return out_lr
335
+ if lr_decay_type == "cos":
336
+ warmup_total_iters = min(max(warmup_iters_ratio * total_iters, 1), 3)
337
+ warmup_lr_start = max(warmup_lr_ratio * lr, 1e-6)
338
+ no_aug_iter = min(max(no_aug_iter_ratio * total_iters, 1), 15)
339
+ func = partial(yolox_warm_cos_lr ,lr, min_lr, total_iters, warmup_total_iters, warmup_lr_start, no_aug_iter)
340
+ else:
341
+ decay_rate = (min_lr / lr) ** (1 / (step_num - 1))
342
+ step_size = total_iters / step_num
343
+ func = partial(step_lr, lr, decay_rate, step_size)
344
+ return func
345
+ def set_optimizer_lr(optimizer, lr_scheduler_func, epoch):
346
+ lr = lr_scheduler_func(epoch)
347
+ for param_group in optimizer.param_groups:
348
+ param_group['lr'] = lr