PolarisFTL commited on
Commit
9d4e990
·
verified ·
1 Parent(s): c79402e

Add utils modules

Browse files
utils/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ #
utils/callbacks.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import os
3
+ import torch
4
+ import matplotlib
5
+ matplotlib.use('Agg')
6
+ import scipy.signal
7
+ from matplotlib import pyplot as plt
8
+ from torch.utils.tensorboard import SummaryWriter
9
+ import shutil
10
+ import numpy as np
11
+ from PIL import Image
12
+ from tqdm import tqdm
13
+ from .utils import cvtColor, preprocess_input, resize_image
14
+ from .utils_bbox import DecodeBox
15
+ from .utils_map import get_coco_map, get_map
16
+ class LossHistory():
17
+
18
+ def __init__(self, log_dir, model, input_shape):
19
+ self.log_dir = log_dir
20
+ self.losses = []
21
+ os.makedirs(self.log_dir)
22
+ self.writer = SummaryWriter(self.log_dir)
23
+ try:
24
+ dummy_input = torch.randn(2, 3, input_shape[0], input_shape[1])
25
+ self.writer.add_graph(model, dummy_input)
26
+ except:
27
+ pass
28
+ def append_loss(self, epoch, loss):
29
+ if not os.path.exists(self.log_dir):
30
+ os.makedirs(self.log_dir)
31
+ self.losses.append(loss)
32
+
33
+ with open(os.path.join(self.log_dir, "epoch_loss.txt"), 'a') as f:
34
+ f.write(str(loss))
35
+ f.write("\n")
36
+
37
+ self.writer.add_scalar('loss', loss, epoch)
38
+ self.loss_plot()
39
+
40
+ def loss_plot(self):
41
+ iters = range(len(self.losses))
42
+ plt.figure()
43
+ plt.plot(iters, self.losses, 'red', linewidth = 2, label='train loss')
44
+ try:
45
+ if len(self.losses) < 25:
46
+ num = 5
47
+ else:
48
+ num = 15
49
+
50
+ plt.plot(iters, scipy.signal.savgol_filter(self.losses, num, 3), 'green', linestyle = '--', linewidth = 2, label='smooth train loss')
51
+ except:
52
+ pass
53
+ plt.grid(True)
54
+ plt.xlabel('Epoch')
55
+ plt.ylabel('Loss')
56
+ plt.legend(loc="upper right")
57
+ plt.savefig(os.path.join(self.log_dir, "epoch_loss.png"))
58
+ plt.cla()
59
+ plt.close("all")
60
+
61
+ class EvalCallback():
62
+ def __init__(self, net, input_shape, anchors, anchors_mask, class_names, num_classes, val_lines, log_dir, cuda, \
63
+ map_out_path=".temp_map_out", max_boxes=100, confidence=0.05, nms_iou=0.5, letterbox_image=True, MINOVERLAP=0.5, eval_flag=True, period=1):
64
+ super(EvalCallback, self).__init__()
65
+ self.net = net
66
+ self.input_shape = input_shape
67
+ self.anchors = anchors
68
+ self.anchors_mask = anchors_mask
69
+ self.class_names = class_names
70
+ self.num_classes = num_classes
71
+ self.val_lines = val_lines
72
+ self.log_dir = log_dir
73
+ self.cuda = cuda
74
+ self.map_out_path = map_out_path
75
+ self.max_boxes = max_boxes
76
+ self.confidence = confidence
77
+ self.nms_iou = nms_iou
78
+ self.letterbox_image = letterbox_image
79
+ self.MINOVERLAP = MINOVERLAP
80
+ self.eval_flag = eval_flag
81
+ self.period = period
82
+ self.bbox_util = DecodeBox(self.anchors, self.num_classes, (self.input_shape[0], self.input_shape[1]), self.anchors_mask)
83
+ self.maps = [0]
84
+ self.epoches = [0]
85
+ if self.eval_flag:
86
+ with open(os.path.join(self.log_dir, "epoch_map.txt"), 'a') as f:
87
+ f.write(str(0))
88
+ f.write("\n")
89
+
90
+ def get_map_txt(self, image_id, image, class_names, map_out_path):
91
+ f = open(os.path.join(map_out_path, "detection-results/"+image_id+".txt"), "w", encoding='utf-8')
92
+ image_shape = np.array(np.shape(image)[0:2])
93
+ image = cvtColor(image)
94
+ image_data = resize_image(image, (self.input_shape[1], self.input_shape[0]), self.letterbox_image)
95
+ image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, dtype='float32')), (2, 0, 1)), 0)
96
+ with torch.no_grad():
97
+ images = torch.from_numpy(image_data)
98
+ if self.cuda:
99
+ images = images.cuda()
100
+ outputs = self.net(images)
101
+ outputs = self.bbox_util.decode_box(outputs)
102
+ results = self.bbox_util.non_max_suppression(torch.cat(outputs, 1), self.num_classes, self.input_shape,
103
+ image_shape, self.letterbox_image, conf_thres = self.confidence, nms_thres = self.nms_iou)
104
+ if results[0] is None:
105
+ return
106
+ top_label = np.array(results[0][:, 6], dtype = 'int32')
107
+ top_conf = results[0][:, 4] * results[0][:, 5]
108
+ top_boxes = results[0][:, :4]
109
+ top_100 = np.argsort(top_conf)[::-1][:self.max_boxes]
110
+ top_boxes = top_boxes[top_100]
111
+ top_conf = top_conf[top_100]
112
+ top_label = top_label[top_100]
113
+ for i, c in list(enumerate(top_label)):
114
+ predicted_class = self.class_names[int(c)]
115
+ box = top_boxes[i]
116
+ score = str(top_conf[i])
117
+ top, left, bottom, right = box
118
+ if predicted_class not in class_names:
119
+ continue
120
+ f.write("%s %s %s %s %s %s\n" % (predicted_class, score[:6], str(int(left)), str(int(top)), str(int(right)),str(int(bottom))))
121
+ f.close()
122
+ return
123
+
124
+ def on_epoch_end(self, epoch, model_eval):
125
+ if epoch % self.period == 0 and self.eval_flag:
126
+ self.net = model_eval
127
+ if not os.path.exists(self.map_out_path):
128
+ os.makedirs(self.map_out_path)
129
+ if not os.path.exists(os.path.join(self.map_out_path, "ground-truth")):
130
+ os.makedirs(os.path.join(self.map_out_path, "ground-truth"))
131
+ if not os.path.exists(os.path.join(self.map_out_path, "detection-results")):
132
+ os.makedirs(os.path.join(self.map_out_path, "detection-results"))
133
+ print("Get map.")
134
+ for annotation_line in tqdm(self.val_lines):
135
+ line = annotation_line.split()
136
+ image_id = os.path.basename(line[0]).split('.')[0]
137
+ image = Image.open(line[0])
138
+ gt_boxes = np.array([np.array(list(map(int,box.split(',')))) for box in line[1:]])
139
+ self.get_map_txt(image_id, image, self.class_names, self.map_out_path)
140
+ with open(os.path.join(self.map_out_path, "ground-truth/"+image_id+".txt"), "w") as new_f:
141
+ for box in gt_boxes:
142
+ left, top, right, bottom, obj = box
143
+ obj_name = self.class_names[obj]
144
+ new_f.write("%s %s %s %s %s\n" % (obj_name, left, top, right, bottom))
145
+ print("Calculate Map.")
146
+ try:
147
+ temp_map = get_coco_map(class_names = self.class_names, path = self.map_out_path)[1]
148
+ except:
149
+ temp_map = get_map(self.MINOVERLAP, False, path = self.map_out_path)
150
+ self.maps.append(temp_map)
151
+ self.epoches.append(epoch)
152
+ with open(os.path.join(self.log_dir, "epoch_map.txt"), 'a') as f:
153
+ f.write(str(temp_map))
154
+ f.write("\n")
155
+ plt.figure()
156
+ plt.plot(self.epoches, self.maps, 'red', linewidth = 2, label='train map')
157
+ plt.grid(True)
158
+ plt.xlabel('Epoch')
159
+ plt.ylabel('Map %s'%str(self.MINOVERLAP))
160
+ plt.title('A Map Curve')
161
+ plt.legend(loc="upper right")
162
+ plt.savefig(os.path.join(self.log_dir, "epoch_map.png"))
163
+ plt.cla()
164
+ plt.close("all")
165
+ print("Get map done.")
166
+ shutil.rmtree(self.map_out_path)
utils/dataloader.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from random import sample, shuffle
2
+ import cv2
3
+ import numpy as np
4
+ import torch
5
+ from PIL import Image
6
+ from torch.utils.data.dataset import Dataset
7
+ from utils.utils import cvtColor, preprocess_input
8
+
9
+ class YoloDataset(Dataset):
10
+ def __init__(self, annotation_lines, clean_lines, input_shape, num_classes, anchors, anchors_mask, epoch_length, train):
11
+ super(YoloDataset, self).__init__()
12
+ self.annotation_lines = annotation_lines
13
+ self.clean_lines = clean_lines
14
+ self.input_shape = input_shape
15
+ self.num_classes = num_classes
16
+ self.anchors = anchors
17
+ self.anchors_mask = anchors_mask
18
+ self.epoch_length = epoch_length
19
+ self.train = train
20
+ self.epoch_now = -1
21
+ self.length = len(self.annotation_lines)
22
+ self.bbox_attrs = 5 + num_classes
23
+
24
+ def __len__(self):
25
+ return self.length
26
+
27
+ def __getitem__(self, index):
28
+ index = index % self.length
29
+ image, box, clearimg= self.get_random_data(self.annotation_lines[index],self.clean_lines[index], self.input_shape, random = self.train)
30
+ image = np.transpose(preprocess_input(np.array(image, dtype=np.float32)), (2, 0, 1))
31
+ box = np.array(box, dtype=np.float32)
32
+ clearimg = np.transpose(preprocess_input(np.array(clearimg, dtype=np.float32)), (2, 0, 1))
33
+ nL = len(box)
34
+ labels_out = np.zeros((nL, 6))
35
+ if nL:
36
+ box[:, [0, 2]] = box[:, [0, 2]] / self.input_shape[1]
37
+ box[:, [1, 3]] = box[:, [1, 3]] / self.input_shape[0]
38
+ box[:, 2:4] = box[:, 2:4] - box[:, 0:2]
39
+ box[:, 0:2] = box[:, 0:2] + box[:, 2:4] / 2
40
+ labels_out[:, 1] = box[:, -1]
41
+ labels_out[:, 2:] = box[:, :4]
42
+ return image, labels_out, clearimg
43
+
44
+ def rand(self, a=0, b=1):
45
+ return np.random.rand()*(b-a) + a
46
+
47
+ def get_random_data(self, annotation_line,clean_line, input_shape, jitter=.3, hue=.1, sat=0.7, val=0.4, random=True):
48
+ line = annotation_line.split()
49
+ clearline = clean_line.split()
50
+ image = Image.open(line[0])
51
+ image = cvtColor(image)
52
+ clearimg = Image.open(clearline[0])
53
+ clearimg = cvtColor(clearimg)
54
+ iw, ih = image.size
55
+ h, w = input_shape
56
+ box = np.array([np.array(list(map(int,box.split(',')))) for box in line[1:]])
57
+ if not random:
58
+ scale = min(w/iw, h/ih)
59
+ nw = int(iw*scale)
60
+ nh = int(ih*scale)
61
+ dx = (w-nw)//2
62
+ dy = (h-nh)//2
63
+ image = image.resize((nw,nh), Image.BICUBIC)
64
+ new_image = Image.new('RGB', (w,h), (128,128,128))
65
+ new_image.paste(image, (dx, dy))
66
+ image_data = np.array(new_image, np.float32)
67
+ clearimg = clearimg.resize((nw, nh), Image.BICUBIC)
68
+ new_clearimg = Image.new('RGB', (w, h), (128, 128, 128))
69
+ new_clearimg.paste(clearimg, (dx, dy))
70
+ clear_image_data = np.array(new_clearimg, np.float32)
71
+ if len(box)>0:
72
+ np.random.shuffle(box)
73
+ box[:, [0,2]] = box[:, [0,2]]*nw/iw + dx
74
+ box[:, [1,3]] = box[:, [1,3]]*nh/ih + dy
75
+ box[:, 0:2][box[:, 0:2]<0] = 0
76
+ box[:, 2][box[:, 2]>w] = w
77
+ box[:, 3][box[:, 3]>h] = h
78
+ box_w = box[:, 2] - box[:, 0]
79
+ box_h = box[:, 3] - box[:, 1]
80
+ box = box[np.logical_and(box_w>1, box_h>1)]
81
+ return image_data, box, clear_image_data
82
+ new_ar = iw/ih * self.rand(1-jitter,1+jitter) / self.rand(1-jitter,1+jitter)
83
+ scale = self.rand(.25, 2)
84
+ if new_ar < 1:
85
+ nh = int(scale*h)
86
+ nw = int(nh*new_ar)
87
+ else:
88
+ nw = int(scale*w)
89
+ nh = int(nw/new_ar)
90
+ image = image.resize((nw,nh), Image.BICUBIC)
91
+ clearimg = clearimg.resize((nw, nh), Image.BICUBIC)
92
+ dx = int(self.rand(0, w-nw))
93
+ dy = int(self.rand(0, h-nh))
94
+ new_image = Image.new('RGB', (w,h), (128,128,128))
95
+ new_image.paste(image, (dx, dy))
96
+ image = new_image
97
+ new_clearimg = Image.new('RGB', (w, h), (128, 128, 128))
98
+ new_clearimg.paste(clearimg, (dx, dy))
99
+ clearimg = new_clearimg
100
+ flip = self.rand()<.5
101
+ if flip:
102
+ image = image.transpose(Image.FLIP_LEFT_RIGHT)
103
+ clearimg = clearimg.transpose(Image.FLIP_LEFT_RIGHT)
104
+ image_data = np.array(image, np.uint8)
105
+ clear_image_data = np.array(clearimg, np.uint8)
106
+ r = np.random.uniform(-1, 1, 3) * [hue, sat, val] + 1
107
+ hue, sat, val = cv2.split(cv2.cvtColor(image_data, cv2.COLOR_RGB2HSV))
108
+ dtype = image_data.dtype
109
+ hue1, sat1, val1 = cv2.split(cv2.cvtColor(clear_image_data, cv2.COLOR_RGB2HSV))
110
+ dtype1 = clear_image_data.dtype
111
+ x = np.arange(0, 256, dtype=r.dtype)
112
+ lut_hue = ((x * r[0]) % 180).astype(dtype)
113
+ lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)
114
+ lut_val = np.clip(x * r[2], 0, 255).astype(dtype)
115
+ x1 = np.arange(0, 256, dtype=r.dtype)
116
+ lut_hue1 = ((x1 * r[0]) % 180).astype(dtype)
117
+ lut_sat1 = np.clip(x1 * r[1], 0, 255).astype(dtype)
118
+ lut_val1 = np.clip(x1 * r[2], 0, 255).astype(dtype)
119
+ image_data = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val)))
120
+ image_data = cv2.cvtColor(image_data, cv2.COLOR_HSV2RGB)
121
+ clear_image_data = cv2.merge((cv2.LUT(hue1, lut_hue1), cv2.LUT(sat1, lut_sat1), cv2.LUT(val1, lut_val1)))
122
+ clear_image_data = cv2.cvtColor(clear_image_data, cv2.COLOR_HSV2RGB)
123
+ if len(box)>0:
124
+ np.random.shuffle(box)
125
+ box[:, [0,2]] = box[:, [0,2]]*nw/iw + dx
126
+ box[:, [1,3]] = box[:, [1,3]]*nh/ih + dy
127
+ if flip: box[:, [0,2]] = w - box[:, [2,0]]
128
+ box[:, 0:2][box[:, 0:2]<0] = 0
129
+ box[:, 2][box[:, 2]>w] = w
130
+ box[:, 3][box[:, 3]>h] = h
131
+ box_w = box[:, 2] - box[:, 0]
132
+ box_h = box[:, 3] - box[:, 1]
133
+ box = box[np.logical_and(box_w>1, box_h>1)]
134
+ return image_data, box, clear_image_data
135
+
136
+ def merge_bboxes(self, bboxes, cutx, cuty):
137
+ merge_bbox = []
138
+ for i in range(len(bboxes)):
139
+ for box in bboxes[i]:
140
+ tmp_box = []
141
+ x1, y1, x2, y2 = box[0], box[1], box[2], box[3]
142
+ if i == 0:
143
+ if y1 > cuty or x1 > cutx:
144
+ continue
145
+ if y2 >= cuty and y1 <= cuty:
146
+ y2 = cuty
147
+ if x2 >= cutx and x1 <= cutx:
148
+ x2 = cutx
149
+ if i == 1:
150
+ if y2 < cuty or x1 > cutx:
151
+ continue
152
+ if y2 >= cuty and y1 <= cuty:
153
+ y1 = cuty
154
+ if x2 >= cutx and x1 <= cutx:
155
+ x2 = cutx
156
+ if i == 2:
157
+ if y2 < cuty or x2 < cutx:
158
+ continue
159
+ if y2 >= cuty and y1 <= cuty:
160
+ y1 = cuty
161
+ if x2 >= cutx and x1 <= cutx:
162
+ x1 = cutx
163
+ if i == 3:
164
+ if y1 > cuty or x2 < cutx:
165
+ continue
166
+ if y2 >= cuty and y1 <= cuty:
167
+ y2 = cuty
168
+ if x2 >= cutx and x1 <= cutx:
169
+ x1 = cutx
170
+ tmp_box.append(x1)
171
+ tmp_box.append(y1)
172
+ tmp_box.append(x2)
173
+ tmp_box.append(y2)
174
+ tmp_box.append(box[-1])
175
+ merge_bbox.append(tmp_box)
176
+ return merge_bbox
177
+
178
+ def yolo_dataset_collate(batch):
179
+ images = []
180
+ bboxes = []
181
+ clearimg = []
182
+ for i, (img, box, clear) in enumerate(batch):
183
+ images.append(img)
184
+ box[:, 0] = i
185
+ bboxes.append(box)
186
+ clearimg.append(clear)
187
+ images = torch.from_numpy(np.array(images)).type(torch.FloatTensor)
188
+ bboxes = torch.from_numpy(np.concatenate(bboxes, 0)).type(torch.FloatTensor)
189
+ clearimg = torch.from_numpy(np.array(clearimg)).type(torch.FloatTensor)
190
+ return images, bboxes, clearimg
utils/utils.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import numpy as np
3
+ import torch
4
+ from PIL import Image
5
+
6
+ def cvtColor(image):
7
+ if len(np.shape(image)) == 3 and np.shape(image)[2] == 3:
8
+ return image
9
+ else:
10
+ image = image.convert('RGB')
11
+ return image
12
+
13
+ def resize_image(image, size, letterbox_image):
14
+ iw, ih = image.size
15
+ w, h = size
16
+ if letterbox_image:
17
+ scale = min(w/iw, h/ih)
18
+ nw = int(iw*scale)
19
+ nh = int(ih*scale)
20
+ image = image.resize((nw,nh), Image.BICUBIC)
21
+ new_image = Image.new('RGB', size, (128,128,128))
22
+ new_image.paste(image, ((w-nw)//2, (h-nh)//2))
23
+ else:
24
+ new_image = image.resize((w, h), Image.BICUBIC)
25
+ return new_image
26
+
27
+ def get_classes(classes_path):
28
+ with open(classes_path, encoding='utf-8') as f:
29
+ class_names = f.readlines()
30
+ class_names = [c.strip() for c in class_names]
31
+ return class_names, len(class_names)
32
+
33
+ def get_anchors(anchors_path):
34
+ '''loads the anchors from a file'''
35
+ with open(anchors_path, encoding='utf-8') as f:
36
+ anchors = f.readline()
37
+ anchors = [float(x) for x in anchors.split(',')]
38
+ anchors = np.array(anchors).reshape(-1, 2)
39
+ return anchors, len(anchors)
40
+
41
+ def get_lr(optimizer):
42
+ for param_group in optimizer.param_groups:
43
+ return param_group['lr']
44
+
45
+ def seed_everything(seed=11):
46
+ random.seed(seed)
47
+ np.random.seed(seed)
48
+ torch.manual_seed(seed)
49
+ torch.cuda.manual_seed(seed)
50
+ torch.cuda.manual_seed_all(seed)
51
+ torch.backends.cudnn.deterministic = True
52
+ torch.backends.cudnn.benchmark = False
53
+
54
+ def worker_init_fn(worker_id, rank, seed):
55
+ worker_seed = rank + seed
56
+ random.seed(worker_seed)
57
+ np.random.seed(worker_seed)
58
+ torch.manual_seed(worker_seed)
59
+
60
+ def preprocess_input(image):
61
+ image /= 255.0
62
+ return image
63
+
64
+ def show_config(**kwargs):
65
+ print('Configurations:')
66
+ print('-' * 70)
67
+ print('|%25s | %40s|' % ('keys', 'values'))
68
+ print('-' * 70)
69
+ for key, value in kwargs.items():
70
+ print('|%25s | %40s|' % (str(key), str(value)))
71
+ print('-' * 70)
utils/utils_bbox.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from torchvision.ops import nms
4
+
5
+ class DecodeBox():
6
+ def __init__(self, anchors, num_classes, input_shape, anchors_mask = [[6,7,8], [3,4,5], [0,1,2]]):
7
+ super(DecodeBox, self).__init__()
8
+ self.anchors = anchors
9
+ self.num_classes = num_classes
10
+ self.bbox_attrs = 5 + num_classes
11
+ self.input_shape = input_shape
12
+ self.anchors_mask = anchors_mask
13
+ def decode_box(self, inputs):
14
+ outputs = []
15
+ detect = [inputs[0],inputs[1],inputs[2]]
16
+ for i, input in enumerate(detect):
17
+ batch_size = input.size(0)
18
+ input_height = input.size(2)
19
+ input_width = input.size(3)
20
+ stride_h = self.input_shape[0] / input_height
21
+ stride_w = self.input_shape[1] / input_width
22
+ scaled_anchors = [(anchor_width / stride_w, anchor_height / stride_h) for anchor_width, anchor_height in self.anchors[self.anchors_mask[i]]]
23
+ prediction = input.view(batch_size, len(self.anchors_mask[i]),
24
+ self.bbox_attrs, input_height, input_width).permute(0, 1, 3, 4, 2).contiguous()
25
+ x = torch.sigmoid(prediction[..., 0])
26
+ y = torch.sigmoid(prediction[..., 1])
27
+ w = torch.sigmoid(prediction[..., 2])
28
+ h = torch.sigmoid(prediction[..., 3])
29
+ conf = torch.sigmoid(prediction[..., 4])
30
+ pred_cls = torch.sigmoid(prediction[..., 5:])
31
+ FloatTensor = torch.cuda.FloatTensor if x.is_cuda else torch.FloatTensor
32
+ LongTensor = torch.cuda.LongTensor if x.is_cuda else torch.LongTensor
33
+ grid_x = torch.linspace(0, input_width - 1, input_width).repeat(input_height, 1).repeat(
34
+ batch_size * len(self.anchors_mask[i]), 1, 1).view(x.shape).type(FloatTensor)
35
+ grid_y = torch.linspace(0, input_height - 1, input_height).repeat(input_width, 1).t().repeat(
36
+ batch_size * len(self.anchors_mask[i]), 1, 1).view(y.shape).type(FloatTensor)
37
+ anchor_w = FloatTensor(scaled_anchors).index_select(1, LongTensor([0]))
38
+ anchor_h = FloatTensor(scaled_anchors).index_select(1, LongTensor([1]))
39
+ anchor_w = anchor_w.repeat(batch_size, 1).repeat(1, 1, input_height * input_width).view(w.shape)
40
+ anchor_h = anchor_h.repeat(batch_size, 1).repeat(1, 1, input_height * input_width).view(h.shape)
41
+ pred_boxes = FloatTensor(prediction[..., :4].shape)
42
+ pred_boxes[..., 0] = x.data * 2. - 0.5 + grid_x
43
+ pred_boxes[..., 1] = y.data * 2. - 0.5 + grid_y
44
+ pred_boxes[..., 2] = (w.data * 2) ** 2 * anchor_w
45
+ pred_boxes[..., 3] = (h.data * 2) ** 2 * anchor_h
46
+ _scale = torch.Tensor([input_width, input_height, input_width, input_height]).type(FloatTensor)
47
+ output = torch.cat((pred_boxes.view(batch_size, -1, 4) / _scale,
48
+ conf.view(batch_size, -1, 1), pred_cls.view(batch_size, -1, self.num_classes)), -1)
49
+ outputs.append(output.data)
50
+ return outputs
51
+ def yolo_correct_boxes(self, box_xy, box_wh, input_shape, image_shape, letterbox_image):
52
+ box_yx = box_xy[..., ::-1]
53
+ box_hw = box_wh[..., ::-1]
54
+ input_shape = np.array(input_shape)
55
+ image_shape = np.array(image_shape)
56
+ if letterbox_image:
57
+ new_shape = np.round(image_shape * np.min(input_shape/image_shape))
58
+ offset = (input_shape - new_shape)/2./input_shape
59
+ scale = input_shape/new_shape
60
+ box_yx = (box_yx - offset) * scale
61
+ box_hw *= scale
62
+ box_mins = box_yx - (box_hw / 2.)
63
+ box_maxes = box_yx + (box_hw / 2.)
64
+ boxes = np.concatenate([box_mins[..., 0:1], box_mins[..., 1:2], box_maxes[..., 0:1], box_maxes[..., 1:2]], axis=-1)
65
+ boxes *= np.concatenate([image_shape, image_shape], axis=-1)
66
+ return boxes
67
+ def non_max_suppression(self, prediction, num_classes, input_shape, image_shape, letterbox_image, conf_thres=0.5, nms_thres=0.4):
68
+ box_corner = prediction.new(prediction.shape)
69
+ box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2
70
+ box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2
71
+ box_corner[:, :, 2] = prediction[:, :, 0] + prediction[:, :, 2] / 2
72
+ box_corner[:, :, 3] = prediction[:, :, 1] + prediction[:, :, 3] / 2
73
+ prediction[:, :, :4] = box_corner[:, :, :4]
74
+ output = [None for _ in range(len(prediction))]
75
+ for i, image_pred in enumerate(prediction):
76
+ class_conf, class_pred = torch.max(image_pred[:, 5:5 + num_classes], 1, keepdim=True)
77
+ conf_mask = (image_pred[:, 4] * class_conf[:, 0] >= conf_thres).squeeze()
78
+ image_pred = image_pred[conf_mask]
79
+ class_conf = class_conf[conf_mask]
80
+ class_pred = class_pred[conf_mask]
81
+ if not image_pred.size(0):
82
+ continue
83
+ detections = torch.cat((image_pred[:, :5], class_conf.float(), class_pred.float()), 1)
84
+ unique_labels = detections[:, -1].cpu().unique()
85
+ if prediction.is_cuda:
86
+ unique_labels = unique_labels.cuda()
87
+ detections = detections.cuda()
88
+ for c in unique_labels:
89
+ detections_class = detections[detections[:, -1] == c]
90
+ keep = nms(
91
+ detections_class[:, :4],
92
+ detections_class[:, 4] * detections_class[:, 5],
93
+ nms_thres
94
+ )
95
+ max_detections = detections_class[keep]
96
+ output[i] = max_detections if output[i] is None else torch.cat((output[i], max_detections))
97
+ if output[i] is not None:
98
+ output[i] = output[i].cpu().numpy()
99
+ box_xy, box_wh = (output[i][:, 0:2] + output[i][:, 2:4])/2, output[i][:, 2:4] - output[i][:, 0:2]
100
+ output[i][:, :4] = self.yolo_correct_boxes(box_xy, box_wh, input_shape, image_shape, letterbox_image)
101
+ return output
102
+
103
+ class DecodeBoxNP():
104
+ def __init__(self, anchors, num_classes, input_shape, anchors_mask = [[6,7,8], [3,4,5], [0,1,2]]):
105
+ super(DecodeBoxNP, self).__init__()
106
+ self.anchors = anchors
107
+ self.num_classes = num_classes
108
+ self.bbox_attrs = 5 + num_classes
109
+ self.input_shape = input_shape
110
+ self.anchors_mask = anchors_mask
111
+ def sigmoid(self, x):
112
+ return 1 / (1 + np.exp(-x))
113
+ def decode_box(self, inputs):
114
+ outputs = []
115
+ for i, input in enumerate(inputs):
116
+ batch_size = np.shape(input)[0]
117
+ input_height = np.shape(input)[2]
118
+ input_width = np.shape(input)[3]
119
+ stride_h = self.input_shape[0] / input_height
120
+ stride_w = self.input_shape[1] / input_width
121
+ scaled_anchors = [(anchor_width / stride_w, anchor_height / stride_h) for anchor_width, anchor_height in self.anchors[self.anchors_mask[i]]]
122
+ prediction = np.transpose(np.reshape(input, (batch_size, len(self.anchors_mask[i]), self.bbox_attrs, input_height, input_width)), (0, 1, 3, 4, 2))
123
+ x = self.sigmoid(prediction[..., 0])
124
+ y = self.sigmoid(prediction[..., 1])
125
+ w = self.sigmoid(prediction[..., 2])
126
+ h = self.sigmoid(prediction[..., 3])
127
+ conf = self.sigmoid(prediction[..., 4])
128
+ pred_cls = self.sigmoid(prediction[..., 5:])
129
+ grid_x = np.repeat(np.expand_dims(np.repeat(np.expand_dims(np.linspace(0, input_width - 1, input_width), 0), input_height, axis=0), 0), batch_size * len(self.anchors_mask[i]), axis=0)
130
+ grid_x = np.reshape(grid_x, np.shape(x))
131
+ grid_y = np.repeat(np.expand_dims(np.repeat(np.expand_dims(np.linspace(0, input_height - 1, input_height), 0), input_width, axis=0).T, 0), batch_size * len(self.anchors_mask[i]), axis=0)
132
+ grid_y = np.reshape(grid_y, np.shape(y))
133
+ anchor_w = np.repeat(np.expand_dims(np.repeat(np.expand_dims(np.array(scaled_anchors)[:, 0], 0), batch_size, axis=0), -1), input_height * input_width, axis=-1)
134
+ anchor_h = np.repeat(np.expand_dims(np.repeat(np.expand_dims(np.array(scaled_anchors)[:, 1], 0), batch_size, axis=0), -1), input_height * input_width, axis=-1)
135
+ anchor_w = np.reshape(anchor_w, np.shape(w))
136
+ anchor_h = np.reshape(anchor_h, np.shape(h))
137
+ pred_boxes = np.zeros(np.shape(prediction[..., :4]))
138
+ pred_boxes[..., 0] = x * 2. - 0.5 + grid_x
139
+ pred_boxes[..., 1] = y * 2. - 0.5 + grid_y
140
+ pred_boxes[..., 2] = (w * 2) ** 2 * anchor_w
141
+ pred_boxes[..., 3] = (h * 2) ** 2 * anchor_h
142
+ _scale = np.array([input_width, input_height, input_width, input_height])
143
+ output = np.concatenate([np.reshape(pred_boxes, (batch_size, -1, 4)) / _scale,
144
+ np.reshape(conf, (batch_size, -1, 1)), np.reshape(pred_cls, (batch_size, -1, self.num_classes))], -1)
145
+ outputs.append(output)
146
+ return outputs
147
+
148
+ def bbox_iou(self, box1, box2, x1y1x2y2=True):
149
+ if not x1y1x2y2:
150
+ b1_x1, b1_x2 = box1[:, 0] - box1[:, 2] / 2, box1[:, 0] + box1[:, 2] / 2
151
+ b1_y1, b1_y2 = box1[:, 1] - box1[:, 3] / 2, box1[:, 1] + box1[:, 3] / 2
152
+ b2_x1, b2_x2 = box2[:, 0] - box2[:, 2] / 2, box2[:, 0] + box2[:, 2] / 2
153
+ b2_y1, b2_y2 = box2[:, 1] - box2[:, 3] / 2, box2[:, 1] + box2[:, 3] / 2
154
+ else:
155
+ b1_x1, b1_y1, b1_x2, b1_y2 = box1[:, 0], box1[:, 1], box1[:, 2], box1[:, 3]
156
+ b2_x1, b2_y1, b2_x2, b2_y2 = box2[:, 0], box2[:, 1], box2[:, 2], box2[:, 3]
157
+ inter_rect_x1 = np.maximum(b1_x1, b2_x1)
158
+ inter_rect_y1 = np.maximum(b1_y1, b2_y1)
159
+ inter_rect_x2 = np.minimum(b1_x2, b2_x2)
160
+ inter_rect_y2 = np.minimum(b1_y2, b2_y2)
161
+ inter_area = np.maximum(inter_rect_x2 - inter_rect_x1, 0) * \
162
+ np.maximum(inter_rect_y2 - inter_rect_y1, 0)
163
+ b1_area = (b1_x2 - b1_x1) * (b1_y2 - b1_y1)
164
+ b2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1)
165
+ iou = inter_area / np.maximum(b1_area + b2_area - inter_area, 1e-6)
166
+ return iou
167
+
168
+ def yolo_correct_boxes(self, box_xy, box_wh, input_shape, image_shape, letterbox_image):
169
+ box_yx = box_xy[..., ::-1]
170
+ box_hw = box_wh[..., ::-1]
171
+ input_shape = np.array(input_shape)
172
+ image_shape = np.array(image_shape)
173
+ if letterbox_image:
174
+ new_shape = np.round(image_shape * np.min(input_shape/image_shape))
175
+ offset = (input_shape - new_shape)/2./input_shape
176
+ scale = input_shape/new_shape
177
+ box_yx = (box_yx - offset) * scale
178
+ box_hw *= scale
179
+ box_mins = box_yx - (box_hw / 2.)
180
+ box_maxes = box_yx + (box_hw / 2.)
181
+ boxes = np.concatenate([box_mins[..., 0:1], box_mins[..., 1:2], box_maxes[..., 0:1], box_maxes[..., 1:2]], axis=-1)
182
+ boxes *= np.concatenate([image_shape, image_shape], axis=-1)
183
+ return boxes
184
+
185
+ def non_max_suppression(self, prediction, num_classes, input_shape, image_shape, letterbox_image, conf_thres=0.5, nms_thres=0.4):
186
+ box_corner = np.zeros_like(prediction)
187
+ box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2
188
+ box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2
189
+ box_corner[:, :, 2] = prediction[:, :, 0] + prediction[:, :, 2] / 2
190
+ box_corner[:, :, 3] = prediction[:, :, 1] + prediction[:, :, 3] / 2
191
+ prediction[:, :, :4] = box_corner[:, :, :4]
192
+ output = [None for _ in range(len(prediction))]
193
+ for i, image_pred in enumerate(prediction):
194
+ class_conf = np.max(image_pred[:, 5:5 + num_classes], 1, keepdims=True)
195
+ class_pred = np.expand_dims(np.argmax(image_pred[:, 5:5 + num_classes], 1), -1)
196
+ conf_mask = np.squeeze((image_pred[:, 4] * class_conf[:, 0] >= conf_thres))
197
+ image_pred = image_pred[conf_mask]
198
+ class_conf = class_conf[conf_mask]
199
+ class_pred = class_pred[conf_mask]
200
+ if not np.shape(image_pred)[0]:
201
+ continue
202
+ detections = np.concatenate((image_pred[:, :5], class_conf, class_pred), 1)
203
+ unique_labels = np.unique(detections[:, -1])
204
+ for c in unique_labels:
205
+ detections_class = detections[detections[:, -1] == c]
206
+ conf_sort_index = np.argsort(detections_class[:, 4] * detections_class[:, 5])[::-1]
207
+ detections_class = detections_class[conf_sort_index]
208
+ max_detections = []
209
+ while np.shape(detections_class)[0]:
210
+ max_detections.append(detections_class[0:1])
211
+ if len(detections_class) == 1:
212
+ break
213
+ ious = self.bbox_iou(max_detections[-1], detections_class[1:])
214
+ detections_class = detections_class[1:][ious < nms_thres]
215
+ max_detections = np.concatenate(max_detections, 0)
216
+ output[i] = max_detections if output[i] is None else np.concatenate((output[i], max_detections))
217
+ if output[i] is not None:
218
+ output[i] = output[i]
219
+ box_xy, box_wh = (output[i][:, 0:2] + output[i][:, 2:4])/2, output[i][:, 2:4] - output[i][:, 0:2]
220
+ output[i][:, :4] = self.yolo_correct_boxes(box_xy, box_wh, input_shape, image_shape, letterbox_image)
221
+ return output
222
+
223
+ if __name__ == "__main__":
224
+
225
+ import matplotlib.pyplot as plt
226
+ import numpy as np
227
+
228
+ def get_anchors_and_decode(input, input_shape, anchors, anchors_mask, num_classes):
229
+ batch_size = input.size(0)
230
+ input_height = input.size(2)
231
+ input_width = input.size(3)
232
+ stride_h = input_shape[0] / input_height
233
+ stride_w = input_shape[1] / input_width
234
+ scaled_anchors = [(anchor_width / stride_w, anchor_height / stride_h) for anchor_width, anchor_height in anchors[anchors_mask[2]]]
235
+ prediction = input.view(batch_size, len(anchors_mask[2]),
236
+ num_classes + 5, input_height, input_width).permute(0, 1, 3, 4, 2).contiguous()
237
+ x = torch.sigmoid(prediction[..., 0])
238
+ y = torch.sigmoid(prediction[..., 1])
239
+ w = torch.sigmoid(prediction[..., 2])
240
+ h = torch.sigmoid(prediction[..., 3])
241
+ conf = torch.sigmoid(prediction[..., 4])
242
+ pred_cls = torch.sigmoid(prediction[..., 5:])
243
+ FloatTensor = torch.cuda.FloatTensor if x.is_cuda else torch.FloatTensor
244
+ LongTensor = torch.cuda.LongTensor if x.is_cuda else torch.LongTensor
245
+ grid_x = torch.linspace(0, input_width - 1, input_width).repeat(input_height, 1).repeat(
246
+ batch_size * len(anchors_mask[2]), 1, 1).view(x.shape).type(FloatTensor)
247
+ grid_y = torch.linspace(0, input_height - 1, input_height).repeat(input_width, 1).t().repeat(
248
+ batch_size * len(anchors_mask[2]), 1, 1).view(y.shape).type(FloatTensor)
249
+ anchor_w = FloatTensor(scaled_anchors).index_select(1, LongTensor([0]))
250
+ anchor_h = FloatTensor(scaled_anchors).index_select(1, LongTensor([1]))
251
+ anchor_w = anchor_w.repeat(batch_size, 1).repeat(1, 1, input_height * input_width).view(w.shape)
252
+ anchor_h = anchor_h.repeat(batch_size, 1).repeat(1, 1, input_height * input_width).view(h.shape)
253
+ pred_boxes = FloatTensor(prediction[..., :4].shape)
254
+ pred_boxes[..., 0] = x.data * 2. - 0.5 + grid_x
255
+ pred_boxes[..., 1] = y.data * 2. - 0.5 + grid_y
256
+ pred_boxes[..., 2] = (w.data * 2) ** 2 * anchor_w
257
+ pred_boxes[..., 3] = (h.data * 2) ** 2 * anchor_h
258
+ point_h = 5
259
+ point_w = 5
260
+ box_xy = pred_boxes[..., 0:2].cpu().numpy() * 32
261
+ box_wh = pred_boxes[..., 2:4].cpu().numpy() * 32
262
+ grid_x = grid_x.cpu().numpy() * 32
263
+ grid_y = grid_y.cpu().numpy() * 32
264
+ anchor_w = anchor_w.cpu().numpy() * 32
265
+ anchor_h = anchor_h.cpu().numpy() * 32
266
+ fig = plt.figure()
267
+ ax = fig.add_subplot(121)
268
+ from PIL import Image
269
+ img = Image.open("img/street.jpg").resize([640, 640])
270
+ plt.imshow(img, alpha=0.5)
271
+ plt.ylim(-30, 650)
272
+ plt.xlim(-30, 650)
273
+ plt.scatter(grid_x, grid_y)
274
+ plt.scatter(point_h * 32, point_w * 32, c='black')
275
+ plt.gca().invert_yaxis()
276
+ anchor_left = grid_x - anchor_w / 2
277
+ anchor_top = grid_y - anchor_h / 2
278
+ rect1 = plt.Rectangle([anchor_left[0, 0, point_h, point_w],anchor_top[0, 0, point_h, point_w]], \
279
+ anchor_w[0, 0, point_h, point_w],anchor_h[0, 0, point_h, point_w],color="r",fill=False)
280
+ rect2 = plt.Rectangle([anchor_left[0, 1, point_h, point_w],anchor_top[0, 1, point_h, point_w]], \
281
+ anchor_w[0, 1, point_h, point_w],anchor_h[0, 1, point_h, point_w],color="r",fill=False)
282
+ rect3 = plt.Rectangle([anchor_left[0, 2, point_h, point_w],anchor_top[0, 2, point_h, point_w]], \
283
+ anchor_w[0, 2, point_h, point_w],anchor_h[0, 2, point_h, point_w],color="r",fill=False)
284
+ ax.add_patch(rect1)
285
+ ax.add_patch(rect2)
286
+ ax.add_patch(rect3)
287
+ ax = fig.add_subplot(122)
288
+ plt.imshow(img, alpha=0.5)
289
+ plt.ylim(-30, 650)
290
+ plt.xlim(-30, 650)
291
+ plt.scatter(grid_x, grid_y)
292
+ plt.scatter(point_h * 32, point_w * 32, c='black')
293
+ plt.scatter(box_xy[0, :, point_h, point_w, 0], box_xy[0, :, point_h, point_w, 1], c='r')
294
+ plt.gca().invert_yaxis()
295
+ pre_left = box_xy[...,0] - box_wh[...,0] / 2
296
+ pre_top = box_xy[...,1] - box_wh[...,1] / 2
297
+ rect1 = plt.Rectangle([pre_left[0, 0, point_h, point_w], pre_top[0, 0, point_h, point_w]],\
298
+ box_wh[0, 0, point_h, point_w,0], box_wh[0, 0, point_h, point_w,1],color="r",fill=False)
299
+ rect2 = plt.Rectangle([pre_left[0, 1, point_h, point_w], pre_top[0, 1, point_h, point_w]],\
300
+ box_wh[0, 1, point_h, point_w,0], box_wh[0, 1, point_h, point_w,1],color="r",fill=False)
301
+ rect3 = plt.Rectangle([pre_left[0, 2, point_h, point_w], pre_top[0, 2, point_h, point_w]],\
302
+ box_wh[0, 2, point_h, point_w,0], box_wh[0, 2, point_h, point_w,1],color="r",fill=False)
303
+ ax.add_patch(rect1)
304
+ ax.add_patch(rect2)
305
+ ax.add_patch(rect3)
306
+ plt.show()
307
+ feat = torch.from_numpy(np.random.normal(0.2, 0.5, [4, 255, 20, 20])).float()
308
+ anchors = np.array([[116, 90], [156, 198], [373, 326], [30,61], [62,45], [59,119], [10,13], [16,30], [33,23]])
309
+ anchors_mask = [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
310
+ get_anchors_and_decode(feat, [640, 640], anchors, anchors_mask, 80)
utils/utils_fit.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ from tqdm import tqdm
5
+ from utils.utils import get_lr
6
+
7
+ def fit_one_epoch(model_train, model, ema, yolo_loss, loss_history, eval_callback, optimizer, epoch, epoch_step, gen, Epoch, cuda, fp16, scaler, save_period, save_dir, local_rank=0):
8
+ loss = 0
9
+ Dehazy_loss = 0
10
+ loss_detection = 0
11
+ criterion = nn.MSELoss()
12
+
13
+ if local_rank == 0:
14
+ print('Start Train')
15
+ pbar = tqdm(total=epoch_step,desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3)
16
+ model_train.train()
17
+
18
+ for iteration, batch in enumerate(gen):
19
+ if iteration >= epoch_step:
20
+ break
21
+ images, targets, clean = batch[0], batch[1], batch[2]
22
+ with torch.no_grad():
23
+ if cuda:
24
+ images = images.cuda(local_rank)
25
+ targets = targets.cuda(local_rank)
26
+ clean = clean.cuda(local_rank)
27
+ hazy_and_clear = torch.cat([images, clean], dim = 0).cuda()
28
+ optimizer.zero_grad()
29
+
30
+ if not fp16:
31
+ outputs = model_train(hazy_and_clear)
32
+ detect_outputs = [outputs[0],outputs[1],outputs[2]]
33
+ loss_detection = yolo_loss(detect_outputs, targets, images)
34
+ loss_dehazy = criterion(outputs[3], clean)
35
+ loss_value = 1 * loss_detection + 0.1 * loss_dehazy
36
+ loss_value.backward()
37
+ optimizer.step()
38
+ else:
39
+ from torch.cuda.amp import autocast
40
+ with autocast():
41
+ outputs = model_train(images)
42
+ loss_value = yolo_loss(outputs, targets, images)
43
+ scaler.scale(loss_value).backward()
44
+ scaler.step(optimizer)
45
+ scaler.update()
46
+ if ema:
47
+ ema.update(model_train)
48
+ Dehazy_loss += loss_dehazy.item()
49
+ loss += loss_value.item()
50
+ loss_detection = (loss - 0.1 * Dehazy_loss)
51
+ if local_rank == 0:
52
+ pbar.set_postfix(**{'loss' : loss / (iteration + 1),
53
+ 'loss_detection' : loss_detection / (iteration + 1),
54
+ 'Dehazy_loss': Dehazy_loss / (iteration + 1),
55
+ 'lr' : get_lr(optimizer)})
56
+ pbar.update(1)
57
+
58
+ if ema:
59
+ model_train_eval = ema.ema
60
+ else:
61
+ model_train_eval = model_train.eval()
62
+
63
+ if local_rank == 0:
64
+ pbar.close()
65
+ loss_history.append_loss(epoch + 1, loss / epoch_step)
66
+ eval_callback.on_epoch_end(epoch + 1, model_train_eval)
67
+ print('Epoch:'+ str(epoch + 1) + '/' + str(Epoch))
68
+ print('Total Loss: %.3f' % (loss / epoch_step))
69
+ if ema:
70
+ save_state_dict = ema.ema.state_dict()
71
+ else:
72
+ save_state_dict = model.state_dict()
73
+ if (epoch + 1) % save_period == 0 or epoch + 1 == Epoch:
74
+ torch.save(save_state_dict, os.path.join(save_dir, "ep%03d-loss%.3f.pth" % (epoch + 1, loss / epoch_step)))
75
+ if loss / epoch_step <= min(loss_history.losses):
76
+ print('Save best model to best_epoch_weights.pth')
77
+ torch.save(save_state_dict, os.path.join(save_dir, "best_epoch_weights.pth"))
78
+ torch.save(save_state_dict, os.path.join(save_dir, "last_epoch_weights.pth"))
utils/utils_map.py ADDED
@@ -0,0 +1,723 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import json
3
+ import math
4
+ import operator
5
+ import os
6
+ import shutil
7
+ import sys
8
+ try:
9
+ from pycocotools.coco import COCO
10
+ from pycocotools.cocoeval import COCOeval
11
+ except:
12
+ pass
13
+ import cv2
14
+ import matplotlib
15
+ matplotlib.use('Agg')
16
+ from matplotlib import pyplot as plt
17
+ import numpy as np
18
+
19
+ def log_average_miss_rate(precision, fp_cumsum, num_images):
20
+ if precision.size == 0:
21
+ lamr = 0
22
+ mr = 1
23
+ fppi = 0
24
+ return lamr, mr, fppi
25
+ fppi = fp_cumsum / float(num_images)
26
+ mr = (1 - precision)
27
+ fppi_tmp = np.insert(fppi, 0, -1.0)
28
+ mr_tmp = np.insert(mr, 0, 1.0)
29
+ ref = np.logspace(-2.0, 0.0, num = 9)
30
+ for i, ref_i in enumerate(ref):
31
+ j = np.where(fppi_tmp <= ref_i)[-1][-1]
32
+ ref[i] = mr_tmp[j]
33
+ lamr = math.exp(np.mean(np.log(np.maximum(1e-10, ref))))
34
+ return lamr, mr, fppi
35
+ """
36
+ throw error and exit
37
+ """
38
+ def error(msg):
39
+ print(msg)
40
+ sys.exit(0)
41
+ """
42
+ check if the number is a float between 0.0 and 1.0
43
+ """
44
+ def is_float_between_0_and_1(value):
45
+ try:
46
+ val = float(value)
47
+ if val > 0.0 and val < 1.0:
48
+ return True
49
+ else:
50
+ return False
51
+ except ValueError:
52
+ return False
53
+
54
+ def voc_ap(rec, prec):
55
+ rec.insert(0, 0.0)
56
+ rec.append(1.0)
57
+ mrec = rec[:]
58
+ prec.insert(0, 0.0)
59
+ prec.append(0.0)
60
+ mpre = prec[:]
61
+
62
+ for i in range(len(mpre)-2, -1, -1):
63
+ mpre[i] = max(mpre[i], mpre[i+1])
64
+
65
+ i_list = []
66
+ for i in range(1, len(mrec)):
67
+ if mrec[i] != mrec[i-1]:
68
+ i_list.append(i)
69
+
70
+ ap = 0.0
71
+ for i in i_list:
72
+ ap += ((mrec[i]-mrec[i-1])*mpre[i])
73
+ return ap, mrec, mpre
74
+
75
+ def file_lines_to_list(path):
76
+ with open(path) as f:
77
+ content = f.readlines()
78
+ content = [x.strip() for x in content]
79
+ return content
80
+
81
+ def draw_text_in_image(img, text, pos, color, line_width):
82
+ font = cv2.FONT_HERSHEY_PLAIN
83
+ fontScale = 1
84
+ lineType = 1
85
+ bottomLeftCornerOfText = pos
86
+ cv2.putText(img, text,
87
+ bottomLeftCornerOfText,
88
+ font,
89
+ fontScale,
90
+ color,
91
+ lineType)
92
+ text_width, _ = cv2.getTextSize(text, font, fontScale, lineType)[0]
93
+ return img, (line_width + text_width)
94
+
95
+ def adjust_axes(r, t, fig, axes):
96
+ bb = t.get_window_extent(renderer=r)
97
+ text_width_inches = bb.width / fig.dpi
98
+ current_fig_width = fig.get_figwidth()
99
+ new_fig_width = current_fig_width + text_width_inches
100
+ propotion = new_fig_width / current_fig_width
101
+ x_lim = axes.get_xlim()
102
+ axes.set_xlim([x_lim[0], x_lim[1]*propotion])
103
+
104
+ def draw_plot_func(dictionary, n_classes, window_title, plot_title, x_label, output_path, to_show, plot_color, true_p_bar):
105
+ sorted_dic_by_value = sorted(dictionary.items(), key=operator.itemgetter(1))
106
+ sorted_keys, sorted_values = zip(*sorted_dic_by_value)
107
+ if true_p_bar != "":
108
+ fp_sorted = []
109
+ tp_sorted = []
110
+ for key in sorted_keys:
111
+ fp_sorted.append(dictionary[key] - true_p_bar[key])
112
+ tp_sorted.append(true_p_bar[key])
113
+ plt.barh(range(n_classes), fp_sorted, align='center', color='crimson', label='False Positive')
114
+ plt.barh(range(n_classes), tp_sorted, align='center', color='forestgreen', label='True Positive', left=fp_sorted)
115
+ plt.legend(loc='lower right')
116
+ """
117
+ Write number on side of bar
118
+ """
119
+ fig = plt.gcf()
120
+ axes = plt.gca()
121
+ # r = fig.canvas.get_renderer()
122
+ for i, val in enumerate(sorted_values):
123
+ fp_val = fp_sorted[i]
124
+ tp_val = tp_sorted[i]
125
+ fp_str_val = " " + str(fp_val)
126
+ tp_str_val = fp_str_val + " " + str(tp_val)
127
+ t = plt.text(val, i, tp_str_val, color='forestgreen', va='center', fontweight='bold')
128
+ plt.text(val, i, fp_str_val, color='crimson', va='center', fontweight='bold')
129
+ if i == (len(sorted_values)-1):
130
+ adjust_axes(r, t, fig, axes)
131
+ else:
132
+ plt.barh(range(n_classes), sorted_values, color=plot_color)
133
+ """
134
+ Write number on side of bar
135
+ """
136
+ fig = plt.gcf()
137
+ axes = plt.gca()
138
+ r = fig.canvas.get_renderer()
139
+ for i, val in enumerate(sorted_values):
140
+ str_val = " " + str(val)
141
+ if val < 1.0:
142
+ str_val = " {0:.2f}".format(val)
143
+ t = plt.text(val, i, str_val, color=plot_color, va='center', fontweight='bold')
144
+ if i == (len(sorted_values)-1):
145
+ adjust_axes(r, t, fig, axes)
146
+ # fig.canvas.set_window_title(window_title)
147
+ tick_font_size = 12
148
+ plt.yticks(range(n_classes), sorted_keys, fontsize=tick_font_size)
149
+ """
150
+ Re-scale height accordingly
151
+ """
152
+ init_height = fig.get_figheight()
153
+ dpi = fig.dpi
154
+ height_pt = n_classes * (tick_font_size * 1.4)
155
+ height_in = height_pt / dpi
156
+ top_margin = 0.15
157
+ bottom_margin = 0.05
158
+ figure_height = height_in / (1 - top_margin - bottom_margin)
159
+ if figure_height > init_height:
160
+ fig.set_figheight(figure_height)
161
+ plt.title(plot_title, fontsize=14)
162
+ plt.xlabel(x_label, fontsize='large')
163
+ fig.tight_layout()
164
+ fig.savefig(output_path)
165
+ if to_show:
166
+ plt.show()
167
+ plt.close()
168
+ def get_map(MINOVERLAP, draw_plot, score_threhold=0.5, path = './map_out'):
169
+ GT_PATH = os.path.join(path, 'ground-truth')
170
+ DR_PATH = os.path.join(path, 'detection-results')
171
+ IMG_PATH = os.path.join(path, 'images-optional')
172
+ TEMP_FILES_PATH = os.path.join(path, '.temp_files')
173
+ RESULTS_FILES_PATH = os.path.join(path, 'results')
174
+ show_animation = True
175
+ if os.path.exists(IMG_PATH):
176
+ for dirpath, dirnames, files in os.walk(IMG_PATH):
177
+ if not files:
178
+ show_animation = False
179
+ else:
180
+ show_animation = False
181
+ if not os.path.exists(TEMP_FILES_PATH):
182
+ os.makedirs(TEMP_FILES_PATH)
183
+ if os.path.exists(RESULTS_FILES_PATH):
184
+ shutil.rmtree(RESULTS_FILES_PATH)
185
+ else:
186
+ os.makedirs(RESULTS_FILES_PATH)
187
+ if draw_plot:
188
+ try:
189
+ matplotlib.use('TkAgg')
190
+ except:
191
+ pass
192
+ os.makedirs(os.path.join(RESULTS_FILES_PATH, "AP"))
193
+ os.makedirs(os.path.join(RESULTS_FILES_PATH, "F1"))
194
+ os.makedirs(os.path.join(RESULTS_FILES_PATH, "Recall"))
195
+ os.makedirs(os.path.join(RESULTS_FILES_PATH, "Precision"))
196
+ if show_animation:
197
+ os.makedirs(os.path.join(RESULTS_FILES_PATH, "images", "detections_one_by_one"))
198
+ ground_truth_files_list = glob.glob(GT_PATH + '/*.txt')
199
+ if len(ground_truth_files_list) == 0:
200
+ error("Error: No ground-truth files found!")
201
+ ground_truth_files_list.sort()
202
+ gt_counter_per_class = {}
203
+ counter_images_per_class = {}
204
+ for txt_file in ground_truth_files_list:
205
+ file_id = txt_file.split(".txt", 1)[0]
206
+ file_id = os.path.basename(os.path.normpath(file_id))
207
+ temp_path = os.path.join(DR_PATH, (file_id + ".txt"))
208
+ if not os.path.exists(temp_path):
209
+ error_msg = "Error. File not found: {}\n".format(temp_path)
210
+ error(error_msg)
211
+ lines_list = file_lines_to_list(txt_file)
212
+ bounding_boxes = []
213
+ is_difficult = False
214
+ already_seen_classes = []
215
+ for line in lines_list:
216
+ try:
217
+ if "difficult" in line:
218
+ class_name, left, top, right, bottom, _difficult = line.split()
219
+ is_difficult = True
220
+ else:
221
+ class_name, left, top, right, bottom = line.split()
222
+ except:
223
+ if "difficult" in line:
224
+ line_split = line.split()
225
+ _difficult = line_split[-1]
226
+ bottom = line_split[-2]
227
+ right = line_split[-3]
228
+ top = line_split[-4]
229
+ left = line_split[-5]
230
+ class_name = ""
231
+ for name in line_split[:-5]:
232
+ class_name += name + " "
233
+ class_name = class_name[:-1]
234
+ is_difficult = True
235
+ else:
236
+ line_split = line.split()
237
+ bottom = line_split[-1]
238
+ right = line_split[-2]
239
+ top = line_split[-3]
240
+ left = line_split[-4]
241
+ class_name = ""
242
+ for name in line_split[:-4]:
243
+ class_name += name + " "
244
+ class_name = class_name[:-1]
245
+ bbox = left + " " + top + " " + right + " " + bottom
246
+ if is_difficult:
247
+ bounding_boxes.append({"class_name":class_name, "bbox":bbox, "used":False, "difficult":True})
248
+ is_difficult = False
249
+ else:
250
+ bounding_boxes.append({"class_name":class_name, "bbox":bbox, "used":False})
251
+ if class_name in gt_counter_per_class:
252
+ gt_counter_per_class[class_name] += 1
253
+ else:
254
+ gt_counter_per_class[class_name] = 1
255
+ if class_name not in already_seen_classes:
256
+ if class_name in counter_images_per_class:
257
+ counter_images_per_class[class_name] += 1
258
+ else:
259
+ counter_images_per_class[class_name] = 1
260
+ already_seen_classes.append(class_name)
261
+ with open(TEMP_FILES_PATH + "/" + file_id + "_ground_truth.json", 'w') as outfile:
262
+ json.dump(bounding_boxes, outfile)
263
+ gt_classes = list(gt_counter_per_class.keys())
264
+ gt_classes = sorted(gt_classes)
265
+ n_classes = len(gt_classes)
266
+ dr_files_list = glob.glob(DR_PATH + '/*.txt')
267
+ dr_files_list.sort()
268
+ for class_index, class_name in enumerate(gt_classes):
269
+ bounding_boxes = []
270
+ for txt_file in dr_files_list:
271
+ file_id = txt_file.split(".txt",1)[0]
272
+ file_id = os.path.basename(os.path.normpath(file_id))
273
+ temp_path = os.path.join(GT_PATH, (file_id + ".txt"))
274
+ if class_index == 0:
275
+ if not os.path.exists(temp_path):
276
+ error_msg = "Error. File not found: {}\n".format(temp_path)
277
+ error(error_msg)
278
+ lines = file_lines_to_list(txt_file)
279
+ for line in lines:
280
+ try:
281
+ tmp_class_name, confidence, left, top, right, bottom = line.split()
282
+ except:
283
+ line_split = line.split()
284
+ bottom = line_split[-1]
285
+ right = line_split[-2]
286
+ top = line_split[-3]
287
+ left = line_split[-4]
288
+ confidence = line_split[-5]
289
+ tmp_class_name = ""
290
+ for name in line_split[:-5]:
291
+ tmp_class_name += name + " "
292
+ tmp_class_name = tmp_class_name[:-1]
293
+ if tmp_class_name == class_name:
294
+ bbox = left + " " + top + " " + right + " " +bottom
295
+ bounding_boxes.append({"confidence":confidence, "file_id":file_id, "bbox":bbox})
296
+ bounding_boxes.sort(key=lambda x:float(x['confidence']), reverse=True)
297
+ with open(TEMP_FILES_PATH + "/" + class_name + "_dr.json", 'w') as outfile:
298
+ json.dump(bounding_boxes, outfile)
299
+ sum_AP = 0.0
300
+ ap_dictionary = {}
301
+ lamr_dictionary = {}
302
+ with open(RESULTS_FILES_PATH + "/results.txt", 'w') as results_file:
303
+ results_file.write("# AP and precision/recall per class\n")
304
+ count_true_positives = {}
305
+
306
+ for class_index, class_name in enumerate(gt_classes):
307
+ count_true_positives[class_name] = 0
308
+ dr_file = TEMP_FILES_PATH + "/" + class_name + "_dr.json"
309
+ dr_data = json.load(open(dr_file))
310
+
311
+ nd = len(dr_data)
312
+ tp = [0] * nd
313
+ fp = [0] * nd
314
+ score = [0] * nd
315
+ score_threhold_idx = 0
316
+ for idx, detection in enumerate(dr_data):
317
+ file_id = detection["file_id"]
318
+ score[idx] = float(detection["confidence"])
319
+ if score[idx] >= score_threhold:
320
+ score_threhold_idx = idx
321
+ if show_animation:
322
+ ground_truth_img = glob.glob1(IMG_PATH, file_id + ".*")
323
+ if len(ground_truth_img) == 0:
324
+ error("Error. Image not found with id: " + file_id)
325
+ elif len(ground_truth_img) > 1:
326
+ error("Error. Multiple image with id: " + file_id)
327
+ else:
328
+ img = cv2.imread(IMG_PATH + "/" + ground_truth_img[0])
329
+ img_cumulative_path = RESULTS_FILES_PATH + "/images/" + ground_truth_img[0]
330
+ if os.path.isfile(img_cumulative_path):
331
+ img_cumulative = cv2.imread(img_cumulative_path)
332
+ else:
333
+ img_cumulative = img.copy()
334
+ bottom_border = 60
335
+ BLACK = [0, 0, 0]
336
+ img = cv2.copyMakeBorder(img, 0, bottom_border, 0, 0, cv2.BORDER_CONSTANT, value=BLACK)
337
+ gt_file = TEMP_FILES_PATH + "/" + file_id + "_ground_truth.json"
338
+ ground_truth_data = json.load(open(gt_file))
339
+ ovmax = -1
340
+ gt_match = -1
341
+ bb = [float(x) for x in detection["bbox"].split()]
342
+ for obj in ground_truth_data:
343
+ if obj["class_name"] == class_name:
344
+ bbgt = [ float(x) for x in obj["bbox"].split() ]
345
+ bi = [max(bb[0],bbgt[0]), max(bb[1],bbgt[1]), min(bb[2],bbgt[2]), min(bb[3],bbgt[3])]
346
+ iw = bi[2] - bi[0] + 1
347
+ ih = bi[3] - bi[1] + 1
348
+ if iw > 0 and ih > 0:
349
+ ua = (bb[2] - bb[0] + 1) * (bb[3] - bb[1] + 1) + (bbgt[2] - bbgt[0]
350
+ + 1) * (bbgt[3] - bbgt[1] + 1) - iw * ih
351
+ ov = iw * ih / ua
352
+ if ov > ovmax:
353
+ ovmax = ov
354
+ gt_match = obj
355
+ if show_animation:
356
+ status = "NO MATCH FOUND!"
357
+ min_overlap = MINOVERLAP
358
+ if ovmax >= min_overlap:
359
+ if "difficult" not in gt_match:
360
+ if not bool(gt_match["used"]):
361
+ tp[idx] = 1
362
+ gt_match["used"] = True
363
+ count_true_positives[class_name] += 1
364
+ with open(gt_file, 'w') as f:
365
+ f.write(json.dumps(ground_truth_data))
366
+ if show_animation:
367
+ status = "MATCH!"
368
+ else:
369
+ fp[idx] = 1
370
+ if show_animation:
371
+ status = "REPEATED MATCH!"
372
+ else:
373
+ fp[idx] = 1
374
+ if ovmax > 0:
375
+ status = "INSUFFICIENT OVERLAP"
376
+ """
377
+ Draw image to show animation
378
+ """
379
+ if show_animation:
380
+ height, widht = img.shape[:2]
381
+ white = (255,255,255)
382
+ light_blue = (255,200,100)
383
+ green = (0,255,0)
384
+ light_red = (30,30,255)
385
+ margin = 10
386
+ v_pos = int(height - margin - (bottom_border / 2.0))
387
+ text = "Image: " + ground_truth_img[0] + " "
388
+ img, line_width = draw_text_in_image(img, text, (margin, v_pos), white, 0)
389
+ text = "Class [" + str(class_index) + "/" + str(n_classes) + "]: " + class_name + " "
390
+ img, line_width = draw_text_in_image(img, text, (margin + line_width, v_pos), light_blue, line_width)
391
+ if ovmax != -1:
392
+ color = light_red
393
+ if status == "INSUFFICIENT OVERLAP":
394
+ text = "IoU: {0:.2f}% ".format(ovmax*100) + "< {0:.2f}% ".format(min_overlap*100)
395
+ else:
396
+ text = "IoU: {0:.2f}% ".format(ovmax*100) + ">= {0:.2f}% ".format(min_overlap*100)
397
+ color = green
398
+ img, _ = draw_text_in_image(img, text, (margin + line_width, v_pos), color, line_width)
399
+ v_pos += int(bottom_border / 2.0)
400
+ rank_pos = str(idx+1)
401
+ text = "Detection #rank: " + rank_pos + " confidence: {0:.2f}% ".format(float(detection["confidence"])*100)
402
+ img, line_width = draw_text_in_image(img, text, (margin, v_pos), white, 0)
403
+ color = light_red
404
+ if status == "MATCH!":
405
+ color = green
406
+ text = "Result: " + status + " "
407
+ img, line_width = draw_text_in_image(img, text, (margin + line_width, v_pos), color, line_width)
408
+
409
+ font = cv2.FONT_HERSHEY_SIMPLEX
410
+ if ovmax > 0:
411
+ bbgt = [ int(round(float(x))) for x in gt_match["bbox"].split() ]
412
+ cv2.rectangle(img,(bbgt[0],bbgt[1]),(bbgt[2],bbgt[3]),light_blue,2)
413
+ cv2.rectangle(img_cumulative,(bbgt[0],bbgt[1]),(bbgt[2],bbgt[3]),light_blue,2)
414
+ cv2.putText(img_cumulative, class_name, (bbgt[0],bbgt[1] - 5), font, 0.6, light_blue, 1, cv2.LINE_AA)
415
+ bb = [int(i) for i in bb]
416
+ cv2.rectangle(img,(bb[0],bb[1]),(bb[2],bb[3]),color,2)
417
+ cv2.rectangle(img_cumulative,(bb[0],bb[1]),(bb[2],bb[3]),color,2)
418
+ cv2.putText(img_cumulative, class_name, (bb[0],bb[1] - 5), font, 0.6, color, 1, cv2.LINE_AA)
419
+ cv2.imshow("Animation", img)
420
+ cv2.waitKey(20)
421
+ output_img_path = RESULTS_FILES_PATH + "/images/detections_one_by_one/" + class_name + "_detection" + str(idx) + ".jpg"
422
+ cv2.imwrite(output_img_path, img)
423
+ cv2.imwrite(img_cumulative_path, img_cumulative)
424
+ cumsum = 0
425
+ for idx, val in enumerate(fp):
426
+ fp[idx] += cumsum
427
+ cumsum += val
428
+ cumsum = 0
429
+ for idx, val in enumerate(tp):
430
+ tp[idx] += cumsum
431
+ cumsum += val
432
+ rec = tp[:]
433
+ for idx, val in enumerate(tp):
434
+ rec[idx] = float(tp[idx]) / np.maximum(gt_counter_per_class[class_name], 1)
435
+ prec = tp[:]
436
+ for idx, val in enumerate(tp):
437
+ prec[idx] = float(tp[idx]) / np.maximum((fp[idx] + tp[idx]), 1)
438
+ ap, mrec, mprec = voc_ap(rec[:], prec[:])
439
+ F1 = np.array(rec)*np.array(prec)*2 / np.where((np.array(prec)+np.array(rec))==0, 1, (np.array(prec)+np.array(rec)))
440
+ sum_AP += ap
441
+ text = "{0:.2f}%".format(ap*100) + " = " + class_name + " AP "
442
+ if len(prec)>0:
443
+ F1_text = "{0:.2f}".format(F1[score_threhold_idx]) + " = " + class_name + " F1 "
444
+ Recall_text = "{0:.2f}%".format(rec[score_threhold_idx]*100) + " = " + class_name + " Recall "
445
+ Precision_text = "{0:.2f}%".format(prec[score_threhold_idx]*100) + " = " + class_name + " Precision "
446
+ else:
447
+ F1_text = "0.00" + " = " + class_name + " F1 "
448
+ Recall_text = "0.00%" + " = " + class_name + " Recall "
449
+ Precision_text = "0.00%" + " = " + class_name + " Precision "
450
+ rounded_prec = [ '%.2f' % elem for elem in prec ]
451
+ rounded_rec = [ '%.2f' % elem for elem in rec ]
452
+ results_file.write(text + "\n Precision: " + str(rounded_prec) + "\n Recall :" + str(rounded_rec) + "\n\n")
453
+ if len(prec)>0:
454
+ print(text + "\t||\tscore_threhold=" + str(score_threhold) + " : " + "F1=" + "{0:.2f}".format(F1[score_threhold_idx])\
455
+ + " ; Recall=" + "{0:.2f}%".format(rec[score_threhold_idx]*100) + " ; Precision=" + "{0:.2f}%".format(prec[score_threhold_idx]*100))
456
+ else:
457
+ print(text + "\t||\tscore_threhold=" + str(score_threhold) + " : " + "F1=0.00% ; Recall=0.00% ; Precision=0.00%")
458
+ ap_dictionary[class_name] = ap
459
+ n_images = counter_images_per_class[class_name]
460
+ lamr, mr, fppi = log_average_miss_rate(np.array(rec), np.array(fp), n_images)
461
+ lamr_dictionary[class_name] = lamr
462
+ if draw_plot:
463
+ plt.plot(rec, prec, '-o')
464
+ area_under_curve_x = mrec[:-1] + [mrec[-2]] + [mrec[-1]]
465
+ area_under_curve_y = mprec[:-1] + [0.0] + [mprec[-1]]
466
+ plt.fill_between(area_under_curve_x, 0, area_under_curve_y, alpha=0.2, edgecolor='r')
467
+ fig = plt.gcf()
468
+ # fig.canvas.set_window_title('AP ' + class_name)
469
+ plt.title('class: ' + text)
470
+ plt.xlabel('Recall')
471
+ plt.ylabel('Precision')
472
+ axes = plt.gca()
473
+ axes.set_xlim([0.0,1.0])
474
+ axes.set_ylim([0.0,1.05])
475
+ fig.savefig(RESULTS_FILES_PATH + "/AP/" + class_name + ".png")
476
+ plt.cla()
477
+ plt.plot(score, F1, "-", color='orangered')
478
+ plt.title('class: ' + F1_text + "\nscore_threhold=" + str(score_threhold))
479
+ plt.xlabel('Score_Threhold')
480
+ plt.ylabel('F1')
481
+ axes = plt.gca()
482
+ axes.set_xlim([0.0,1.0])
483
+ axes.set_ylim([0.0,1.05])
484
+ fig.savefig(RESULTS_FILES_PATH + "/F1/" + class_name + ".png")
485
+ plt.cla()
486
+ plt.plot(score, rec, "-H", color='gold')
487
+ plt.title('class: ' + Recall_text + "\nscore_threhold=" + str(score_threhold))
488
+ plt.xlabel('Score_Threhold')
489
+ plt.ylabel('Recall')
490
+ axes = plt.gca()
491
+ axes.set_xlim([0.0,1.0])
492
+ axes.set_ylim([0.0,1.05])
493
+ fig.savefig(RESULTS_FILES_PATH + "/Recall/" + class_name + ".png")
494
+ plt.cla()
495
+ plt.plot(score, prec, "-s", color='palevioletred')
496
+ plt.title('class: ' + Precision_text + "\nscore_threhold=" + str(score_threhold))
497
+ plt.xlabel('Score_Threhold')
498
+ plt.ylabel('Precision')
499
+ axes = plt.gca()
500
+ axes.set_xlim([0.0,1.0])
501
+ axes.set_ylim([0.0,1.05])
502
+ fig.savefig(RESULTS_FILES_PATH + "/Precision/" + class_name + ".png")
503
+ plt.cla()
504
+ if show_animation:
505
+ cv2.destroyAllWindows()
506
+ if n_classes == 0:
507
+ print("No species detected, check if the label information is modified with the classes_path in get_map.py.")
508
+ return 0
509
+ results_file.write("\n# mAP of all classes\n")
510
+ mAP = sum_AP / n_classes
511
+ text = "mAP = {0:.2f}%".format(mAP*100)
512
+ results_file.write(text + "\n")
513
+ print(text)
514
+ shutil.rmtree(TEMP_FILES_PATH)
515
+ """
516
+ Count total of detection-results
517
+ """
518
+ det_counter_per_class = {}
519
+ for txt_file in dr_files_list:
520
+ lines_list = file_lines_to_list(txt_file)
521
+ for line in lines_list:
522
+ class_name = line.split()[0]
523
+ if class_name in det_counter_per_class:
524
+ det_counter_per_class[class_name] += 1
525
+ else:
526
+ det_counter_per_class[class_name] = 1
527
+ dr_classes = list(det_counter_per_class.keys())
528
+ """
529
+ Write number of ground-truth objects per class to results.txt
530
+ """
531
+ with open(RESULTS_FILES_PATH + "/results.txt", 'a') as results_file:
532
+ results_file.write("\n# Number of ground-truth objects per class\n")
533
+ for class_name in sorted(gt_counter_per_class):
534
+ results_file.write(class_name + ": " + str(gt_counter_per_class[class_name]) + "\n")
535
+ """
536
+ Finish counting true positives
537
+ """
538
+ for class_name in dr_classes:
539
+ if class_name not in gt_classes:
540
+ count_true_positives[class_name] = 0
541
+ """
542
+ Write number of detected objects per class to results.txt
543
+ """
544
+ with open(RESULTS_FILES_PATH + "/results.txt", 'a') as results_file:
545
+ results_file.write("\n# Number of detected objects per class\n")
546
+ for class_name in sorted(dr_classes):
547
+ n_det = det_counter_per_class[class_name]
548
+ text = class_name + ": " + str(n_det)
549
+ text += " (tp:" + str(count_true_positives[class_name]) + ""
550
+ text += ", fp:" + str(n_det - count_true_positives[class_name]) + ")\n"
551
+ results_file.write(text)
552
+ """
553
+ Plot the total number of occurences of each class in the ground-truth
554
+ """
555
+ if draw_plot:
556
+ window_title = "ground-truth-info"
557
+ plot_title = "ground-truth\n"
558
+ plot_title += "(" + str(len(ground_truth_files_list)) + " files and " + str(n_classes) + " classes)"
559
+ x_label = "Number of objects per class"
560
+ output_path = RESULTS_FILES_PATH + "/ground-truth-info.png"
561
+ to_show = False
562
+ plot_color = 'forestgreen'
563
+ draw_plot_func(
564
+ gt_counter_per_class,
565
+ n_classes,
566
+ window_title,
567
+ plot_title,
568
+ x_label,
569
+ output_path,
570
+ to_show,
571
+ plot_color,
572
+ '',
573
+ )
574
+ """
575
+ Draw log-average miss rate plot (Show lamr of all classes in decreasing order)
576
+ """
577
+ if draw_plot:
578
+ window_title = "lamr"
579
+ plot_title = "log-average miss rate"
580
+ x_label = "log-average miss rate"
581
+ output_path = RESULTS_FILES_PATH + "/lamr.png"
582
+ to_show = False
583
+ plot_color = 'royalblue'
584
+ draw_plot_func(
585
+ lamr_dictionary,
586
+ n_classes,
587
+ window_title,
588
+ plot_title,
589
+ x_label,
590
+ output_path,
591
+ to_show,
592
+ plot_color,
593
+ ""
594
+ )
595
+ """
596
+ Draw mAP plot (Show AP's of all classes in decreasing order)
597
+ """
598
+ if draw_plot:
599
+ window_title = "mAP"
600
+ plot_title = "mAP = {0:.2f}%".format(mAP*100)
601
+ x_label = "Average Precision"
602
+ output_path = RESULTS_FILES_PATH + "/mAP.png"
603
+ to_show = True
604
+ plot_color = 'royalblue'
605
+ draw_plot_func(
606
+ ap_dictionary,
607
+ n_classes,
608
+ window_title,
609
+ plot_title,
610
+ x_label,
611
+ output_path,
612
+ to_show,
613
+ plot_color,
614
+ ""
615
+ )
616
+ return mAP
617
+ def preprocess_gt(gt_path, class_names):
618
+ image_ids = os.listdir(gt_path)
619
+ results = {}
620
+ images = []
621
+ bboxes = []
622
+ for i, image_id in enumerate(image_ids):
623
+ lines_list = file_lines_to_list(os.path.join(gt_path, image_id))
624
+ boxes_per_image = []
625
+ image = {}
626
+ image_id = os.path.splitext(image_id)[0]
627
+ image['file_name'] = image_id + '.jpg'
628
+ image['width'] = 1
629
+ image['height'] = 1
630
+ image['id'] = str(image_id)
631
+ for line in lines_list:
632
+ difficult = 0
633
+ if "difficult" in line:
634
+ line_split = line.split()
635
+ left, top, right, bottom, _difficult = line_split[-5:]
636
+ class_name = ""
637
+ for name in line_split[:-5]:
638
+ class_name += name + " "
639
+ class_name = class_name[:-1]
640
+ difficult = 1
641
+ else:
642
+ line_split = line.split()
643
+ left, top, right, bottom = line_split[-4:]
644
+ class_name = ""
645
+ for name in line_split[:-4]:
646
+ class_name += name + " "
647
+ class_name = class_name[:-1]
648
+ left, top, right, bottom = float(left), float(top), float(right), float(bottom)
649
+ if class_name not in class_names:
650
+ continue
651
+ cls_id = class_names.index(class_name) + 1
652
+ bbox = [left, top, right - left, bottom - top, difficult, str(image_id), cls_id, (right - left) * (bottom - top) - 10.0]
653
+ boxes_per_image.append(bbox)
654
+ images.append(image)
655
+ bboxes.extend(boxes_per_image)
656
+ results['images'] = images
657
+ categories = []
658
+ for i, cls in enumerate(class_names):
659
+ category = {}
660
+ category['supercategory'] = cls
661
+ category['name'] = cls
662
+ category['id'] = i + 1
663
+ categories.append(category)
664
+ results['categories'] = categories
665
+ annotations = []
666
+ for i, box in enumerate(bboxes):
667
+ annotation = {}
668
+ annotation['area'] = box[-1]
669
+ annotation['category_id'] = box[-2]
670
+ annotation['image_id'] = box[-3]
671
+ annotation['iscrowd'] = box[-4]
672
+ annotation['bbox'] = box[:4]
673
+ annotation['id'] = i
674
+ annotations.append(annotation)
675
+ results['annotations'] = annotations
676
+ return results
677
+ def preprocess_dr(dr_path, class_names):
678
+ image_ids = os.listdir(dr_path)
679
+ results = []
680
+ for image_id in image_ids:
681
+ lines_list = file_lines_to_list(os.path.join(dr_path, image_id))
682
+ image_id = os.path.splitext(image_id)[0]
683
+ for line in lines_list:
684
+ line_split = line.split()
685
+ confidence, left, top, right, bottom = line_split[-5:]
686
+ class_name = ""
687
+ for name in line_split[:-5]:
688
+ class_name += name + " "
689
+ class_name = class_name[:-1]
690
+ left, top, right, bottom = float(left), float(top), float(right), float(bottom)
691
+ result = {}
692
+ result["image_id"] = str(image_id)
693
+ if class_name not in class_names:
694
+ continue
695
+ result["category_id"] = class_names.index(class_name) + 1
696
+ result["bbox"] = [left, top, right - left, bottom - top]
697
+ result["score"] = float(confidence)
698
+ results.append(result)
699
+ return results
700
+ def get_coco_map(class_names, path):
701
+ GT_PATH = os.path.join(path, 'ground-truth')
702
+ DR_PATH = os.path.join(path, 'detection-results')
703
+ COCO_PATH = os.path.join(path, 'coco_eval')
704
+ if not os.path.exists(COCO_PATH):
705
+ os.makedirs(COCO_PATH)
706
+ GT_JSON_PATH = os.path.join(COCO_PATH, 'instances_gt.json')
707
+ DR_JSON_PATH = os.path.join(COCO_PATH, 'instances_dr.json')
708
+ with open(GT_JSON_PATH, "w") as f:
709
+ results_gt = preprocess_gt(GT_PATH, class_names)
710
+ json.dump(results_gt, f, indent=4)
711
+ with open(DR_JSON_PATH, "w") as f:
712
+ results_dr = preprocess_dr(DR_PATH, class_names)
713
+ json.dump(results_dr, f, indent=4)
714
+ if len(results_dr) == 0:
715
+ print("No targets detected.")
716
+ return [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
717
+ cocoGt = COCO(GT_JSON_PATH)
718
+ cocoDt = cocoGt.loadRes(DR_JSON_PATH)
719
+ cocoEval = COCOeval(cocoGt, cocoDt, 'bbox')
720
+ cocoEval.evaluate()
721
+ cocoEval.accumulate()
722
+ cocoEval.summarize()
723
+ return cocoEval.stats