| |
|
| | #ifndef MARLIN_NAMESPACE_NAME |
| | #define MARLIN_NAMESPACE_NAME marlin |
| | #endif |
| |
|
| | #include "marlin.cuh" |
| | #include "marlin_dtypes.cuh" |
| | #include "core/scalar_type.hpp" |
| |
|
| | #define MARLIN_KERNEL_PARAMS \ |
| | const int4 *__restrict__ A, const int4 *__restrict__ B, \ |
| | int4 *__restrict__ C, int4 *__restrict__ C_tmp, \ |
| | const int4 *__restrict__ scales_ptr, \ |
| | const uint16_t *__restrict__ scale2_ptr, \ |
| | const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \ |
| | int num_groups, int prob_m, int prob_n, int prob_k, int lda, int *locks, \ |
| | bool use_atomic_add, bool use_fp32_reduce, int max_shared_mem |
| |
|
| | namespace MARLIN_NAMESPACE_NAME { |
| | template <typename scalar_t, |
| | const vllm::ScalarTypeId w_type_id, |
| | const int threads, |
| | const int thread_m_blocks, |
| | |
| | |
| | const int thread_n_blocks, |
| | const int thread_k_blocks, |
| | const bool m_block_size_8, |
| | |
| | const int stages, |
| | |
| | const int group_blocks, |
| | |
| | const bool is_zp_float |
| | > |
| | __global__ void Marlin(MARLIN_KERNEL_PARAMS); |
| |
|
| | } |
| |
|