| | |
| | |
| | import glob |
| | import itertools |
| | import os |
| | import subprocess |
| |
|
| | import jinja2 |
| |
|
| | FILE_HEAD = """ |
| | // auto generated by generate.py |
| | // clang-format off |
| | |
| | #include "kernel.h" |
| | #include "marlin_template.h" |
| | |
| | namespace MARLIN_NAMESPACE_NAME { |
| | """.strip() |
| |
|
| | TEMPLATE = ("template __global__ void Marlin<" |
| | "{{scalar_t}}, " |
| | "{{w_type_id}}, " |
| | "{{threads}}, " |
| | "{{thread_m_blocks}}, " |
| | "{{thread_n_blocks}}, " |
| | "{{thread_k_blocks}}, " |
| | "{{'true' if m_block_size_8 else 'false'}}, " |
| | "{{stages}}, " |
| | "{{group_blocks}}, " |
| | "{{'true' if is_zp_float else 'false'}}>" |
| | "( MARLIN_KERNEL_PARAMS );") |
| |
|
| | |
| | |
| | SCALAR_TYPES = [ |
| | "vllm::kU4", "vllm::kU4B8", "vllm::kU8B128", "vllm::kFE4M3fn", |
| | "vllm::kFE2M1f" |
| | ] |
| | THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128), |
| | (128, 64, 128)] |
| |
|
| | THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4] |
| | |
| | |
| | |
| | |
| | GROUP_BLOCKS = [0, 1, -1, 2, 4, 8] |
| | DTYPES = ["fp16", "bf16"] |
| |
|
| |
|
| | def remove_old_kernels(): |
| | for filename in glob.glob(os.path.dirname(__file__) + "/kernel_*.cu"): |
| | subprocess.call(["rm", "-f", filename]) |
| |
|
| |
|
| | def generate_new_kernels(): |
| | for scalar_type, dtype in itertools.product(SCALAR_TYPES, DTYPES): |
| | all_template_str_list = [] |
| |
|
| | for group_blocks, m_blocks, thread_configs in itertools.product( |
| | GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS): |
| |
|
| | |
| | if group_blocks == 0 and scalar_type not in [ |
| | "vllm::kU4B8", "vllm::kU8B128" |
| | ]: |
| | continue |
| | if thread_configs[2] == 256: |
| | |
| | |
| | if m_blocks <= 1 and thread_configs[0] != 128: |
| | continue |
| | if m_blocks > 1 and thread_configs[0] != 64: |
| | continue |
| |
|
| | |
| | |
| | if scalar_type == "vllm::kFE4M3fn" and group_blocks not in [-1, 8]: |
| | continue |
| | |
| | if scalar_type == "vllm::kFE2M1f" and group_blocks != 1: |
| | continue |
| | |
| | if scalar_type != "vllm::kFE2M1f" and group_blocks == 1: |
| | continue |
| |
|
| | k_blocks = thread_configs[0] // 16 |
| | n_blocks = thread_configs[1] // 16 |
| | threads = thread_configs[2] |
| |
|
| | c_dtype = "half" if dtype == "fp16" else "nv_bfloat16" |
| |
|
| | is_zp_float_list = [False] |
| | if dtype == "fp16" and scalar_type == "vllm::kU4" and \ |
| | group_blocks == 4: |
| | |
| | |
| | is_zp_float_list.append(True) |
| |
|
| | for is_zp_float in is_zp_float_list: |
| | template_str = jinja2.Template(TEMPLATE).render( |
| | scalar_t=c_dtype, |
| | w_type_id=scalar_type + ".id()", |
| | threads=threads, |
| | thread_m_blocks=max(m_blocks, 1), |
| | thread_n_blocks=n_blocks, |
| | thread_k_blocks=k_blocks, |
| | m_block_size_8=m_blocks == 0.5, |
| | stages="pipe_stages", |
| | group_blocks=group_blocks, |
| | is_zp_float=is_zp_float, |
| | ) |
| |
|
| | all_template_str_list.append(template_str) |
| |
|
| | file_content = FILE_HEAD + "\n\n" |
| | file_content += "\n\n".join(all_template_str_list) + "\n\n}\n" |
| | filename = f"kernel_{dtype}_{scalar_type[6:].lower()}.cu" |
| |
|
| | with open(os.path.join(os.path.dirname(__file__), filename), "w") as f: |
| | f.write(file_content) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | remove_old_kernels() |
| | generate_new_kernels() |
| |
|