InPeerReview's picture
Upload 9 files
c1651d2 verified
import numpy as np
import argparse
import glob
import os
import sys
import torch
import cv2
import random
import time
import multiprocessing.pool as mpp
import multiprocessing as mp
SEED = 66
def seed_everything(seed):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True
def label2rgb(mask, mask_pred):
real_1 = (mask == 1)
real_0 = (mask == 0)
pred_1 = (mask_pred == 1)
pred_0 = (mask_pred == 0)
TP = np.logical_and(real_1, pred_1)
TN = np.logical_and(real_0, pred_0)
FN = np.logical_and(real_1, pred_0)
FP = np.logical_and(real_0, pred_1)
mask_TP = TP[np.newaxis, :, :]
mask_TN = TN[np.newaxis, :, :]
mask_FN = FN[np.newaxis, :, :]
mask_FP = FP[np.newaxis, :, :]
h, w = mask.shape[0], mask.shape[1]
mask_rgb = np.zeros(shape=(h, w, 3), dtype=np.uint8)
mask_rgb[np.all(mask_TP, axis=0)] = [255, 255, 255] # TP
mask_rgb[np.all(mask_TN, axis=0)] = [0, 0, 0] # TN
mask_rgb[np.all(mask_FN, axis=0)] = [0, 255, 0] # FN
mask_rgb[np.all(mask_FP, axis=0)] = [255, 0, 0] # FP
return mask_rgb
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--dataset", default="Vaihingen")
parser.add_argument("--mask-dir", default="data/Test/masks")
parser.add_argument("--output-mask-dir", default="data/Test/masks_rgb")
return parser.parse_args()
def mask_save(inp):
(mask, mask_pred, masks_output_dir, file_name) = inp
out_mask_path = os.path.join(masks_output_dir, "{}.png".format(file_name))
label = label2rgb(mask.copy(), mask_pred.copy())
rgb_label = cv2.cvtColor(label, cv2.COLOR_BGR2RGB)
cv2.imwrite(out_mask_path, rgb_label)
# def get_rgb(inp):
# (mask_path, masks_output_dir,dataset) = inp
# mask_filename = os.path.splitext(os.path.basename(mask_path))[0]
# mask_bgr = cv2.imread(mask_path, cv2.IMREAD_UNCHANGED)
# mask = cv2.cvtColor(mask_bgr, cv2.COLOR_BGR2RGB)
# if dataset == "LoveDA":
# rgb_label = loveda_label2rgb(mask.copy())
# elif dataset == "Vaihingen":
# rgb_label = vaihingen_label2rgb(mask.copy())
# elif dataset == "Potsdam":
# rgb_label = potsdam_label2rgb(mask.copy())
# elif dataset == "uavid":
# rgb_label = uavid_label2rgb(mask.copy())
# else: return
# #rgb_label = cv2.cvtColor(rgb_label, cv2.COLOR_RGB2BGR)
# out_mask_path_rgb = os.path.join(masks_output_dir, "{}.png".format(mask_filename))
# rgb_label = cv2.cvtColor(rgb_label, cv2.COLOR_BGR2RGB)
# cv2.imwrite(out_mask_path_rgb, rgb_label)
# if __name__ == '__main__':
# base_path = "/home/xwma/lrr/rssegmentation/"
# args = parse_args()
# dataset = args.dataset
# seed_everything(SEED)
# masks_dir = args.mask_dir
# masks_output_dir = args.output_mask_dir
# masks_dir = base_path + masks_dir
# masks_output_dir = base_path + masks_output_dir
# mask_paths = glob.glob(os.path.join(masks_dir, "*.png"))
# inp = [(mask_path, masks_output_dir, dataset) for mask_path in mask_paths]
# if not os.path.exists(masks_output_dir):
# os.makedirs(masks_output_dir)
# t0 = time.time()
# mpp.Pool(processes=mp.cpu_count()).map(get_rgb, inp)
# t1 = time.time()
# split_time = t1 - t0
# print('images spliting spends: {} s'.format(split_time))