| |
| |
|
|
| import numpy as np |
| import pytest |
| import torch |
|
|
| try: |
| from megablocks._ops import ops as backend |
| except ModuleNotFoundError as e: |
| raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e |
|
|
| from megablocks import ops |
|
|
|
|
| def promote_scalar(x: torch.Tensor) -> torch.Tensor: |
| return x.view(1) if not len(x.size()) else x |
|
|
|
|
| REPLICATE_TESTS = [ |
| (8, 1, 1), |
| (8, 2, 1), |
| (8, 4, 1), |
| (8, 8, 1), |
| (8, 2, 2), |
| (8, 4, 2), |
| (8, 8, 2), |
| (8, 2, 4), |
| (8, 4, 4), |
| (8, 8, 4), |
| (8, 2, 8), |
| (8, 4, 8), |
| (8, 8, 8), |
| (16384, 2, 1), |
| (16384, 4, 1), |
| (16384, 8, 1), |
| (16384, 16, 1), |
| (16384, 32, 1), |
| (16384, 64, 1), |
| (16384, 128, 1), |
| (16384, 2, 2), |
| (16384, 4, 2), |
| (16384, 8, 2), |
| (16384, 16, 2), |
| (16384, 32, 2), |
| (16384, 64, 2), |
| (16384, 128, 2), |
| (16384, 2, 4), |
| (16384, 4, 4), |
| (16384, 8, 4), |
| (16384, 16, 4), |
| (16384, 32, 4), |
| (16384, 64, 4), |
| (16384, 128, 4), |
| (16384, 2, 8), |
| (16384, 4, 8), |
| (16384, 8, 8), |
| (16384, 16, 8), |
| (16384, 32, 8), |
| (16384, 64, 8), |
| (16384, 128, 8), |
| ] |
|
|
|
|
| @pytest.mark.gpu |
| @pytest.mark.parametrize(("tokens", "num_centers", "top_k"), REPLICATE_TESTS) |
| def test_replicate(tokens: int, num_centers: int, top_k: int): |
| tokens_to_centers = torch.randint(0, num_centers, (tokens,)).cuda().int() |
| tokens_per_center = ops.histogram(tokens_to_centers, num_centers) |
| bins = ops.inclusive_cumsum(tokens_per_center, 0) |
| bins = promote_scalar(bins) |
| center_weights = torch.randn(top_k, num_centers).cuda().half() |
|
|
| def replicate(x: torch.Tensor, bins: torch.Tensor, num_outputs: int): |
| x = x.cpu().numpy() |
| bins = bins.cpu().numpy() |
| out = np.zeros((x.shape[0], num_outputs)) |
| for batch_idx in range(x.shape[0]): |
| start = 0 |
| for i, end in enumerate(bins): |
| value = x[batch_idx, i] |
| while start < end: |
| out[batch_idx, start] = value |
| start += 1 |
| return torch.from_numpy(out).cuda().half() |
|
|
| out = ops.replicate(center_weights, bins, tokens) |
| expected_out = replicate(center_weights, bins, tokens) |
| assert torch.all(torch.eq(out, expected_out)) |
|
|
|
|
| @pytest.mark.gpu |
| @pytest.mark.parametrize(("tokens", "num_centers", "top_k"), REPLICATE_TESTS) |
| def test_replicate_backward(tokens: int, num_centers: int, top_k: int): |
| tokens_to_centers = torch.randint(0, num_centers, (tokens,)).cuda().int() |
| tokens_per_center = ops.histogram(tokens_to_centers, num_centers) |
| bins = ops.inclusive_cumsum(tokens_per_center, 0) |
| bins = promote_scalar(bins) |
| center_weights = torch.randn(top_k, num_centers).cuda().half() |
|
|
| grad = ops.replicate(center_weights, bins, tokens) |
|
|
| out = torch.empty_like(center_weights) |
| backend.replicate_backward(grad, bins, out) |
| expected_out = center_weights * tokens_per_center.view([1, num_centers]) |
|
|
| |
| assert torch.allclose(out, expected_out, rtol=1e-2) |
|
|