File size: 3,597 Bytes
c1651d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import numpy as np
import argparse
import glob
import os
import sys
import torch
import cv2
import random
import time
import multiprocessing.pool as mpp
import multiprocessing as mp
SEED = 66

def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
def label2rgb(mask, mask_pred):
    real_1 = (mask == 1)
    real_0 = (mask == 0)
    pred_1 = (mask_pred == 1)
    pred_0 = (mask_pred == 0)

    TP = np.logical_and(real_1, pred_1)
    TN = np.logical_and(real_0, pred_0)
    FN = np.logical_and(real_1, pred_0)
    FP = np.logical_and(real_0, pred_1)

    mask_TP = TP[np.newaxis, :, :]
    mask_TN = TN[np.newaxis, :, :]
    mask_FN = FN[np.newaxis, :, :]
    mask_FP = FP[np.newaxis, :, :]

    h, w = mask.shape[0], mask.shape[1]
    mask_rgb = np.zeros(shape=(h, w, 3), dtype=np.uint8)
    mask_rgb[np.all(mask_TP, axis=0)] = [255, 255, 255] # TP
    mask_rgb[np.all(mask_TN, axis=0)] = [0, 0, 0] # TN
    mask_rgb[np.all(mask_FN, axis=0)] = [0, 255, 0] # FN
    mask_rgb[np.all(mask_FP, axis=0)] = [255, 0, 0] # FP

    return mask_rgb

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset", default="Vaihingen")
    parser.add_argument("--mask-dir", default="data/Test/masks")
    parser.add_argument("--output-mask-dir", default="data/Test/masks_rgb")
    return parser.parse_args()

def mask_save(inp):
    (mask, mask_pred, masks_output_dir, file_name) = inp
    out_mask_path = os.path.join(masks_output_dir, "{}.png".format(file_name))

    label = label2rgb(mask.copy(), mask_pred.copy())

    rgb_label = cv2.cvtColor(label, cv2.COLOR_BGR2RGB)
    cv2.imwrite(out_mask_path, rgb_label)

# def get_rgb(inp):
#     (mask_path, masks_output_dir,dataset) = inp
#     mask_filename = os.path.splitext(os.path.basename(mask_path))[0]
#     mask_bgr = cv2.imread(mask_path, cv2.IMREAD_UNCHANGED)
#     mask = cv2.cvtColor(mask_bgr, cv2.COLOR_BGR2RGB)
#     if dataset == "LoveDA":
#         rgb_label = loveda_label2rgb(mask.copy())
#     elif dataset == "Vaihingen":
#         rgb_label = vaihingen_label2rgb(mask.copy())
#     elif dataset == "Potsdam":
#         rgb_label = potsdam_label2rgb(mask.copy())
#     elif dataset == "uavid":
#         rgb_label = uavid_label2rgb(mask.copy())
#     else: return
#     #rgb_label = cv2.cvtColor(rgb_label, cv2.COLOR_RGB2BGR)

#     out_mask_path_rgb = os.path.join(masks_output_dir, "{}.png".format(mask_filename))
#     rgb_label = cv2.cvtColor(rgb_label, cv2.COLOR_BGR2RGB)
#     cv2.imwrite(out_mask_path_rgb, rgb_label)

# if __name__ == '__main__':
#     base_path = "/home/xwma/lrr/rssegmentation/"
#     args = parse_args()
#     dataset = args.dataset

#     seed_everything(SEED)
#     masks_dir = args.mask_dir
#     masks_output_dir = args.output_mask_dir
#     masks_dir = base_path + masks_dir
#     masks_output_dir = base_path + masks_output_dir

#     mask_paths = glob.glob(os.path.join(masks_dir, "*.png"))
#     inp = [(mask_path, masks_output_dir, dataset) for mask_path in mask_paths]
#     if not os.path.exists(masks_output_dir):
#         os.makedirs(masks_output_dir)

#     t0 = time.time()
#     mpp.Pool(processes=mp.cpu_count()).map(get_rgb, inp)
#     t1 = time.time()
#     split_time = t1 - t0
#     print('images spliting spends: {} s'.format(split_time))