| |
| |
| |
| |
| |
|
|
| import unittest |
|
|
| import numpy as np |
| import torch |
| from pytorch3d.ops import utils as oputil |
|
|
| from .common_testing import TestCaseMixin |
|
|
|
|
| class TestOpsUtils(TestCaseMixin, unittest.TestCase): |
| def setUp(self) -> None: |
| super().setUp() |
| torch.manual_seed(42) |
| np.random.seed(42) |
|
|
| def test_wmean(self): |
| device = torch.device("cuda:0") |
| n_points = 20 |
|
|
| x = torch.rand(n_points, 3, device=device) |
| weight = torch.rand(n_points, device=device) |
| x_np = x.cpu().data.numpy() |
| weight_np = weight.cpu().data.numpy() |
|
|
| |
| mean = oputil.wmean(x, keepdim=False) |
| mean_gt = np.average(x_np, axis=-2) |
| self.assertClose(mean.cpu().data.numpy(), mean_gt) |
|
|
| |
| mean = oputil.wmean(x, weight=weight, keepdim=False) |
| mean_gt = np.average(x_np, axis=-2, weights=weight_np) |
| self.assertClose(mean.cpu().data.numpy(), mean_gt) |
|
|
| |
| mean = oputil.wmean(x, weight=weight, keepdim=True) |
| self.assertClose(mean[0].cpu().data.numpy(), mean_gt) |
|
|
| |
| mean = oputil.wmean(x, weight=weight > 0.5, keepdim=False) |
| mean_gt = np.average(x_np, axis=-2, weights=weight_np > 0.5) |
| self.assertClose(mean.cpu().data.numpy(), mean_gt) |
|
|
| |
| x = torch.rand(10, n_points, 3, device=device) |
| x_np = x.cpu().data.numpy() |
| mean = oputil.wmean(x, weight=weight, keepdim=False) |
| mean_gt = np.average(x_np, axis=-2, weights=weight_np) |
| self.assertClose(mean.cpu().data.numpy(), mean_gt) |
|
|
| weight = weight[None, None, :].repeat(3, 1, 1) |
| mean = oputil.wmean(x, weight=weight, keepdim=False) |
| self.assertClose(mean[0].cpu().data.numpy(), mean_gt) |
|
|
| |
| weight = torch.rand(x.shape[0], device=device) |
| with self.assertRaises(ValueError) as context: |
| oputil.wmean(x, weight=weight, keepdim=False) |
| self.assertTrue("weights are not compatible" in str(context.exception)) |
|
|
| |
| weight = torch.rand(x.shape[0], n_points, device=device) |
| weight_np = np.tile( |
| weight[:, :, None].cpu().data.numpy(), (1, 1, x_np.shape[-1]) |
| ) |
| mean = oputil.wmean(x, dim=0, weight=weight, keepdim=False) |
| mean_gt = np.average(x_np, axis=0, weights=weight_np) |
| self.assertClose(mean.cpu().data.numpy(), mean_gt) |
|
|
| |
| mean = oputil.wmean(x, dim=(0, 1), weight=weight, keepdim=False) |
| mean_gt = np.average(x_np, axis=(0, 1), weights=weight_np) |
| self.assertClose(mean.cpu().data.numpy(), mean_gt) |
|
|
| def test_masked_gather_errors(self): |
| idx = torch.randint(0, 10, size=(5, 10, 4, 2)) |
| points = torch.randn(size=(5, 10, 3)) |
| with self.assertRaisesRegex(ValueError, "format is not supported"): |
| oputil.masked_gather(points, idx) |
|
|
| points = torch.randn(size=(2, 10, 3)) |
| with self.assertRaisesRegex(ValueError, "same batch dimension"): |
| oputil.masked_gather(points, idx) |
|
|