| | |
| | |
| |
|
| | import numpy as np |
| | import pytest |
| | import torch |
| |
|
| | try: |
| | from megablocks._ops import ops as backend |
| | except ModuleNotFoundError as e: |
| | raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e |
| |
|
| | from megablocks import ops |
| |
|
| |
|
| | def promote_scalar(x: torch.Tensor) -> torch.Tensor: |
| | return x.view(1) if not len(x.size()) else x |
| |
|
| |
|
| | REPLICATE_TESTS = [ |
| | (8, 1, 1), |
| | (8, 2, 1), |
| | (8, 4, 1), |
| | (8, 8, 1), |
| | (8, 2, 2), |
| | (8, 4, 2), |
| | (8, 8, 2), |
| | (8, 2, 4), |
| | (8, 4, 4), |
| | (8, 8, 4), |
| | (8, 2, 8), |
| | (8, 4, 8), |
| | (8, 8, 8), |
| | (16384, 2, 1), |
| | (16384, 4, 1), |
| | (16384, 8, 1), |
| | (16384, 16, 1), |
| | (16384, 32, 1), |
| | (16384, 64, 1), |
| | (16384, 128, 1), |
| | (16384, 2, 2), |
| | (16384, 4, 2), |
| | (16384, 8, 2), |
| | (16384, 16, 2), |
| | (16384, 32, 2), |
| | (16384, 64, 2), |
| | (16384, 128, 2), |
| | (16384, 2, 4), |
| | (16384, 4, 4), |
| | (16384, 8, 4), |
| | (16384, 16, 4), |
| | (16384, 32, 4), |
| | (16384, 64, 4), |
| | (16384, 128, 4), |
| | (16384, 2, 8), |
| | (16384, 4, 8), |
| | (16384, 8, 8), |
| | (16384, 16, 8), |
| | (16384, 32, 8), |
| | (16384, 64, 8), |
| | (16384, 128, 8), |
| | ] |
| |
|
| |
|
| | @pytest.mark.gpu |
| | @pytest.mark.parametrize(("tokens", "num_centers", "top_k"), REPLICATE_TESTS) |
| | def test_replicate(tokens: int, num_centers: int, top_k: int): |
| | tokens_to_centers = torch.randint(0, num_centers, (tokens,)).cuda().int() |
| | tokens_per_center = ops.histogram(tokens_to_centers, num_centers) |
| | bins = ops.inclusive_cumsum(tokens_per_center, 0) |
| | bins = promote_scalar(bins) |
| | center_weights = torch.randn(top_k, num_centers).cuda().half() |
| |
|
| | def replicate(x: torch.Tensor, bins: torch.Tensor, num_outputs: int): |
| | x = x.cpu().numpy() |
| | bins = bins.cpu().numpy() |
| | out = np.zeros((x.shape[0], num_outputs)) |
| | for batch_idx in range(x.shape[0]): |
| | start = 0 |
| | for i, end in enumerate(bins): |
| | value = x[batch_idx, i] |
| | while start < end: |
| | out[batch_idx, start] = value |
| | start += 1 |
| | return torch.from_numpy(out).cuda().half() |
| |
|
| | out = ops.replicate(center_weights, bins, tokens) |
| | expected_out = replicate(center_weights, bins, tokens) |
| | assert torch.all(torch.eq(out, expected_out)) |
| |
|
| |
|
| | @pytest.mark.gpu |
| | @pytest.mark.parametrize(("tokens", "num_centers", "top_k"), REPLICATE_TESTS) |
| | def test_replicate_backward(tokens: int, num_centers: int, top_k: int): |
| | tokens_to_centers = torch.randint(0, num_centers, (tokens,)).cuda().int() |
| | tokens_per_center = ops.histogram(tokens_to_centers, num_centers) |
| | bins = ops.inclusive_cumsum(tokens_per_center, 0) |
| | bins = promote_scalar(bins) |
| | center_weights = torch.randn(top_k, num_centers).cuda().half() |
| |
|
| | grad = ops.replicate(center_weights, bins, tokens) |
| |
|
| | out = torch.empty_like(center_weights) |
| | backend.replicate_backward(grad, bins, out) |
| | expected_out = center_weights * tokens_per_center.view([1, num_centers]) |
| |
|
| | |
| | assert torch.allclose(out, expected_out, rtol=1e-2) |
| |
|