| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torchsparse.tensor import PointTensor, SparseTensor |
| import torchsparse.nn as spnn |
|
|
| from tsparse.modules import SparseCostRegNet |
| from tsparse.torchsparse_utils import sparse_to_dense_channel |
| from ops.grid_sampler import grid_sample_3d, tricubic_sample_3d |
|
|
| |
| from ops.back_project import back_project_sparse_type |
| from ops.generate_grids import generate_grid |
|
|
| from inplace_abn import InPlaceABN |
|
|
| from models.embedder import Embedding |
| from models.featurenet import ConvBnReLU |
|
|
| import pdb |
| import random |
|
|
| torch._C._jit_set_profiling_executor(False) |
| torch._C._jit_set_profiling_mode(False) |
|
|
|
|
| @torch.jit.script |
| def fused_mean_variance(x, weight): |
| mean = torch.sum(x * weight, dim=1, keepdim=True) |
| var = torch.sum(weight * (x - mean) ** 2, dim=1, keepdim=True) |
| return mean, var |
|
|
|
|
| class LatentSDFLayer(nn.Module): |
| def __init__(self, |
| d_in=3, |
| d_out=129, |
| d_hidden=128, |
| n_layers=4, |
| skip_in=(4,), |
| multires=0, |
| bias=0.5, |
| geometric_init=True, |
| weight_norm=True, |
| activation='softplus', |
| d_conditional_feature=16): |
| super(LatentSDFLayer, self).__init__() |
|
|
| self.d_conditional_feature = d_conditional_feature |
|
|
| |
| dims_in = [d_in] + [d_hidden + d_conditional_feature for _ in range(n_layers - 2)] + [d_hidden] |
| dims_out = [d_hidden for _ in range(n_layers - 1)] + [d_out] |
|
|
| self.embed_fn_fine = None |
|
|
| if multires > 0: |
| embed_fn = Embedding(in_channels=d_in, N_freqs=multires) |
| self.embed_fn_fine = embed_fn |
| dims_in[0] = embed_fn.out_channels |
|
|
| self.num_layers = n_layers |
| self.skip_in = skip_in |
|
|
| for l in range(0, self.num_layers - 1): |
| if l in self.skip_in: |
| in_dim = dims_in[l] + dims_in[0] |
| else: |
| in_dim = dims_in[l] |
|
|
| out_dim = dims_out[l] |
| lin = nn.Linear(in_dim, out_dim) |
|
|
| if geometric_init: |
| if l == self.num_layers - 2: |
| torch.nn.init.normal_(lin.weight, mean=np.sqrt(np.pi) / np.sqrt(in_dim), std=0.0001) |
| torch.nn.init.constant_(lin.bias, -bias) |
| |
| torch.nn.init.constant_(lin.weight[:, -d_conditional_feature:], 0.0) |
| torch.nn.init.constant_(lin.bias[-d_conditional_feature:], 0.0) |
|
|
| elif multires > 0 and l == 0: |
| torch.nn.init.constant_(lin.bias, 0.0) |
| |
| torch.nn.init.constant_(lin.weight[:, 3:], 0.0) |
| |
| torch.nn.init.normal_(lin.weight[:, :3], 0.0, np.sqrt(2) / np.sqrt(out_dim)) |
| elif multires > 0 and l in self.skip_in: |
| torch.nn.init.constant_(lin.bias, 0.0) |
| torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) |
| |
| torch.nn.init.constant_(lin.weight[:, -(dims_in[0] - 3 + d_conditional_feature):], 0.0) |
| else: |
| torch.nn.init.constant_(lin.bias, 0.0) |
| torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) |
| |
| torch.nn.init.constant_(lin.weight[:, -d_conditional_feature:], 0.0) |
|
|
| if weight_norm: |
| lin = nn.utils.weight_norm(lin) |
|
|
| setattr(self, "lin" + str(l), lin) |
|
|
| if activation == 'softplus': |
| self.activation = nn.Softplus(beta=100) |
| else: |
| assert activation == 'relu' |
| self.activation = nn.ReLU() |
|
|
| def forward(self, inputs, latent): |
| inputs = inputs |
| if self.embed_fn_fine is not None: |
| inputs = self.embed_fn_fine(inputs) |
|
|
| |
| if latent.shape[1] != self.d_conditional_feature: |
| latent = torch.cat([latent, latent], dim=1) |
|
|
| x = inputs |
| for l in range(0, self.num_layers - 1): |
| lin = getattr(self, "lin" + str(l)) |
|
|
| |
| if l in self.skip_in: |
| x = torch.cat([x, inputs], 1) / np.sqrt(2) |
|
|
| if 0 < l < self.num_layers - 1: |
| x = torch.cat([x, latent], 1) |
|
|
| x = lin(x) |
|
|
| if l < self.num_layers - 2: |
| x = self.activation(x) |
|
|
| return x |
|
|
|
|
| class SparseSdfNetwork(nn.Module): |
| ''' |
| Coarse-to-fine sparse cost regularization network |
| return sparse volume feature for extracting sdf |
| ''' |
|
|
| def __init__(self, lod, ch_in, voxel_size, vol_dims, |
| hidden_dim=128, activation='softplus', |
| cost_type='variance_mean', |
| d_pyramid_feature_compress=16, |
| regnet_d_out=8, num_sdf_layers=4, |
| multires=6, |
| ): |
| super(SparseSdfNetwork, self).__init__() |
|
|
| self.lod = lod |
| self.ch_in = ch_in |
| self.voxel_size = voxel_size |
| self.vol_dims = torch.tensor(vol_dims) |
|
|
| self.selected_views_num = 2 |
| self.hidden_dim = hidden_dim |
| self.activation = activation |
| self.cost_type = cost_type |
| self.d_pyramid_feature_compress = d_pyramid_feature_compress |
| self.gru_fusion = None |
|
|
| self.regnet_d_out = regnet_d_out |
| self.multires = multires |
|
|
| self.pos_embedder = Embedding(3, self.multires) |
|
|
| self.compress_layer = ConvBnReLU( |
| self.ch_in, self.d_pyramid_feature_compress, 3, 1, 1, |
| norm_act=InPlaceABN) |
| sparse_ch_in = self.d_pyramid_feature_compress * 2 |
|
|
| sparse_ch_in = sparse_ch_in + 16 if self.lod > 0 else sparse_ch_in |
| self.sparse_costreg_net = SparseCostRegNet( |
| d_in=sparse_ch_in, d_out=self.regnet_d_out) |
| |
|
|
| if activation == 'softplus': |
| self.activation = nn.Softplus(beta=100) |
| else: |
| assert activation == 'relu' |
| self.activation = nn.ReLU() |
|
|
| self.sdf_layer = LatentSDFLayer(d_in=3, |
| d_out=self.hidden_dim + 1, |
| d_hidden=self.hidden_dim, |
| n_layers=num_sdf_layers, |
| multires=multires, |
| geometric_init=True, |
| weight_norm=True, |
| activation=activation, |
| d_conditional_feature=16 |
| ) |
|
|
| def upsample(self, pre_feat, pre_coords, interval, num=8): |
| ''' |
| |
| :param pre_feat: (Tensor), features from last level, (N, C) |
| :param pre_coords: (Tensor), coordinates from last level, (N, 4) (4 : Batch ind, x, y, z) |
| :param interval: interval of voxels, interval = scale ** 2 |
| :param num: 1 -> 8 |
| :return: up_feat : (Tensor), upsampled features, (N*8, C) |
| :return: up_coords: (N*8, 4), upsampled coordinates, (4 : Batch ind, x, y, z) |
| ''' |
| with torch.no_grad(): |
| pos_list = [1, 2, 3, [1, 2], [1, 3], [2, 3], [1, 2, 3]] |
| n, c = pre_feat.shape |
| up_feat = pre_feat.unsqueeze(1).expand(-1, num, -1).contiguous() |
| up_coords = pre_coords.unsqueeze(1).repeat(1, num, 1).contiguous() |
| for i in range(num - 1): |
| up_coords[:, i + 1, pos_list[i]] += interval |
|
|
| up_feat = up_feat.view(-1, c) |
| up_coords = up_coords.view(-1, 4) |
|
|
| return up_feat, up_coords |
|
|
| def aggregate_multiview_features(self, multiview_features, multiview_masks): |
| """ |
| aggregate mutli-view features by compute their cost variance |
| :param multiview_features: (num of voxels, num_of_views, c) |
| :param multiview_masks: (num of voxels, num_of_views) |
| :return: |
| """ |
| num_pts, n_views, C = multiview_features.shape |
|
|
| counts = torch.sum(multiview_masks, dim=1, keepdim=False) |
|
|
| assert torch.all(counts > 0) |
|
|
| volume_sum = torch.sum(multiview_features, dim=1, keepdim=False) |
| volume_sq_sum = torch.sum(multiview_features ** 2, dim=1, keepdim=False) |
|
|
| if volume_sum.isnan().sum() > 0: |
| import ipdb; ipdb.set_trace() |
|
|
| del multiview_features |
|
|
| counts = 1. / (counts + 1e-5) |
| costvar = volume_sq_sum * counts[:, None] - (volume_sum * counts[:, None]) ** 2 |
|
|
| costvar_mean = torch.cat([costvar, volume_sum * counts[:, None]], dim=1) |
| del volume_sum, volume_sq_sum, counts |
|
|
|
|
|
|
| return costvar_mean |
|
|
| def sparse_to_dense_volume(self, coords, feature, vol_dims, interval, device=None): |
| """ |
| convert the sparse volume into dense volume to enable trilinear sampling |
| to save GPU memory; |
| :param coords: [num_pts, 3] |
| :param feature: [num_pts, C] |
| :param vol_dims: [3] dX, dY, dZ |
| :param interval: |
| :return: |
| """ |
|
|
| |
| if device is None: |
| device = feature.device |
|
|
| coords_int = (coords / interval).to(torch.int64) |
| vol_dims = (vol_dims / interval).to(torch.int64) |
|
|
| |
| dense_volume = sparse_to_dense_channel( |
| coords_int.to(device), feature.to(device), vol_dims.to(device), |
| feature.shape[1], 0, device) |
|
|
| valid_mask_volume = sparse_to_dense_channel( |
| coords_int.to(device), |
| torch.ones([feature.shape[0], 1]).to(feature.device), |
| vol_dims.to(device), |
| 1, 0, device) |
|
|
| dense_volume = dense_volume.permute(3, 0, 1, 2).contiguous().unsqueeze(0) |
| valid_mask_volume = valid_mask_volume.permute(3, 0, 1, 2).contiguous().unsqueeze(0) |
|
|
| return dense_volume, valid_mask_volume |
|
|
| def get_conditional_volume(self, feature_maps, partial_vol_origin, proj_mats, sizeH=None, sizeW=None, lod=0, |
| pre_coords=None, pre_feats=None, |
| ): |
| """ |
| |
| :param feature_maps: pyramid features (B,V,C0+C1+C2,H,W) fused pyramid features |
| :param partial_vol_origin: [B, 3] the world coordinates of the volume origin (0,0,0) |
| :param proj_mats: projection matrix transform world pts into image space [B,V,4,4] suitable for original image size |
| :param sizeH: the H of original image size |
| :param sizeW: the W of original image size |
| :param pre_coords: the coordinates of sparse volume from the prior lod |
| :param pre_feats: the features of sparse volume from the prior lod |
| :return: |
| """ |
| device = proj_mats.device |
| bs = feature_maps.shape[0] |
| N_views = feature_maps.shape[1] |
| minimum_visible_views = np.min([1, N_views - 1]) |
| |
| outputs = {} |
| pts_samples = [] |
|
|
| |
|
|
| |
| if self.compress_layer is not None: |
| feats = self.compress_layer(feature_maps[0]) |
| else: |
| feats = feature_maps[0] |
| feats = feats[:, None, :, :, :] |
| KRcam = proj_mats.permute(1, 0, 2, 3).contiguous() |
| interval = 1 |
|
|
| if self.lod == 0: |
| |
| coords = generate_grid(self.vol_dims, 1)[0] |
| coords = coords.view(3, -1).to(device) |
| up_coords = [] |
| for b in range(bs): |
| up_coords.append(torch.cat([torch.ones(1, coords.shape[-1]).to(coords.device) * b, coords])) |
| up_coords = torch.cat(up_coords, dim=1).permute(1, 0).contiguous() |
| |
| |
| |
| frustum_mask = back_project_sparse_type( |
| up_coords, partial_vol_origin, self.voxel_size, |
| feats, KRcam, sizeH=sizeH, sizeW=sizeW, only_mask=True) |
| frustum_mask = torch.sum(frustum_mask, dim=-1) > minimum_visible_views |
| up_coords = up_coords[frustum_mask] |
|
|
| else: |
| |
| assert pre_feats is not None |
| assert pre_coords is not None |
| up_feat, up_coords = self.upsample(pre_feats, pre_coords, 1) |
|
|
| |
| |
| multiview_features, multiview_masks = back_project_sparse_type( |
| up_coords, partial_vol_origin, self.voxel_size, feats, |
| KRcam, sizeH=sizeH, sizeW=sizeW) |
| |
| |
| |
| |
|
|
| |
| if self.lod > 0: |
| |
| frustum_mask = torch.sum(multiview_masks, dim=-1) > 1 |
| up_feat = up_feat[frustum_mask] |
| up_coords = up_coords[frustum_mask] |
| multiview_features = multiview_features[frustum_mask] |
| multiview_masks = multiview_masks[frustum_mask] |
| |
| |
| volume = self.aggregate_multiview_features(multiview_features, multiview_masks) |
| |
|
|
| |
| |
|
|
| del multiview_features, multiview_masks |
|
|
| |
| if self.lod != 0: |
| feat = torch.cat([volume, up_feat], dim=1) |
| else: |
| feat = volume |
|
|
| |
| r_coords = up_coords[:, [1, 2, 3, 0]] |
|
|
| |
| |
| |
|
|
| sparse_feat = SparseTensor(feat, r_coords.to( |
| torch.int32)) |
| |
| feat = self.sparse_costreg_net(sparse_feat) |
|
|
| dense_volume, valid_mask_volume = self.sparse_to_dense_volume(up_coords[:, 1:], feat, self.vol_dims, interval, |
| device=None) |
|
|
| |
| |
|
|
|
|
| outputs['dense_volume_scale%d' % self.lod] = dense_volume |
| outputs['valid_mask_volume_scale%d' % self.lod] = valid_mask_volume |
| outputs['visible_mask_scale%d' % self.lod] = valid_mask_volume |
| outputs['coords_scale%d' % self.lod] = generate_grid(self.vol_dims, interval).to(device) |
| |
| return outputs |
|
|
| def sdf(self, pts, conditional_volume, lod): |
| num_pts = pts.shape[0] |
| device = pts.device |
| pts_ = pts.clone() |
| pts = pts.view(1, 1, 1, num_pts, 3) |
|
|
| pts = torch.flip(pts, dims=[-1]) |
| |
| sampled_feature = grid_sample_3d(conditional_volume, pts) |
| sampled_feature = sampled_feature.view(-1, num_pts).permute(1, 0).contiguous().to(device) |
|
|
| sdf_pts = self.sdf_layer(pts_, sampled_feature) |
|
|
| outputs = {} |
| outputs['sdf_pts_scale%d' % lod] = sdf_pts[:, :1] |
| outputs['sdf_features_pts_scale%d' % lod] = sdf_pts[:, 1:] |
| outputs['sampled_latent_scale%d' % lod] = sampled_feature |
|
|
| return outputs |
|
|
| @torch.no_grad() |
| def sdf_from_sdfvolume(self, pts, sdf_volume, lod=0): |
| num_pts = pts.shape[0] |
| device = pts.device |
| pts_ = pts.clone() |
| pts = pts.view(1, 1, 1, num_pts, 3) |
|
|
| pts = torch.flip(pts, dims=[-1]) |
|
|
| sdf = torch.nn.functional.grid_sample(sdf_volume, pts, mode='bilinear', align_corners=True, |
| padding_mode='border') |
| sdf = sdf.view(-1, num_pts).permute(1, 0).contiguous().to(device) |
|
|
| outputs = {} |
| outputs['sdf_pts_scale%d' % lod] = sdf |
|
|
| return outputs |
|
|
| @torch.no_grad() |
| def get_sdf_volume(self, conditional_volume, mask_volume, coords_volume, partial_origin): |
| """ |
| |
| :param conditional_volume: [1,C, dX,dY,dZ] |
| :param mask_volume: [1,1, dX,dY,dZ] |
| :param coords_volume: [1,3, dX,dY,dZ] |
| :return: |
| """ |
| device = conditional_volume.device |
| chunk_size = 10240 |
|
|
| _, C, dX, dY, dZ = conditional_volume.shape |
| conditional_volume = conditional_volume.view(C, dX * dY * dZ).permute(1, 0).contiguous() |
| mask_volume = mask_volume.view(-1) |
| coords_volume = coords_volume.view(3, dX * dY * dZ).permute(1, 0).contiguous() |
|
|
| pts = coords_volume * self.voxel_size + partial_origin |
|
|
| sdf_volume = torch.ones([dX * dY * dZ, 1]).float().to(device) |
|
|
| conditional_volume = conditional_volume[mask_volume > 0] |
| pts = pts[mask_volume > 0] |
| conditional_volume = conditional_volume.split(chunk_size) |
| pts = pts.split(chunk_size) |
|
|
| sdf_all = [] |
| for pts_part, feature_part in zip(pts, conditional_volume): |
| sdf_part = self.sdf_layer(pts_part, feature_part)[:, :1] |
| sdf_all.append(sdf_part) |
|
|
| sdf_all = torch.cat(sdf_all, dim=0) |
| sdf_volume[mask_volume > 0] = sdf_all |
| sdf_volume = sdf_volume.view(1, 1, dX, dY, dZ) |
| return sdf_volume |
|
|
| def gradient(self, x, conditional_volume, lod): |
| """ |
| return the gradient of specific lod |
| :param x: |
| :param lod: |
| :return: |
| """ |
| x.requires_grad_(True) |
| |
| with torch.enable_grad(): |
| output = self.sdf(x, conditional_volume, lod) |
| y = output['sdf_pts_scale%d' % lod] |
|
|
| d_output = torch.ones_like(y, requires_grad=False, device=y.device) |
| |
| |
| gradients = torch.autograd.grad( |
| outputs=y, |
| inputs=x, |
| grad_outputs=d_output, |
| create_graph=True, |
| retain_graph=True, |
| only_inputs=True)[0] |
| return gradients.unsqueeze(1) |
|
|
|
|
| def sparse_to_dense_volume(coords, feature, vol_dims, interval, device=None): |
| """ |
| convert the sparse volume into dense volume to enable trilinear sampling |
| to save GPU memory; |
| :param coords: [num_pts, 3] |
| :param feature: [num_pts, C] |
| :param vol_dims: [3] dX, dY, dZ |
| :param interval: |
| :return: |
| """ |
|
|
| |
| if device is None: |
| device = feature.device |
|
|
| coords_int = (coords / interval).to(torch.int64) |
| vol_dims = (vol_dims / interval).to(torch.int64) |
|
|
| |
| dense_volume = sparse_to_dense_channel( |
| coords_int.to(device), feature.to(device), vol_dims.to(device), |
| feature.shape[1], 0, device) |
|
|
| valid_mask_volume = sparse_to_dense_channel( |
| coords_int.to(device), |
| torch.ones([feature.shape[0], 1]).to(feature.device), |
| vol_dims.to(device), |
| 1, 0, device) |
|
|
| dense_volume = dense_volume.permute(3, 0, 1, 2).contiguous().unsqueeze(0) |
| valid_mask_volume = valid_mask_volume.permute(3, 0, 1, 2).contiguous().unsqueeze(0) |
|
|
| return dense_volume, valid_mask_volume |
|
|
|
|
| class SdfVolume(nn.Module): |
| def __init__(self, volume, coords=None, type='dense'): |
| super(SdfVolume, self).__init__() |
| self.volume = torch.nn.Parameter(volume, requires_grad=True) |
| self.coords = coords |
| self.type = type |
|
|
| def forward(self): |
| return self.volume |
|
|
|
|
| class FinetuneOctreeSdfNetwork(nn.Module): |
| ''' |
| After obtain the conditional volume from generalized network; |
| directly optimize the conditional volume |
| The conditional volume is still sparse |
| ''' |
|
|
| def __init__(self, voxel_size, vol_dims, |
| origin=[-1., -1., -1.], |
| hidden_dim=128, activation='softplus', |
| regnet_d_out=8, |
| multires=6, |
| if_fitted_rendering=True, |
| num_sdf_layers=4, |
| ): |
| super(FinetuneOctreeSdfNetwork, self).__init__() |
|
|
| self.voxel_size = voxel_size |
| self.vol_dims = torch.tensor(vol_dims) |
|
|
| self.origin = torch.tensor(origin).to(torch.float32) |
|
|
| self.hidden_dim = hidden_dim |
| self.activation = activation |
|
|
| self.regnet_d_out = regnet_d_out |
|
|
| self.if_fitted_rendering = if_fitted_rendering |
| self.multires = multires |
| |
| |
|
|
| |
| self.sparse_volume_lod0 = None |
| self.sparse_coords_lod0 = None |
|
|
| if activation == 'softplus': |
| self.activation = nn.Softplus(beta=100) |
| else: |
| assert activation == 'relu' |
| self.activation = nn.ReLU() |
|
|
| self.sdf_layer = LatentSDFLayer(d_in=3, |
| d_out=self.hidden_dim + 1, |
| d_hidden=self.hidden_dim, |
| n_layers=num_sdf_layers, |
| multires=multires, |
| geometric_init=True, |
| weight_norm=True, |
| activation=activation, |
| d_conditional_feature=16 |
| ) |
|
|
| |
| self.renderer = None |
|
|
| d_in_renderer = 3 + self.regnet_d_out + 3 + 3 |
| self.renderer = BlendingRenderingNetwork( |
| d_feature=self.hidden_dim - 1, |
| mode='idr', |
| d_in=d_in_renderer, |
| d_out=50, |
| d_hidden=self.hidden_dim, |
| n_layers=3, |
| weight_norm=True, |
| multires_view=4, |
| squeeze_out=True, |
| ) |
|
|
| def initialize_conditional_volumes(self, dense_volume_lod0, dense_volume_mask_lod0, |
| sparse_volume_lod0=None, sparse_coords_lod0=None): |
| """ |
| |
| :param dense_volume_lod0: [1,C,dX,dY,dZ] |
| :param dense_volume_mask_lod0: [1,1,dX,dY,dZ] |
| :param dense_volume_lod1: |
| :param dense_volume_mask_lod1: |
| :return: |
| """ |
|
|
| if sparse_volume_lod0 is None: |
| device = dense_volume_lod0.device |
| _, C, dX, dY, dZ = dense_volume_lod0.shape |
|
|
| dense_volume_lod0 = dense_volume_lod0.view(C, dX * dY * dZ).permute(1, 0).contiguous() |
| mask_lod0 = dense_volume_mask_lod0.view(dX * dY * dZ) > 0 |
|
|
| self.sparse_volume_lod0 = SdfVolume(dense_volume_lod0[mask_lod0], type='sparse') |
|
|
| coords = generate_grid(self.vol_dims, 1)[0] |
| coords = coords.view(3, dX * dY * dZ).permute(1, 0).to(device) |
| self.sparse_coords_lod0 = torch.nn.Parameter(coords[mask_lod0], requires_grad=False) |
| else: |
| self.sparse_volume_lod0 = SdfVolume(sparse_volume_lod0, type='sparse') |
| self.sparse_coords_lod0 = torch.nn.Parameter(sparse_coords_lod0, requires_grad=False) |
|
|
| def get_conditional_volume(self): |
| dense_volume, valid_mask_volume = sparse_to_dense_volume( |
| self.sparse_coords_lod0, |
| self.sparse_volume_lod0(), self.vol_dims, interval=1, |
| device=None) |
|
|
| |
|
|
| outputs = {} |
| outputs['dense_volume_scale%d' % 0] = dense_volume |
| outputs['valid_mask_volume_scale%d' % 0] = valid_mask_volume |
|
|
| return outputs |
|
|
| def tv_regularizer(self): |
| dense_volume, valid_mask_volume = sparse_to_dense_volume( |
| self.sparse_coords_lod0, |
| self.sparse_volume_lod0(), self.vol_dims, interval=1, |
| device=None) |
|
|
| dx = (dense_volume[:, :, 1:, :, :] - dense_volume[:, :, :-1, :, :]) ** 2 |
| dy = (dense_volume[:, :, :, 1:, :] - dense_volume[:, :, :, :-1, :]) ** 2 |
| dz = (dense_volume[:, :, :, :, 1:] - dense_volume[:, :, :, :, :-1]) ** 2 |
|
|
| tv = dx[:, :, :, :-1, :-1] + dy[:, :, :-1, :, :-1] + dz[:, :, :-1, :-1, :] |
|
|
| mask = valid_mask_volume[:, :, :-1, :-1, :-1] * valid_mask_volume[:, :, 1:, :-1, :-1] * \ |
| valid_mask_volume[:, :, :-1, 1:, :-1] * valid_mask_volume[:, :, :-1, :-1, 1:] |
|
|
| tv = torch.sqrt(tv + 1e-6).mean(dim=1, keepdim=True) * mask |
| |
|
|
| assert torch.all(~torch.isnan(tv)) |
|
|
| return torch.mean(tv) |
|
|
| def sdf(self, pts, conditional_volume, lod): |
|
|
| outputs = {} |
|
|
| num_pts = pts.shape[0] |
| device = pts.device |
| pts_ = pts.clone() |
| pts = pts.view(1, 1, 1, num_pts, 3) |
|
|
| pts = torch.flip(pts, dims=[-1]) |
|
|
| sampled_feature = grid_sample_3d(conditional_volume, pts) |
| sampled_feature = sampled_feature.view(-1, num_pts).permute(1, 0).contiguous() |
| outputs['sampled_latent_scale%d' % lod] = sampled_feature |
|
|
| sdf_pts = self.sdf_layer(pts_, sampled_feature) |
|
|
| lod = 0 |
| outputs['sdf_pts_scale%d' % lod] = sdf_pts[:, :1] |
| outputs['sdf_features_pts_scale%d' % lod] = sdf_pts[:, 1:] |
|
|
| return outputs |
|
|
| def color_blend(self, pts, position, normals, view_dirs, feature_vectors, img_index, |
| pts_pixel_color, pts_pixel_mask, pts_patch_color=None, pts_patch_mask=None): |
|
|
| return self.renderer(torch.cat([pts, position], dim=-1), normals, view_dirs, feature_vectors, |
| img_index, pts_pixel_color, pts_pixel_mask, |
| pts_patch_color=pts_patch_color, pts_patch_mask=pts_patch_mask) |
|
|
| def gradient(self, x, conditional_volume, lod): |
| """ |
| return the gradient of specific lod |
| :param x: |
| :param lod: |
| :return: |
| """ |
| x.requires_grad_(True) |
| output = self.sdf(x, conditional_volume, lod) |
| y = output['sdf_pts_scale%d' % 0] |
|
|
| d_output = torch.ones_like(y, requires_grad=False, device=y.device) |
|
|
| gradients = torch.autograd.grad( |
| outputs=y, |
| inputs=x, |
| grad_outputs=d_output, |
| create_graph=True, |
| retain_graph=True, |
| only_inputs=True)[0] |
| return gradients.unsqueeze(1) |
|
|
| @torch.no_grad() |
| def prune_dense_mask(self, threshold=0.02): |
| """ |
| Just gradually prune the mask of dense volume to decrease the number of sdf network inference |
| :return: |
| """ |
| chunk_size = 10240 |
| coords = generate_grid(self.vol_dims_lod0, 1)[0] |
|
|
| _, dX, dY, dZ = coords.shape |
|
|
| pts = coords.view(3, -1).permute(1, |
| 0).contiguous() * self.voxel_size_lod0 + self.origin[None, :] |
|
|
| |
| dense_volume, _ = sparse_to_dense_volume( |
| self.sparse_coords_lod0, |
| self.sparse_volume_lod0(), self.vol_dims_lod0, interval=1, |
| device=None) |
|
|
| sdf_volume = torch.ones([dX * dY * dZ, 1]).float().to(dense_volume.device) * 100 |
|
|
| mask = self.dense_volume_mask_lod0.view(-1) > 0 |
|
|
| pts_valid = pts[mask].to(dense_volume.device) |
| feature_valid = dense_volume.view(self.regnet_d_out, -1).permute(1, 0).contiguous()[mask] |
|
|
| pts_valid = pts_valid.split(chunk_size) |
| feature_valid = feature_valid.split(chunk_size) |
|
|
| sdf_list = [] |
|
|
| for pts_part, feature_part in zip(pts_valid, feature_valid): |
| sdf_part = self.sdf_layer(pts_part, feature_part)[:, :1] |
| sdf_list.append(sdf_part) |
|
|
| sdf_list = torch.cat(sdf_list, dim=0) |
|
|
| sdf_volume[mask] = sdf_list |
|
|
| occupancy_mask = torch.abs(sdf_volume) < threshold |
|
|
| |
| occupancy_mask = occupancy_mask.float() |
| occupancy_mask = occupancy_mask.view(1, 1, dX, dY, dZ) |
| occupancy_mask = F.avg_pool3d(occupancy_mask, kernel_size=7, stride=1, padding=3) |
| occupancy_mask = occupancy_mask > 0 |
|
|
| self.dense_volume_mask_lod0 = torch.logical_and(self.dense_volume_mask_lod0, |
| occupancy_mask).float() |
|
|
|
|
| class BlendingRenderingNetwork(nn.Module): |
| def __init__( |
| self, |
| d_feature, |
| mode, |
| d_in, |
| d_out, |
| d_hidden, |
| n_layers, |
| weight_norm=True, |
| multires_view=0, |
| squeeze_out=True, |
| ): |
| super(BlendingRenderingNetwork, self).__init__() |
|
|
| self.mode = mode |
| self.squeeze_out = squeeze_out |
| dims = [d_in + d_feature] + [d_hidden for _ in range(n_layers)] + [d_out] |
|
|
| self.embedder = None |
| if multires_view > 0: |
| self.embedder = Embedding(3, multires_view) |
| dims[0] += (self.embedder.out_channels - 3) |
|
|
| self.num_layers = len(dims) |
|
|
| for l in range(0, self.num_layers - 1): |
| out_dim = dims[l + 1] |
| lin = nn.Linear(dims[l], out_dim) |
|
|
| if weight_norm: |
| lin = nn.utils.weight_norm(lin) |
|
|
| setattr(self, "lin" + str(l), lin) |
|
|
| self.relu = nn.ReLU() |
|
|
| self.color_volume = None |
|
|
| self.softmax = nn.Softmax(dim=1) |
|
|
| self.type = 'blending' |
|
|
| def sample_pts_from_colorVolume(self, pts): |
| device = pts.device |
| num_pts = pts.shape[0] |
| pts_ = pts.clone() |
| pts = pts.view(1, 1, 1, num_pts, 3) |
|
|
| pts = torch.flip(pts, dims=[-1]) |
|
|
| sampled_color = grid_sample_3d(self.color_volume, pts) |
| sampled_color = sampled_color.view(-1, num_pts).permute(1, 0).contiguous().to(device) |
|
|
| return sampled_color |
|
|
| def forward(self, position, normals, view_dirs, feature_vectors, img_index, |
| pts_pixel_color, pts_pixel_mask, pts_patch_color=None, pts_patch_mask=None): |
| """ |
| |
| :param position: can be 3d coord or interpolated volume latent |
| :param normals: |
| :param view_dirs: |
| :param feature_vectors: |
| :param img_index: [N_views], used to extract corresponding weights |
| :param pts_pixel_color: [N_pts, N_views, 3] |
| :param pts_pixel_mask: [N_pts, N_views] |
| :param pts_patch_color: [N_pts, N_views, Npx, 3] |
| :return: |
| """ |
| if self.embedder is not None: |
| view_dirs = self.embedder(view_dirs) |
|
|
| rendering_input = None |
|
|
| if self.mode == 'idr': |
| rendering_input = torch.cat([position, view_dirs, normals, feature_vectors], dim=-1) |
| elif self.mode == 'no_view_dir': |
| rendering_input = torch.cat([position, normals, feature_vectors], dim=-1) |
| elif self.mode == 'no_normal': |
| rendering_input = torch.cat([position, view_dirs, feature_vectors], dim=-1) |
| elif self.mode == 'no_points': |
| rendering_input = torch.cat([view_dirs, normals, feature_vectors], dim=-1) |
| elif self.mode == 'no_points_no_view_dir': |
| rendering_input = torch.cat([normals, feature_vectors], dim=-1) |
|
|
| x = rendering_input |
|
|
| for l in range(0, self.num_layers - 1): |
| lin = getattr(self, "lin" + str(l)) |
|
|
| x = lin(x) |
|
|
| if l < self.num_layers - 2: |
| x = self.relu(x) |
|
|
| |
| x_extracted = torch.index_select(x, 1, img_index.long()) |
|
|
| weights_pixel = self.softmax(x_extracted) |
| weights_pixel = weights_pixel * pts_pixel_mask |
| weights_pixel = weights_pixel / ( |
| torch.sum(weights_pixel.float(), dim=1, keepdim=True) + 1e-8) |
| final_pixel_color = torch.sum(pts_pixel_color * weights_pixel[:, :, None], dim=1, |
| keepdim=False) |
|
|
| final_pixel_mask = torch.sum(pts_pixel_mask.float(), dim=1, keepdim=True) > 0 |
|
|
| final_patch_color, final_patch_mask = None, None |
| |
| if pts_patch_color is not None: |
| N_pts, N_views, Npx, _ = pts_patch_color.shape |
| patch_mask = torch.sum(pts_patch_mask, dim=-1, keepdim=False) > Npx - 1 |
|
|
| weights_patch = self.softmax(x_extracted) |
| weights_patch = weights_patch * patch_mask |
| weights_patch = weights_patch / ( |
| torch.sum(weights_patch.float(), dim=1, keepdim=True) + 1e-8) |
|
|
| final_patch_color = torch.sum(pts_patch_color * weights_patch[:, :, None, None], dim=1, |
| keepdim=False) |
| final_patch_mask = torch.sum(patch_mask, dim=1, keepdim=True) > 0 |
|
|
| return final_pixel_color, final_pixel_mask, final_patch_color, final_patch_mask |
|
|