| |
| |
| |
| |
| |
|
|
|
|
| import unittest |
|
|
| import pytorch3d as pt3d |
| import torch |
| from pytorch3d.implicitron.models.view_pooler.view_sampler import ViewSampler |
| from pytorch3d.implicitron.tools.config import expand_args_fields |
|
|
|
|
| class TestViewsampling(unittest.TestCase): |
| def setUp(self): |
| torch.manual_seed(42) |
| expand_args_fields(ViewSampler) |
|
|
| def _init_view_sampler_problem(self, random_masks): |
| """ |
| Generates a view-sampling problem: |
| - 4 source views, 1st/2nd from the first sequence 'seq1', the rest from 'seq2' |
| - 3 sets of 3D points from sequences 'seq1', 'seq2', 'seq2' respectively. |
| - first 50 points in each batch correctly project to the source views, |
| while the remaining 50 do not land in any projection plane. |
| - each source view is labeled with image feature tensors of shape 7x100x50, |
| where all elements of the n-th tensor are set to `n+1`. |
| - the elements of the source view masks are either set to random binary number |
| (if `random_masks==True`), or all set to 1 (`random_masks==False`). |
| - the source view cameras are uniformly distributed on a unit circle |
| in the x-z plane and look at (0,0,0). |
| """ |
| seq_id_camera = ["seq1", "seq1", "seq2", "seq2"] |
| seq_id_pts = ["seq1", "seq2", "seq2"] |
| pts_batch = 3 |
| n_pts = 100 |
| n_views = 4 |
| fdim = 7 |
| H = 100 |
| W = 50 |
|
|
| |
| pts_inside = ( |
| torch.nn.functional.normalize( |
| torch.randn(pts_batch, n_pts // 2, 3, device="cuda"), |
| dim=-1, |
| ) |
| * 0.1 |
| ) |
|
|
| |
| pts_outside = pts_inside.clone() |
| pts_outside[:, :, 1] += 1e8 |
| pts = torch.cat([pts_inside, pts_outside], dim=1) |
|
|
| R, T = pt3d.renderer.look_at_view_transform( |
| dist=1.0, |
| elev=0.0, |
| azim=torch.linspace(0, 360, n_views + 1)[:n_views], |
| degrees=True, |
| device=pts.device, |
| ) |
| focal_length = R.new_ones(n_views, 2) |
| principal_point = R.new_zeros(n_views, 2) |
| camera = pt3d.renderer.PerspectiveCameras( |
| R=R, |
| T=T, |
| focal_length=focal_length, |
| principal_point=principal_point, |
| device=pts.device, |
| ) |
|
|
| feats_map = torch.arange(n_views, device=pts.device, dtype=pts.dtype) + 1 |
| feats = {"feats": feats_map[:, None, None, None].repeat(1, fdim, H, W)} |
|
|
| masks = ( |
| torch.rand(n_views, 1, H, W, device=pts.device, dtype=pts.dtype) > 0.5 |
| ).type_as(R) |
|
|
| if not random_masks: |
| masks[:] = 1.0 |
|
|
| return pts, camera, feats, masks, seq_id_camera, seq_id_pts |
|
|
| def test_compare_with_naive(self): |
| """ |
| Compares the outputs of the efficient ViewSampler module with a |
| naive implementation. |
| """ |
|
|
| ( |
| pts, |
| camera, |
| feats, |
| masks, |
| seq_id_camera, |
| seq_id_pts, |
| ) = self._init_view_sampler_problem(True) |
|
|
| for masked_sampling in (True, False): |
| feats_sampled_n, masks_sampled_n = _view_sample_naive( |
| pts, |
| seq_id_pts, |
| camera, |
| seq_id_camera, |
| feats, |
| masks, |
| masked_sampling, |
| ) |
| |
| expand_args_fields(ViewSampler) |
| view_sampler = ViewSampler(masked_sampling=masked_sampling) |
| feats_sampled, masks_sampled = view_sampler( |
| pts=pts, |
| seq_id_pts=seq_id_pts, |
| camera=camera, |
| seq_id_camera=seq_id_camera, |
| feats=feats, |
| masks=masks, |
| ) |
| for k in feats_sampled.keys(): |
| self.assertTrue(torch.allclose(feats_sampled[k], feats_sampled_n[k])) |
| self.assertTrue(torch.allclose(masks_sampled, masks_sampled_n)) |
|
|
| def test_viewsampling(self): |
| """ |
| Generates a viewsampling problem with predictable outcome, and compares |
| the ViewSampler's output to the expected result. |
| """ |
|
|
| ( |
| pts, |
| camera, |
| feats, |
| masks, |
| seq_id_camera, |
| seq_id_pts, |
| ) = self._init_view_sampler_problem(False) |
|
|
| expand_args_fields(ViewSampler) |
|
|
| for masked_sampling in (True, False): |
|
|
| view_sampler = ViewSampler(masked_sampling=masked_sampling) |
|
|
| feats_sampled, masks_sampled = view_sampler( |
| pts=pts, |
| seq_id_pts=seq_id_pts, |
| camera=camera, |
| seq_id_camera=seq_id_camera, |
| feats=feats, |
| masks=masks, |
| ) |
|
|
| n_views = camera.R.shape[0] |
| n_pts = pts.shape[1] |
| feat_dim = feats["feats"].shape[1] |
| pts_batch = pts.shape[0] |
| n_pts_away = n_pts // 2 |
|
|
| for pts_i in range(pts_batch): |
| for view_i in range(n_views): |
| if seq_id_pts[pts_i] != seq_id_camera[view_i]: |
| |
| gt_masks = pts.new_zeros(n_pts, 1) |
| gt_feats = pts.new_zeros(n_pts, feat_dim) |
| else: |
| gt_masks = pts.new_ones(n_pts, 1) |
| gt_feats = pts.new_ones(n_pts, feat_dim) * (view_i + 1) |
| gt_feats[n_pts_away:] = 0.0 |
| if masked_sampling: |
| gt_masks[n_pts_away:] = 0.0 |
|
|
| for k in feats_sampled: |
| self.assertTrue( |
| torch.allclose( |
| feats_sampled[k][pts_i, view_i], |
| gt_feats, |
| ) |
| ) |
| self.assertTrue( |
| torch.allclose( |
| masks_sampled[pts_i, view_i], |
| gt_masks, |
| ) |
| ) |
|
|
|
|
| def _view_sample_naive( |
| pts, |
| seq_id_pts, |
| camera, |
| seq_id_camera, |
| feats, |
| masks, |
| masked_sampling, |
| ): |
| """ |
| A naive implementation of the forward pass of ViewSampler. |
| Refer to ViewSampler's docstring for description of the arguments. |
| """ |
|
|
| pts_batch = pts.shape[0] |
| n_views = camera.R.shape[0] |
| n_pts = pts.shape[1] |
|
|
| feats_sampled = [[[] for _ in range(n_views)] for _ in range(pts_batch)] |
| masks_sampled = [[[] for _ in range(n_views)] for _ in range(pts_batch)] |
|
|
| for pts_i in range(pts_batch): |
| for view_i in range(n_views): |
| if seq_id_pts[pts_i] != seq_id_camera[view_i]: |
| |
| feats_sampled_ = { |
| k: f.new_zeros(n_pts, f.shape[1]) for k, f in feats.items() |
| } |
| masks_sampled_ = masks.new_zeros(n_pts, 1) |
| else: |
| |
| feats_sampled_, masks_sampled_ = _sample_one_view_naive( |
| camera[view_i], |
| pts[pts_i], |
| {k: f[view_i] for k, f in feats.items()}, |
| masks[view_i], |
| masked_sampling, |
| sampling_mode="bilinear", |
| ) |
| feats_sampled[pts_i][view_i] = feats_sampled_ |
| masks_sampled[pts_i][view_i] = masks_sampled_ |
|
|
| masks_sampled_cat = torch.stack([torch.stack(m) for m in masks_sampled]) |
| feats_sampled_cat = {} |
| for k in feats_sampled[0][0].keys(): |
| feats_sampled_cat[k] = torch.stack( |
| [torch.stack([f_[k] for f_ in f]) for f in feats_sampled] |
| ) |
| return feats_sampled_cat, masks_sampled_cat |
|
|
|
|
| def _sample_one_view_naive( |
| camera, |
| pts, |
| feats, |
| masks, |
| masked_sampling, |
| sampling_mode="bilinear", |
| ): |
| """ |
| Sample a single source view. |
| """ |
| proj_ndc = camera.transform_points(pts[None])[None, ..., :-1] |
| feats_sampled = { |
| k: pt3d.renderer.ndc_grid_sample(f[None], proj_ndc, mode=sampling_mode).permute( |
| 0, 3, 1, 2 |
| )[0, :, :, 0] |
| for k, f in feats.items() |
| } |
| if not masked_sampling: |
| n_pts = pts.shape[0] |
| masks_sampled = proj_ndc.new_ones(n_pts, 1) |
| else: |
| masks_sampled = pt3d.renderer.ndc_grid_sample( |
| masks[None], |
| proj_ndc, |
| mode=sampling_mode, |
| align_corners=False, |
| )[0, 0, 0, :][:, None] |
| return feats_sampled, masks_sampled |
|
|