| | |
| | |
| |
|
| | from typing import Dict, Optional, Union |
| |
|
| | import numpy as np |
| | import pytest |
| | import torch |
| |
|
| | from megablocks import ops |
| |
|
| | SORT_TESTS = [ |
| | (32, torch.int16, None), |
| | (1024, torch.int16, None), |
| | (16384, torch.int16, None), |
| | (32, torch.int32, None), |
| | (1024, torch.int32, None), |
| | (16384, torch.int32, None), |
| | (32, torch.int64, None), |
| | (1024, torch.int64, None), |
| | (16384, torch.int64, None), |
| | (32, torch.int16, 128), |
| | (1024, torch.int16, 128), |
| | (16384, torch.int16, 128), |
| | (32, torch.int32, 128), |
| | (1024, torch.int32, 128), |
| | (16384, torch.int32, 128), |
| | (32, torch.int64, 128), |
| | (1024, torch.int64, 128), |
| | (16384, torch.int64, 128), |
| | ] |
| |
|
| |
|
| | def torch_to_numpy_dtype(dtype: torch.dtype,) -> Union[np.int16, np.int32, np.int64]: |
| | types: Dict[torch.dtype, Union[np.int16, np.int32, np.int64]] = { |
| | torch.int16: np.int16, |
| | torch.int32: np.int32, |
| | torch.int64: np.int64, |
| | } |
| | return types[dtype] |
| |
|
| |
|
| | @pytest.mark.gpu |
| | @pytest.mark.parametrize( |
| | ('n', 'dtype', 'max_val'), |
| | SORT_TESTS, |
| | ) |
| | def test_sort(n: int, dtype: torch.dtype, max_val: Optional[int]): |
| | if max_val is None: |
| | max_val = np.iinfo(torch_to_numpy_dtype(dtype)).max |
| | end_bit = int(np.ceil(np.log2(max_val))) |
| | x = torch.randint(0, max_val, (n,)).cuda().to(dtype) |
| |
|
| | out, indices = ops.sort(x, end_bit) |
| | expected_out, expected_indices = torch.sort(x) |
| | assert torch.all(torch.eq(out, expected_out)) |
| |
|
| | |
| | |
| | |
| | data = torch.empty_like(x) |
| | data.scatter_(0, indices.long(), out) |
| | expected_data = torch.empty_like(x) |
| | expected_data.scatter_(0, expected_indices, expected_out) |
| | assert torch.all(torch.eq(data, expected_data)) |
| |
|