| | #include <torch/library.h> |
| |
|
| | #include "registration.h" |
| | #include "torch_binding.h" |
| |
|
| | TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { |
| | ops.def("get_mla_metadata(Tensor! seqlens_k, int num_heads_per_head_k, int num_heads_k) -> Tensor[]"); |
| | ops.impl("get_mla_metadata", torch::kCUDA, &get_mla_metadata); |
| |
|
| | |
| | ops.def("mha_fwd_kvcache_mla(Tensor! q, Tensor! kcache, Tensor? vcache_, int head_size_v, Tensor! seqlens_k, Tensor! block_table, float softmax_scale, bool is_causal_, Tensor! tile_scheduler_metadata, Tensor! num_splits) -> Tensor[]"); |
| | ops.impl("mha_fwd_kvcache_mla", torch::kCUDA, &mha_fwd_kvcache_mla); |
| | } |
| |
|
| | REGISTER_EXTENSION(TORCH_EXTENSION_NAME) |
| |
|