| | import numpy as np |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from torchsparse.tensor import PointTensor, SparseTensor |
| | import torchsparse.nn as spnn |
| |
|
| | from tsparse.modules import SparseCostRegNet |
| | from tsparse.torchsparse_utils import sparse_to_dense_channel |
| | from ops.grid_sampler import grid_sample_3d, tricubic_sample_3d |
| |
|
| | |
| | from ops.back_project import back_project_sparse_type |
| | from ops.generate_grids import generate_grid |
| |
|
| | from inplace_abn import InPlaceABN |
| |
|
| | from models.embedder import Embedding |
| | from models.featurenet import ConvBnReLU |
| |
|
| | import pdb |
| | import random |
| |
|
| | torch._C._jit_set_profiling_executor(False) |
| | torch._C._jit_set_profiling_mode(False) |
| |
|
| |
|
| | @torch.jit.script |
| | def fused_mean_variance(x, weight): |
| | mean = torch.sum(x * weight, dim=1, keepdim=True) |
| | var = torch.sum(weight * (x - mean) ** 2, dim=1, keepdim=True) |
| | return mean, var |
| |
|
| |
|
| | class LatentSDFLayer(nn.Module): |
| | def __init__(self, |
| | d_in=3, |
| | d_out=129, |
| | d_hidden=128, |
| | n_layers=4, |
| | skip_in=(4,), |
| | multires=0, |
| | bias=0.5, |
| | geometric_init=True, |
| | weight_norm=True, |
| | activation='softplus', |
| | d_conditional_feature=16): |
| | super(LatentSDFLayer, self).__init__() |
| |
|
| | self.d_conditional_feature = d_conditional_feature |
| |
|
| | |
| | dims_in = [d_in] + [d_hidden + d_conditional_feature for _ in range(n_layers - 2)] + [d_hidden] |
| | dims_out = [d_hidden for _ in range(n_layers - 1)] + [d_out] |
| |
|
| | self.embed_fn_fine = None |
| |
|
| | if multires > 0: |
| | embed_fn = Embedding(in_channels=d_in, N_freqs=multires) |
| | self.embed_fn_fine = embed_fn |
| | dims_in[0] = embed_fn.out_channels |
| |
|
| | self.num_layers = n_layers |
| | self.skip_in = skip_in |
| |
|
| | for l in range(0, self.num_layers - 1): |
| | if l in self.skip_in: |
| | in_dim = dims_in[l] + dims_in[0] |
| | else: |
| | in_dim = dims_in[l] |
| |
|
| | out_dim = dims_out[l] |
| | lin = nn.Linear(in_dim, out_dim) |
| |
|
| | if geometric_init: |
| | if l == self.num_layers - 2: |
| | torch.nn.init.normal_(lin.weight, mean=np.sqrt(np.pi) / np.sqrt(in_dim), std=0.0001) |
| | torch.nn.init.constant_(lin.bias, -bias) |
| | |
| | torch.nn.init.constant_(lin.weight[:, -d_conditional_feature:], 0.0) |
| | torch.nn.init.constant_(lin.bias[-d_conditional_feature:], 0.0) |
| |
|
| | elif multires > 0 and l == 0: |
| | torch.nn.init.constant_(lin.bias, 0.0) |
| | |
| | torch.nn.init.constant_(lin.weight[:, 3:], 0.0) |
| | |
| | torch.nn.init.normal_(lin.weight[:, :3], 0.0, np.sqrt(2) / np.sqrt(out_dim)) |
| | elif multires > 0 and l in self.skip_in: |
| | torch.nn.init.constant_(lin.bias, 0.0) |
| | torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) |
| | |
| | torch.nn.init.constant_(lin.weight[:, -(dims_in[0] - 3 + d_conditional_feature):], 0.0) |
| | else: |
| | torch.nn.init.constant_(lin.bias, 0.0) |
| | torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) |
| | |
| | torch.nn.init.constant_(lin.weight[:, -d_conditional_feature:], 0.0) |
| |
|
| | if weight_norm: |
| | lin = nn.utils.weight_norm(lin) |
| |
|
| | setattr(self, "lin" + str(l), lin) |
| |
|
| | if activation == 'softplus': |
| | self.activation = nn.Softplus(beta=100) |
| | else: |
| | assert activation == 'relu' |
| | self.activation = nn.ReLU() |
| |
|
| | def forward(self, inputs, latent): |
| | inputs = inputs |
| | if self.embed_fn_fine is not None: |
| | inputs = self.embed_fn_fine(inputs) |
| |
|
| | |
| | if latent.shape[1] != self.d_conditional_feature: |
| | latent = torch.cat([latent, latent], dim=1) |
| |
|
| | x = inputs |
| | for l in range(0, self.num_layers - 1): |
| | lin = getattr(self, "lin" + str(l)) |
| |
|
| | |
| | if l in self.skip_in: |
| | x = torch.cat([x, inputs], 1) / np.sqrt(2) |
| |
|
| | if 0 < l < self.num_layers - 1: |
| | x = torch.cat([x, latent], 1) |
| |
|
| | x = lin(x) |
| |
|
| | if l < self.num_layers - 2: |
| | x = self.activation(x) |
| |
|
| | return x |
| |
|
| |
|
| | class SparseSdfNetwork(nn.Module): |
| | ''' |
| | Coarse-to-fine sparse cost regularization network |
| | return sparse volume feature for extracting sdf |
| | ''' |
| |
|
| | def __init__(self, lod, ch_in, voxel_size, vol_dims, |
| | hidden_dim=128, activation='softplus', |
| | cost_type='variance_mean', |
| | d_pyramid_feature_compress=16, |
| | regnet_d_out=8, num_sdf_layers=4, |
| | multires=6, |
| | ): |
| | super(SparseSdfNetwork, self).__init__() |
| |
|
| | self.lod = lod |
| | self.ch_in = ch_in |
| | self.voxel_size = voxel_size |
| | self.vol_dims = torch.tensor(vol_dims) |
| |
|
| | self.selected_views_num = 2 |
| | self.hidden_dim = hidden_dim |
| | self.activation = activation |
| | self.cost_type = cost_type |
| | self.d_pyramid_feature_compress = d_pyramid_feature_compress |
| | self.gru_fusion = None |
| |
|
| | self.regnet_d_out = regnet_d_out |
| | self.multires = multires |
| |
|
| | self.pos_embedder = Embedding(3, self.multires) |
| |
|
| | self.compress_layer = ConvBnReLU( |
| | self.ch_in, self.d_pyramid_feature_compress, 3, 1, 1, |
| | norm_act=InPlaceABN) |
| | sparse_ch_in = self.d_pyramid_feature_compress * 2 |
| |
|
| | sparse_ch_in = sparse_ch_in + 16 if self.lod > 0 else sparse_ch_in |
| | self.sparse_costreg_net = SparseCostRegNet( |
| | d_in=sparse_ch_in, d_out=self.regnet_d_out) |
| | |
| |
|
| | if activation == 'softplus': |
| | self.activation = nn.Softplus(beta=100) |
| | else: |
| | assert activation == 'relu' |
| | self.activation = nn.ReLU() |
| |
|
| | self.sdf_layer = LatentSDFLayer(d_in=3, |
| | d_out=self.hidden_dim + 1, |
| | d_hidden=self.hidden_dim, |
| | n_layers=num_sdf_layers, |
| | multires=multires, |
| | geometric_init=True, |
| | weight_norm=True, |
| | activation=activation, |
| | d_conditional_feature=16 |
| | ) |
| |
|
| | def upsample(self, pre_feat, pre_coords, interval, num=8): |
| | ''' |
| | |
| | :param pre_feat: (Tensor), features from last level, (N, C) |
| | :param pre_coords: (Tensor), coordinates from last level, (N, 4) (4 : Batch ind, x, y, z) |
| | :param interval: interval of voxels, interval = scale ** 2 |
| | :param num: 1 -> 8 |
| | :return: up_feat : (Tensor), upsampled features, (N*8, C) |
| | :return: up_coords: (N*8, 4), upsampled coordinates, (4 : Batch ind, x, y, z) |
| | ''' |
| | with torch.no_grad(): |
| | pos_list = [1, 2, 3, [1, 2], [1, 3], [2, 3], [1, 2, 3]] |
| | n, c = pre_feat.shape |
| | up_feat = pre_feat.unsqueeze(1).expand(-1, num, -1).contiguous() |
| | up_coords = pre_coords.unsqueeze(1).repeat(1, num, 1).contiguous() |
| | for i in range(num - 1): |
| | up_coords[:, i + 1, pos_list[i]] += interval |
| |
|
| | up_feat = up_feat.view(-1, c) |
| | up_coords = up_coords.view(-1, 4) |
| |
|
| | return up_feat, up_coords |
| |
|
| | def aggregate_multiview_features(self, multiview_features, multiview_masks): |
| | """ |
| | aggregate mutli-view features by compute their cost variance |
| | :param multiview_features: (num of voxels, num_of_views, c) |
| | :param multiview_masks: (num of voxels, num_of_views) |
| | :return: |
| | """ |
| | num_pts, n_views, C = multiview_features.shape |
| |
|
| | counts = torch.sum(multiview_masks, dim=1, keepdim=False) |
| |
|
| | assert torch.all(counts > 0) |
| |
|
| | volume_sum = torch.sum(multiview_features, dim=1, keepdim=False) |
| | volume_sq_sum = torch.sum(multiview_features ** 2, dim=1, keepdim=False) |
| |
|
| | if volume_sum.isnan().sum() > 0: |
| | import ipdb; ipdb.set_trace() |
| |
|
| | del multiview_features |
| |
|
| | counts = 1. / (counts + 1e-5) |
| | costvar = volume_sq_sum * counts[:, None] - (volume_sum * counts[:, None]) ** 2 |
| |
|
| | costvar_mean = torch.cat([costvar, volume_sum * counts[:, None]], dim=1) |
| | del volume_sum, volume_sq_sum, counts |
| |
|
| |
|
| |
|
| | return costvar_mean |
| |
|
| | def sparse_to_dense_volume(self, coords, feature, vol_dims, interval, device=None): |
| | """ |
| | convert the sparse volume into dense volume to enable trilinear sampling |
| | to save GPU memory; |
| | :param coords: [num_pts, 3] |
| | :param feature: [num_pts, C] |
| | :param vol_dims: [3] dX, dY, dZ |
| | :param interval: |
| | :return: |
| | """ |
| |
|
| | |
| | if device is None: |
| | device = feature.device |
| |
|
| | coords_int = (coords / interval).to(torch.int64) |
| | vol_dims = (vol_dims / interval).to(torch.int64) |
| |
|
| | |
| | dense_volume = sparse_to_dense_channel( |
| | coords_int.to(device), feature.to(device), vol_dims.to(device), |
| | feature.shape[1], 0, device) |
| |
|
| | valid_mask_volume = sparse_to_dense_channel( |
| | coords_int.to(device), |
| | torch.ones([feature.shape[0], 1]).to(feature.device), |
| | vol_dims.to(device), |
| | 1, 0, device) |
| |
|
| | dense_volume = dense_volume.permute(3, 0, 1, 2).contiguous().unsqueeze(0) |
| | valid_mask_volume = valid_mask_volume.permute(3, 0, 1, 2).contiguous().unsqueeze(0) |
| |
|
| | return dense_volume, valid_mask_volume |
| |
|
| | def get_conditional_volume(self, feature_maps, partial_vol_origin, proj_mats, sizeH=None, sizeW=None, lod=0, |
| | pre_coords=None, pre_feats=None, |
| | ): |
| | """ |
| | |
| | :param feature_maps: pyramid features (B,V,C0+C1+C2,H,W) fused pyramid features |
| | :param partial_vol_origin: [B, 3] the world coordinates of the volume origin (0,0,0) |
| | :param proj_mats: projection matrix transform world pts into image space [B,V,4,4] suitable for original image size |
| | :param sizeH: the H of original image size |
| | :param sizeW: the W of original image size |
| | :param pre_coords: the coordinates of sparse volume from the prior lod |
| | :param pre_feats: the features of sparse volume from the prior lod |
| | :return: |
| | """ |
| | device = proj_mats.device |
| | bs = feature_maps.shape[0] |
| | N_views = feature_maps.shape[1] |
| | minimum_visible_views = np.min([1, N_views - 1]) |
| | |
| | outputs = {} |
| | pts_samples = [] |
| |
|
| | |
| |
|
| | |
| | if self.compress_layer is not None: |
| | feats = self.compress_layer(feature_maps[0]) |
| | else: |
| | feats = feature_maps[0] |
| | feats = feats[:, None, :, :, :] |
| | KRcam = proj_mats.permute(1, 0, 2, 3).contiguous() |
| | interval = 1 |
| |
|
| | if self.lod == 0: |
| | |
| | coords = generate_grid(self.vol_dims, 1)[0] |
| | coords = coords.view(3, -1).to(device) |
| | up_coords = [] |
| | for b in range(bs): |
| | up_coords.append(torch.cat([torch.ones(1, coords.shape[-1]).to(coords.device) * b, coords])) |
| | up_coords = torch.cat(up_coords, dim=1).permute(1, 0).contiguous() |
| | |
| | |
| | |
| | frustum_mask = back_project_sparse_type( |
| | up_coords, partial_vol_origin, self.voxel_size, |
| | feats, KRcam, sizeH=sizeH, sizeW=sizeW, only_mask=True) |
| | frustum_mask = torch.sum(frustum_mask, dim=-1) > minimum_visible_views |
| | up_coords = up_coords[frustum_mask] |
| |
|
| | else: |
| | |
| | assert pre_feats is not None |
| | assert pre_coords is not None |
| | up_feat, up_coords = self.upsample(pre_feats, pre_coords, 1) |
| |
|
| | |
| | |
| | multiview_features, multiview_masks = back_project_sparse_type( |
| | up_coords, partial_vol_origin, self.voxel_size, feats, |
| | KRcam, sizeH=sizeH, sizeW=sizeW) |
| | |
| | |
| | |
| | |
| |
|
| | |
| | if self.lod > 0: |
| | |
| | frustum_mask = torch.sum(multiview_masks, dim=-1) > 1 |
| | up_feat = up_feat[frustum_mask] |
| | up_coords = up_coords[frustum_mask] |
| | multiview_features = multiview_features[frustum_mask] |
| | multiview_masks = multiview_masks[frustum_mask] |
| | |
| | |
| | volume = self.aggregate_multiview_features(multiview_features, multiview_masks) |
| | |
| |
|
| | |
| | |
| |
|
| | del multiview_features, multiview_masks |
| |
|
| | |
| | if self.lod != 0: |
| | feat = torch.cat([volume, up_feat], dim=1) |
| | else: |
| | feat = volume |
| |
|
| | |
| | r_coords = up_coords[:, [1, 2, 3, 0]] |
| |
|
| | |
| | |
| | |
| |
|
| | sparse_feat = SparseTensor(feat, r_coords.to( |
| | torch.int32)) |
| | |
| | feat = self.sparse_costreg_net(sparse_feat) |
| |
|
| | dense_volume, valid_mask_volume = self.sparse_to_dense_volume(up_coords[:, 1:], feat, self.vol_dims, interval, |
| | device=None) |
| |
|
| | |
| | |
| |
|
| |
|
| | outputs['dense_volume_scale%d' % self.lod] = dense_volume |
| | outputs['valid_mask_volume_scale%d' % self.lod] = valid_mask_volume |
| | outputs['visible_mask_scale%d' % self.lod] = valid_mask_volume |
| | outputs['coords_scale%d' % self.lod] = generate_grid(self.vol_dims, interval).to(device) |
| | |
| | return outputs |
| |
|
| | def sdf(self, pts, conditional_volume, lod): |
| | num_pts = pts.shape[0] |
| | device = pts.device |
| | pts_ = pts.clone() |
| | pts = pts.view(1, 1, 1, num_pts, 3) |
| |
|
| | pts = torch.flip(pts, dims=[-1]) |
| | |
| | sampled_feature = grid_sample_3d(conditional_volume, pts) |
| | sampled_feature = sampled_feature.view(-1, num_pts).permute(1, 0).contiguous().to(device) |
| |
|
| | sdf_pts = self.sdf_layer(pts_, sampled_feature) |
| |
|
| | outputs = {} |
| | outputs['sdf_pts_scale%d' % lod] = sdf_pts[:, :1] |
| | outputs['sdf_features_pts_scale%d' % lod] = sdf_pts[:, 1:] |
| | outputs['sampled_latent_scale%d' % lod] = sampled_feature |
| |
|
| | return outputs |
| |
|
| | @torch.no_grad() |
| | def sdf_from_sdfvolume(self, pts, sdf_volume, lod=0): |
| | num_pts = pts.shape[0] |
| | device = pts.device |
| | pts_ = pts.clone() |
| | pts = pts.view(1, 1, 1, num_pts, 3) |
| |
|
| | pts = torch.flip(pts, dims=[-1]) |
| |
|
| | sdf = torch.nn.functional.grid_sample(sdf_volume, pts, mode='bilinear', align_corners=True, |
| | padding_mode='border') |
| | sdf = sdf.view(-1, num_pts).permute(1, 0).contiguous().to(device) |
| |
|
| | outputs = {} |
| | outputs['sdf_pts_scale%d' % lod] = sdf |
| |
|
| | return outputs |
| |
|
| | @torch.no_grad() |
| | def get_sdf_volume(self, conditional_volume, mask_volume, coords_volume, partial_origin): |
| | """ |
| | |
| | :param conditional_volume: [1,C, dX,dY,dZ] |
| | :param mask_volume: [1,1, dX,dY,dZ] |
| | :param coords_volume: [1,3, dX,dY,dZ] |
| | :return: |
| | """ |
| | device = conditional_volume.device |
| | chunk_size = 10240 |
| |
|
| | _, C, dX, dY, dZ = conditional_volume.shape |
| | conditional_volume = conditional_volume.view(C, dX * dY * dZ).permute(1, 0).contiguous() |
| | mask_volume = mask_volume.view(-1) |
| | coords_volume = coords_volume.view(3, dX * dY * dZ).permute(1, 0).contiguous() |
| |
|
| | pts = coords_volume * self.voxel_size + partial_origin |
| |
|
| | sdf_volume = torch.ones([dX * dY * dZ, 1]).float().to(device) |
| |
|
| | conditional_volume = conditional_volume[mask_volume > 0] |
| | pts = pts[mask_volume > 0] |
| | conditional_volume = conditional_volume.split(chunk_size) |
| | pts = pts.split(chunk_size) |
| |
|
| | sdf_all = [] |
| | for pts_part, feature_part in zip(pts, conditional_volume): |
| | sdf_part = self.sdf_layer(pts_part, feature_part)[:, :1] |
| | sdf_all.append(sdf_part) |
| |
|
| | sdf_all = torch.cat(sdf_all, dim=0) |
| | sdf_volume[mask_volume > 0] = sdf_all |
| | sdf_volume = sdf_volume.view(1, 1, dX, dY, dZ) |
| | return sdf_volume |
| |
|
| | def gradient(self, x, conditional_volume, lod): |
| | """ |
| | return the gradient of specific lod |
| | :param x: |
| | :param lod: |
| | :return: |
| | """ |
| | x.requires_grad_(True) |
| | |
| | with torch.enable_grad(): |
| | output = self.sdf(x, conditional_volume, lod) |
| | y = output['sdf_pts_scale%d' % lod] |
| |
|
| | d_output = torch.ones_like(y, requires_grad=False, device=y.device) |
| | |
| | |
| | gradients = torch.autograd.grad( |
| | outputs=y, |
| | inputs=x, |
| | grad_outputs=d_output, |
| | create_graph=True, |
| | retain_graph=True, |
| | only_inputs=True)[0] |
| | return gradients.unsqueeze(1) |
| |
|
| |
|
| | def sparse_to_dense_volume(coords, feature, vol_dims, interval, device=None): |
| | """ |
| | convert the sparse volume into dense volume to enable trilinear sampling |
| | to save GPU memory; |
| | :param coords: [num_pts, 3] |
| | :param feature: [num_pts, C] |
| | :param vol_dims: [3] dX, dY, dZ |
| | :param interval: |
| | :return: |
| | """ |
| |
|
| | |
| | if device is None: |
| | device = feature.device |
| |
|
| | coords_int = (coords / interval).to(torch.int64) |
| | vol_dims = (vol_dims / interval).to(torch.int64) |
| |
|
| | |
| | dense_volume = sparse_to_dense_channel( |
| | coords_int.to(device), feature.to(device), vol_dims.to(device), |
| | feature.shape[1], 0, device) |
| |
|
| | valid_mask_volume = sparse_to_dense_channel( |
| | coords_int.to(device), |
| | torch.ones([feature.shape[0], 1]).to(feature.device), |
| | vol_dims.to(device), |
| | 1, 0, device) |
| |
|
| | dense_volume = dense_volume.permute(3, 0, 1, 2).contiguous().unsqueeze(0) |
| | valid_mask_volume = valid_mask_volume.permute(3, 0, 1, 2).contiguous().unsqueeze(0) |
| |
|
| | return dense_volume, valid_mask_volume |
| |
|
| |
|
| | class SdfVolume(nn.Module): |
| | def __init__(self, volume, coords=None, type='dense'): |
| | super(SdfVolume, self).__init__() |
| | self.volume = torch.nn.Parameter(volume, requires_grad=True) |
| | self.coords = coords |
| | self.type = type |
| |
|
| | def forward(self): |
| | return self.volume |
| |
|
| |
|
| | class FinetuneOctreeSdfNetwork(nn.Module): |
| | ''' |
| | After obtain the conditional volume from generalized network; |
| | directly optimize the conditional volume |
| | The conditional volume is still sparse |
| | ''' |
| |
|
| | def __init__(self, voxel_size, vol_dims, |
| | origin=[-1., -1., -1.], |
| | hidden_dim=128, activation='softplus', |
| | regnet_d_out=8, |
| | multires=6, |
| | if_fitted_rendering=True, |
| | num_sdf_layers=4, |
| | ): |
| | super(FinetuneOctreeSdfNetwork, self).__init__() |
| |
|
| | self.voxel_size = voxel_size |
| | self.vol_dims = torch.tensor(vol_dims) |
| |
|
| | self.origin = torch.tensor(origin).to(torch.float32) |
| |
|
| | self.hidden_dim = hidden_dim |
| | self.activation = activation |
| |
|
| | self.regnet_d_out = regnet_d_out |
| |
|
| | self.if_fitted_rendering = if_fitted_rendering |
| | self.multires = multires |
| | |
| | |
| |
|
| | |
| | self.sparse_volume_lod0 = None |
| | self.sparse_coords_lod0 = None |
| |
|
| | if activation == 'softplus': |
| | self.activation = nn.Softplus(beta=100) |
| | else: |
| | assert activation == 'relu' |
| | self.activation = nn.ReLU() |
| |
|
| | self.sdf_layer = LatentSDFLayer(d_in=3, |
| | d_out=self.hidden_dim + 1, |
| | d_hidden=self.hidden_dim, |
| | n_layers=num_sdf_layers, |
| | multires=multires, |
| | geometric_init=True, |
| | weight_norm=True, |
| | activation=activation, |
| | d_conditional_feature=16 |
| | ) |
| |
|
| | |
| | self.renderer = None |
| |
|
| | d_in_renderer = 3 + self.regnet_d_out + 3 + 3 |
| | self.renderer = BlendingRenderingNetwork( |
| | d_feature=self.hidden_dim - 1, |
| | mode='idr', |
| | d_in=d_in_renderer, |
| | d_out=50, |
| | d_hidden=self.hidden_dim, |
| | n_layers=3, |
| | weight_norm=True, |
| | multires_view=4, |
| | squeeze_out=True, |
| | ) |
| |
|
| | def initialize_conditional_volumes(self, dense_volume_lod0, dense_volume_mask_lod0, |
| | sparse_volume_lod0=None, sparse_coords_lod0=None): |
| | """ |
| | |
| | :param dense_volume_lod0: [1,C,dX,dY,dZ] |
| | :param dense_volume_mask_lod0: [1,1,dX,dY,dZ] |
| | :param dense_volume_lod1: |
| | :param dense_volume_mask_lod1: |
| | :return: |
| | """ |
| |
|
| | if sparse_volume_lod0 is None: |
| | device = dense_volume_lod0.device |
| | _, C, dX, dY, dZ = dense_volume_lod0.shape |
| |
|
| | dense_volume_lod0 = dense_volume_lod0.view(C, dX * dY * dZ).permute(1, 0).contiguous() |
| | mask_lod0 = dense_volume_mask_lod0.view(dX * dY * dZ) > 0 |
| |
|
| | self.sparse_volume_lod0 = SdfVolume(dense_volume_lod0[mask_lod0], type='sparse') |
| |
|
| | coords = generate_grid(self.vol_dims, 1)[0] |
| | coords = coords.view(3, dX * dY * dZ).permute(1, 0).to(device) |
| | self.sparse_coords_lod0 = torch.nn.Parameter(coords[mask_lod0], requires_grad=False) |
| | else: |
| | self.sparse_volume_lod0 = SdfVolume(sparse_volume_lod0, type='sparse') |
| | self.sparse_coords_lod0 = torch.nn.Parameter(sparse_coords_lod0, requires_grad=False) |
| |
|
| | def get_conditional_volume(self): |
| | dense_volume, valid_mask_volume = sparse_to_dense_volume( |
| | self.sparse_coords_lod0, |
| | self.sparse_volume_lod0(), self.vol_dims, interval=1, |
| | device=None) |
| |
|
| | |
| |
|
| | outputs = {} |
| | outputs['dense_volume_scale%d' % 0] = dense_volume |
| | outputs['valid_mask_volume_scale%d' % 0] = valid_mask_volume |
| |
|
| | return outputs |
| |
|
| | def tv_regularizer(self): |
| | dense_volume, valid_mask_volume = sparse_to_dense_volume( |
| | self.sparse_coords_lod0, |
| | self.sparse_volume_lod0(), self.vol_dims, interval=1, |
| | device=None) |
| |
|
| | dx = (dense_volume[:, :, 1:, :, :] - dense_volume[:, :, :-1, :, :]) ** 2 |
| | dy = (dense_volume[:, :, :, 1:, :] - dense_volume[:, :, :, :-1, :]) ** 2 |
| | dz = (dense_volume[:, :, :, :, 1:] - dense_volume[:, :, :, :, :-1]) ** 2 |
| |
|
| | tv = dx[:, :, :, :-1, :-1] + dy[:, :, :-1, :, :-1] + dz[:, :, :-1, :-1, :] |
| |
|
| | mask = valid_mask_volume[:, :, :-1, :-1, :-1] * valid_mask_volume[:, :, 1:, :-1, :-1] * \ |
| | valid_mask_volume[:, :, :-1, 1:, :-1] * valid_mask_volume[:, :, :-1, :-1, 1:] |
| |
|
| | tv = torch.sqrt(tv + 1e-6).mean(dim=1, keepdim=True) * mask |
| | |
| |
|
| | assert torch.all(~torch.isnan(tv)) |
| |
|
| | return torch.mean(tv) |
| |
|
| | def sdf(self, pts, conditional_volume, lod): |
| |
|
| | outputs = {} |
| |
|
| | num_pts = pts.shape[0] |
| | device = pts.device |
| | pts_ = pts.clone() |
| | pts = pts.view(1, 1, 1, num_pts, 3) |
| |
|
| | pts = torch.flip(pts, dims=[-1]) |
| |
|
| | sampled_feature = grid_sample_3d(conditional_volume, pts) |
| | sampled_feature = sampled_feature.view(-1, num_pts).permute(1, 0).contiguous() |
| | outputs['sampled_latent_scale%d' % lod] = sampled_feature |
| |
|
| | sdf_pts = self.sdf_layer(pts_, sampled_feature) |
| |
|
| | lod = 0 |
| | outputs['sdf_pts_scale%d' % lod] = sdf_pts[:, :1] |
| | outputs['sdf_features_pts_scale%d' % lod] = sdf_pts[:, 1:] |
| |
|
| | return outputs |
| |
|
| | def color_blend(self, pts, position, normals, view_dirs, feature_vectors, img_index, |
| | pts_pixel_color, pts_pixel_mask, pts_patch_color=None, pts_patch_mask=None): |
| |
|
| | return self.renderer(torch.cat([pts, position], dim=-1), normals, view_dirs, feature_vectors, |
| | img_index, pts_pixel_color, pts_pixel_mask, |
| | pts_patch_color=pts_patch_color, pts_patch_mask=pts_patch_mask) |
| |
|
| | def gradient(self, x, conditional_volume, lod): |
| | """ |
| | return the gradient of specific lod |
| | :param x: |
| | :param lod: |
| | :return: |
| | """ |
| | x.requires_grad_(True) |
| | output = self.sdf(x, conditional_volume, lod) |
| | y = output['sdf_pts_scale%d' % 0] |
| |
|
| | d_output = torch.ones_like(y, requires_grad=False, device=y.device) |
| |
|
| | gradients = torch.autograd.grad( |
| | outputs=y, |
| | inputs=x, |
| | grad_outputs=d_output, |
| | create_graph=True, |
| | retain_graph=True, |
| | only_inputs=True)[0] |
| | return gradients.unsqueeze(1) |
| |
|
| | @torch.no_grad() |
| | def prune_dense_mask(self, threshold=0.02): |
| | """ |
| | Just gradually prune the mask of dense volume to decrease the number of sdf network inference |
| | :return: |
| | """ |
| | chunk_size = 10240 |
| | coords = generate_grid(self.vol_dims_lod0, 1)[0] |
| |
|
| | _, dX, dY, dZ = coords.shape |
| |
|
| | pts = coords.view(3, -1).permute(1, |
| | 0).contiguous() * self.voxel_size_lod0 + self.origin[None, :] |
| |
|
| | |
| | dense_volume, _ = sparse_to_dense_volume( |
| | self.sparse_coords_lod0, |
| | self.sparse_volume_lod0(), self.vol_dims_lod0, interval=1, |
| | device=None) |
| |
|
| | sdf_volume = torch.ones([dX * dY * dZ, 1]).float().to(dense_volume.device) * 100 |
| |
|
| | mask = self.dense_volume_mask_lod0.view(-1) > 0 |
| |
|
| | pts_valid = pts[mask].to(dense_volume.device) |
| | feature_valid = dense_volume.view(self.regnet_d_out, -1).permute(1, 0).contiguous()[mask] |
| |
|
| | pts_valid = pts_valid.split(chunk_size) |
| | feature_valid = feature_valid.split(chunk_size) |
| |
|
| | sdf_list = [] |
| |
|
| | for pts_part, feature_part in zip(pts_valid, feature_valid): |
| | sdf_part = self.sdf_layer(pts_part, feature_part)[:, :1] |
| | sdf_list.append(sdf_part) |
| |
|
| | sdf_list = torch.cat(sdf_list, dim=0) |
| |
|
| | sdf_volume[mask] = sdf_list |
| |
|
| | occupancy_mask = torch.abs(sdf_volume) < threshold |
| |
|
| | |
| | occupancy_mask = occupancy_mask.float() |
| | occupancy_mask = occupancy_mask.view(1, 1, dX, dY, dZ) |
| | occupancy_mask = F.avg_pool3d(occupancy_mask, kernel_size=7, stride=1, padding=3) |
| | occupancy_mask = occupancy_mask > 0 |
| |
|
| | self.dense_volume_mask_lod0 = torch.logical_and(self.dense_volume_mask_lod0, |
| | occupancy_mask).float() |
| |
|
| |
|
| | class BlendingRenderingNetwork(nn.Module): |
| | def __init__( |
| | self, |
| | d_feature, |
| | mode, |
| | d_in, |
| | d_out, |
| | d_hidden, |
| | n_layers, |
| | weight_norm=True, |
| | multires_view=0, |
| | squeeze_out=True, |
| | ): |
| | super(BlendingRenderingNetwork, self).__init__() |
| |
|
| | self.mode = mode |
| | self.squeeze_out = squeeze_out |
| | dims = [d_in + d_feature] + [d_hidden for _ in range(n_layers)] + [d_out] |
| |
|
| | self.embedder = None |
| | if multires_view > 0: |
| | self.embedder = Embedding(3, multires_view) |
| | dims[0] += (self.embedder.out_channels - 3) |
| |
|
| | self.num_layers = len(dims) |
| |
|
| | for l in range(0, self.num_layers - 1): |
| | out_dim = dims[l + 1] |
| | lin = nn.Linear(dims[l], out_dim) |
| |
|
| | if weight_norm: |
| | lin = nn.utils.weight_norm(lin) |
| |
|
| | setattr(self, "lin" + str(l), lin) |
| |
|
| | self.relu = nn.ReLU() |
| |
|
| | self.color_volume = None |
| |
|
| | self.softmax = nn.Softmax(dim=1) |
| |
|
| | self.type = 'blending' |
| |
|
| | def sample_pts_from_colorVolume(self, pts): |
| | device = pts.device |
| | num_pts = pts.shape[0] |
| | pts_ = pts.clone() |
| | pts = pts.view(1, 1, 1, num_pts, 3) |
| |
|
| | pts = torch.flip(pts, dims=[-1]) |
| |
|
| | sampled_color = grid_sample_3d(self.color_volume, pts) |
| | sampled_color = sampled_color.view(-1, num_pts).permute(1, 0).contiguous().to(device) |
| |
|
| | return sampled_color |
| |
|
| | def forward(self, position, normals, view_dirs, feature_vectors, img_index, |
| | pts_pixel_color, pts_pixel_mask, pts_patch_color=None, pts_patch_mask=None): |
| | """ |
| | |
| | :param position: can be 3d coord or interpolated volume latent |
| | :param normals: |
| | :param view_dirs: |
| | :param feature_vectors: |
| | :param img_index: [N_views], used to extract corresponding weights |
| | :param pts_pixel_color: [N_pts, N_views, 3] |
| | :param pts_pixel_mask: [N_pts, N_views] |
| | :param pts_patch_color: [N_pts, N_views, Npx, 3] |
| | :return: |
| | """ |
| | if self.embedder is not None: |
| | view_dirs = self.embedder(view_dirs) |
| |
|
| | rendering_input = None |
| |
|
| | if self.mode == 'idr': |
| | rendering_input = torch.cat([position, view_dirs, normals, feature_vectors], dim=-1) |
| | elif self.mode == 'no_view_dir': |
| | rendering_input = torch.cat([position, normals, feature_vectors], dim=-1) |
| | elif self.mode == 'no_normal': |
| | rendering_input = torch.cat([position, view_dirs, feature_vectors], dim=-1) |
| | elif self.mode == 'no_points': |
| | rendering_input = torch.cat([view_dirs, normals, feature_vectors], dim=-1) |
| | elif self.mode == 'no_points_no_view_dir': |
| | rendering_input = torch.cat([normals, feature_vectors], dim=-1) |
| |
|
| | x = rendering_input |
| |
|
| | for l in range(0, self.num_layers - 1): |
| | lin = getattr(self, "lin" + str(l)) |
| |
|
| | x = lin(x) |
| |
|
| | if l < self.num_layers - 2: |
| | x = self.relu(x) |
| |
|
| | |
| | x_extracted = torch.index_select(x, 1, img_index.long()) |
| |
|
| | weights_pixel = self.softmax(x_extracted) |
| | weights_pixel = weights_pixel * pts_pixel_mask |
| | weights_pixel = weights_pixel / ( |
| | torch.sum(weights_pixel.float(), dim=1, keepdim=True) + 1e-8) |
| | final_pixel_color = torch.sum(pts_pixel_color * weights_pixel[:, :, None], dim=1, |
| | keepdim=False) |
| |
|
| | final_pixel_mask = torch.sum(pts_pixel_mask.float(), dim=1, keepdim=True) > 0 |
| |
|
| | final_patch_color, final_patch_mask = None, None |
| | |
| | if pts_patch_color is not None: |
| | N_pts, N_views, Npx, _ = pts_patch_color.shape |
| | patch_mask = torch.sum(pts_patch_mask, dim=-1, keepdim=False) > Npx - 1 |
| |
|
| | weights_patch = self.softmax(x_extracted) |
| | weights_patch = weights_patch * patch_mask |
| | weights_patch = weights_patch / ( |
| | torch.sum(weights_patch.float(), dim=1, keepdim=True) + 1e-8) |
| |
|
| | final_patch_color = torch.sum(pts_patch_color * weights_patch[:, :, None, None], dim=1, |
| | keepdim=False) |
| | final_patch_mask = torch.sum(patch_mask, dim=1, keepdim=True) > 0 |
| |
|
| | return final_pixel_color, final_pixel_mask, final_patch_color, final_patch_mask |
| |
|