| | #include <torch/library.h> |
| |
|
| | #include "registration.h" |
| | #include "torch_binding.h" |
| |
|
| | #include "new_cumsum.h" |
| | #include "new_histogram.h" |
| | #include "new_indices.h" |
| | #include "new_replicate.h" |
| | #include "new_sort.h" |
| |
|
| | #include "grouped_gemm/grouped_gemm.h" |
| |
|
| | |
| | torch::Tensor exclusive_cumsum_wrapper(torch::Tensor x, int64_t dim, torch::Tensor out) { |
| | megablocks::exclusive_cumsum(x, dim, out); |
| | return out; |
| | } |
| |
|
| | |
| | torch::Tensor inclusive_cumsum_wrapper(torch::Tensor x, int64_t dim, torch::Tensor out) { |
| | megablocks::inclusive_cumsum(x, dim, out); |
| | return out; |
| | } |
| |
|
| | |
| | torch::Tensor histogram_wrapper(torch::Tensor x, int64_t num_bins) { |
| | return megablocks::histogram(x, num_bins); |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | torch::Tensor indices_wrapper(torch::Tensor padded_bins, |
| | int64_t block_size, |
| | int64_t output_block_rows, |
| | int64_t output_block_columns, |
| | torch::Tensor out) { |
| | megablocks::indices(padded_bins, block_size, output_block_rows, output_block_columns, out); |
| | return out; |
| | } |
| |
|
| |
|
| |
|
| | |
| | |
| | |
| | |
| | torch::Tensor replicate_forward_wrapper(torch::Tensor x, torch::Tensor bins, torch::Tensor out) { |
| | megablocks::replicate_forward(x, bins, out); |
| | return out; |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | torch::Tensor replicate_backward_wrapper(torch::Tensor grad, torch::Tensor bins, torch::Tensor out) { |
| | megablocks::replicate_backward(grad, bins, out); |
| | return out; |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | torch::Tensor sort_wrapper(torch::Tensor x, int64_t end_bit, torch::Tensor x_out, torch::Tensor iota_out) { |
| | megablocks::sort(x, end_bit, x_out, iota_out); |
| | return x_out; |
| | } |
| |
|
| | |
| | torch::Tensor gmm(torch::Tensor a, torch::Tensor b, torch::Tensor c, torch::Tensor batch_sizes, bool trans_a, bool trans_b) { |
| | grouped_gemm::GroupedGemm(a, b, c, batch_sizes, trans_a, trans_b); |
| | return c; |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { |
| | ops.def("exclusive_cumsum(Tensor x, int dim, Tensor(a!) out) -> Tensor(a!)"); |
| | ops.impl("exclusive_cumsum", torch::kCUDA, &exclusive_cumsum_wrapper); |
| |
|
| | ops.def("inclusive_cumsum(Tensor x, int dim, Tensor(a!) out) -> Tensor(a!)"); |
| | ops.impl("inclusive_cumsum", torch::kCUDA, &inclusive_cumsum_wrapper); |
| |
|
| | ops.def("histogram(Tensor x, int num_bins) -> Tensor"); |
| | ops.impl("histogram", torch::kCUDA, &histogram_wrapper); |
| |
|
| | ops.def("indices(Tensor padded_bins, int block_size, int output_block_rows, int output_block_columns, Tensor(a!) out) -> Tensor(a!)"); |
| | ops.impl("indices", torch::kCUDA, &indices_wrapper); |
| |
|
| | ops.def("replicate_forward(Tensor x, Tensor bins, Tensor(a!) out) -> Tensor(a!)"); |
| | ops.impl("replicate_forward", torch::kCUDA, &replicate_forward_wrapper); |
| |
|
| | ops.def("replicate_backward(Tensor grad, Tensor bins, Tensor(a!) out) -> Tensor(a!)"); |
| | ops.impl("replicate_backward", torch::kCUDA, &replicate_backward_wrapper); |
| | |
| | ops.def("sort(Tensor x, int end_bit, Tensor x_out, Tensor iota_out) -> Tensor(x_out)"); |
| | ops.impl("sort", torch::kCUDA, &sort_wrapper); |
| |
|
| | |
| | ops.def("gmm(Tensor (a!) a, Tensor (b!) b, Tensor(c!) c, Tensor batch_sizes, bool trans_a, bool trans_b) -> Tensor(c!)"); |
| | ops.impl("gmm", torch::kCUDA, &gmm); |
| | } |
| |
|
| | REGISTER_EXTENSION(TORCH_EXTENSION_NAME) |