Spaces:
Paused
Paused
File size: 4,906 Bytes
405d2b1 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 | import cv2
import numpy as np
import torch
from matplotlib import cm
import matplotlib.pyplot as plt
import logging
logger = logging.getLogger('root')
def tensor_to_numpy(tensor_in):
""" torch tensor to numpy array
"""
if tensor_in is not None:
if tensor_in.ndim == 3:
# (C, H, W) -> (H, W, C)
tensor_in = tensor_in.detach().cpu().permute(1, 2, 0).numpy()
elif tensor_in.ndim == 4:
# (B, C, H, W) -> (B, H, W, C)
tensor_in = tensor_in.detach().cpu().permute(0, 2, 3, 1).numpy()
else:
raise Exception('invalid tensor size')
return tensor_in
# def unnormalize(img_in, img_stats={'mean': [0.485, 0.456, 0.406],
# 'std': [0.229, 0.224, 0.225]}):
def unnormalize(img_in, img_stats={'mean': [0.5,0.5,0.5], 'std': [0.5,0.5,0.5]}):
""" unnormalize input image
"""
if torch.is_tensor(img_in):
img_in = tensor_to_numpy(img_in)
# 检查输入图像的数值范围,决定是否需要去归一化
img_min, img_max = img_in.min(), img_in.max()
# 如果图像已经在[0,1]范围内,直接转换为[0,255]
if img_min >= -0.1 and img_max <= 1.1: # 允许小的浮点误差
img_out = np.clip(img_in, 0, 1)
img_out = (img_out * 255.0).astype(np.uint8)
else:
# 如果图像在[-1,1]或其他归一化范围内,进行标准去归一化
img_out = np.zeros_like(img_in)
for ich in range(3):
img_out[..., ich] = img_in[..., ich] * img_stats['std'][ich]
img_out[..., ich] += img_stats['mean'][ich]
img_out = np.clip(img_out, 0, 1)
img_out = (img_out * 255.0).astype(np.uint8)
return img_out
def normal_to_rgb(normal, normal_mask=None):
""" surface normal map to RGB
(used for visualization)
NOTE: x, y, z are mapped to R, G, B
NOTE: [-1, 1] are mapped to [0, 255]
"""
if torch.is_tensor(normal):
normal = tensor_to_numpy(normal)
normal_mask = tensor_to_numpy(normal_mask)
normal_norm = np.linalg.norm(normal, axis=-1, keepdims=True)
normal_norm[normal_norm < 1e-12] = 1e-12
normal = normal / normal_norm
normal_rgb = (((normal + 1) * 0.5) * 255).astype(np.uint8)
if normal_mask is not None:
normal_rgb = normal_rgb * normal_mask # (B, H, W, 3)
return normal_rgb
def kappa_to_alpha(pred_kappa, to_numpy=True):
""" Confidence kappa to uncertainty alpha
Assuming AngMF distribution (introduced in https://arxiv.org/abs/2109.09881)
"""
if torch.is_tensor(pred_kappa) and to_numpy:
pred_kappa = tensor_to_numpy(pred_kappa)
if torch.is_tensor(pred_kappa):
alpha = ((2 * pred_kappa) / ((pred_kappa ** 2.0) + 1)) \
+ ((torch.exp(- pred_kappa * np.pi) * np.pi) / (1 + torch.exp(- pred_kappa * np.pi)))
alpha = torch.rad2deg(alpha)
else:
alpha = ((2 * pred_kappa) / ((pred_kappa ** 2.0) + 1)) \
+ ((np.exp(- pred_kappa * np.pi) * np.pi) / (1 + np.exp(- pred_kappa * np.pi)))
alpha = np.degrees(alpha)
return alpha
def visualize_normal(target_dir, prefixs, img, pred_norm, pred_kappa,
gt_norm, gt_norm_mask, pred_error, num_vis=-1):
""" visualize normal
"""
error_max = 60.0
# img = tensor_to_numpy(img) # (B, H, W, 3)
pred_norm = tensor_to_numpy(pred_norm) # (B, H, W, 3)
# pred_kappa = tensor_to_numpy(pred_kappa) # (B, H, W, 1)
gt_norm = tensor_to_numpy(gt_norm) # (B, H, W, 3)
gt_norm_mask = tensor_to_numpy(gt_norm_mask) # (B, H, W, 1)
pred_error = tensor_to_numpy(pred_error) # (B, H, W, 1)
num_vis = len(prefixs) if num_vis == -1 else num_vis
for i in range(num_vis):
# # img
# img_ = unnormalize(img[i, ...])
# target_path = '%s/%s_img.png' % (target_dir, prefixs[i])
# plt.imsave(target_path, img_)
# pred_norm
target_path = '%s/%s_norm.png' % (target_dir, prefixs[i])
plt.imsave(target_path, normal_to_rgb(pred_norm[i, ...]))
# # pred_kappa
# if pred_kappa is not None:
# pred_alpha = kappa_to_alpha(pred_kappa[i, :, :, 0])
# target_path = '%s/%s_pred_alpha.png' % (target_dir, prefixs[i])
# plt.imsave(target_path, pred_alpha, vmin=0.0, vmax=error_max, cmap='jet')
# gt_norm, pred_error
if gt_norm is not None:
target_path = '%s/%s_gt.png' % (target_dir, prefixs[i])
# plt.imsave(target_path, normal_to_rgb(gt_norm[i, ...], gt_norm_mask[i, ...]))
E = pred_error[i, :, :, 0] * gt_norm_mask[i, :, :, 0]
target_path = '%s/%s_pred_error.png' % (target_dir, prefixs[i])
plt.imsave(target_path, E, vmin=0, vmax=error_max, cmap='jet')
|