| |
| |
| |
| |
| |
|
|
| import unittest |
|
|
| import torch |
| from nerf.raysampler import NeRFRaysampler, ProbabilisticRaysampler |
| from pytorch3d.renderer import PerspectiveCameras |
| from pytorch3d.transforms.rotation_conversions import random_rotations |
|
|
|
|
| class TestRaysampler(unittest.TestCase): |
| def setUp(self) -> None: |
| torch.manual_seed(42) |
|
|
| def test_raysampler_caching(self, batch_size=10): |
| """ |
| Tests the consistency of the NeRF raysampler caching. |
| """ |
|
|
| raysampler = NeRFRaysampler( |
| min_x=0.0, |
| max_x=10.0, |
| min_y=0.0, |
| max_y=10.0, |
| n_pts_per_ray=10, |
| min_depth=0.1, |
| max_depth=10.0, |
| n_rays_per_image=12, |
| image_width=10, |
| image_height=10, |
| stratified=False, |
| stratified_test=False, |
| invert_directions=True, |
| ) |
|
|
| raysampler.eval() |
|
|
| cameras, rays = [], [] |
|
|
| for _ in range(batch_size): |
|
|
| R = random_rotations(1) |
| T = torch.randn(1, 3) |
| focal_length = torch.rand(1, 2) + 0.5 |
| principal_point = torch.randn(1, 2) |
|
|
| camera = PerspectiveCameras( |
| focal_length=focal_length, |
| principal_point=principal_point, |
| R=R, |
| T=T, |
| ) |
|
|
| cameras.append(camera) |
| rays.append(raysampler(camera)) |
|
|
| raysampler.precache_rays(cameras, list(range(batch_size))) |
|
|
| for cam_index, rays_ in enumerate(rays): |
| rays_cached_ = raysampler( |
| cameras=cameras[cam_index], |
| chunksize=None, |
| chunk_idx=0, |
| camera_hash=cam_index, |
| caching=False, |
| ) |
|
|
| for v, v_cached in zip(rays_, rays_cached_): |
| self.assertTrue(torch.allclose(v, v_cached)) |
|
|
| def test_probabilistic_raysampler(self, batch_size=1, n_pts_per_ray=60): |
| """ |
| Check that the probabilistic ray sampler does not crash for various |
| settings. |
| """ |
|
|
| raysampler_grid = NeRFRaysampler( |
| min_x=0.0, |
| max_x=10.0, |
| min_y=0.0, |
| max_y=10.0, |
| n_pts_per_ray=n_pts_per_ray, |
| min_depth=1.0, |
| max_depth=10.0, |
| n_rays_per_image=12, |
| image_width=10, |
| image_height=10, |
| stratified=False, |
| stratified_test=False, |
| invert_directions=True, |
| ) |
|
|
| R = random_rotations(batch_size) |
| T = torch.randn(batch_size, 3) |
| focal_length = torch.rand(batch_size, 2) + 0.5 |
| principal_point = torch.randn(batch_size, 2) |
| camera = PerspectiveCameras( |
| focal_length=focal_length, |
| principal_point=principal_point, |
| R=R, |
| T=T, |
| ) |
|
|
| raysampler_grid.eval() |
|
|
| ray_bundle = raysampler_grid(cameras=camera) |
|
|
| ray_weights = torch.rand_like(ray_bundle.lengths) |
|
|
| |
| for stratified_test in (True, False): |
| for stratified in (True, False): |
| raysampler_prob = ProbabilisticRaysampler( |
| n_pts_per_ray=n_pts_per_ray, |
| stratified=stratified, |
| stratified_test=stratified_test, |
| add_input_samples=True, |
| ) |
| for mode in ("train", "eval"): |
| getattr(raysampler_prob, mode)() |
| for _ in range(10): |
| raysampler_prob(ray_bundle, ray_weights) |
|
|